WHAT' CHA GONNA DO FOR ME?

Python、統計、機械学習、R、ファイナンスとか

Stanで混合正規分布の推定①

前回↓に引き続き、混合正規分布のパラメータ推定。ただし、今回はEMアルゴリズムではなくStanを使ってMCMCにてパラメータを推定する。

前回に引き続き、データは以下の分布からサンプリングすることにより作る。

{
x \sim 0.7 \mathcal{N} (-5, 5^2) + 0.3 \mathcal{N} (5, 1^2)
}

ここに出てくる2つの平均、2つの分散、2つの混合比率を推定するのが目的。

サンプリングされたデータのヒストグラムは以下の通りで、乱数の種を揃えているので、前回と完全に同じデータとなる。
ただし今回は、密度関数の値はそれぞれの分布のものに加えて、それらの合計も紫の線で示した。

f:id:lofas:20150305160039p:plain

Stanを実行することで得られたパラメータ推定値の分布とパスのプロットがこちら。
上段がそれぞれの分布の平均、中段が標準偏差(分散ではない)、下段が混合係数となっていて、それぞれ問題なく分布収束しているように見える。
バーンインは100しか取ってないのにStanすごいやん。

f:id:lofas:20150305155256p:plain

続いて、パラメータ推定値のサマリー。
EMアルゴリズムの結果とほぼ完璧に一致していることが確認できる。
ただ、EMアルゴリズムはパラメータの標準誤差を計算するのが難しい(がんばればできるようだが)ことから信頼区間が構築しづらいのに対し、MCMCでは以下のように簡単に信用区間を構築できるため、その点はMCMCの方が便利。

In [18]: fitchan
Out[18]: 
Inference for Stan model: anon_model_912e3ad2e9344e95d4acd093ad7b9b1b.
1 chains, each with iter=2000; warmup=100; thin=1; 
post-warmup draws per chain=1900, total post-warmup draws=1900.

         mean se_mean     sd   2.5%    25%    50%    75%  97.5%  n_eff   Rhat
mu[0]   -5.24  7.7e-3   0.11  -5.44  -5.31  -5.24  -5.17  -5.01  194.0    1.0
mu[1]     5.0  1.3e-3   0.03   4.94   4.98    5.0   5.02   5.06  511.0    1.0
sig[0]   4.82  3.2e-3   0.07   4.68   4.78   4.82   4.87   4.97  531.0    1.0
sig[1]   1.02  1.1e-3   0.03   0.97   1.01   1.02   1.04   1.07  488.0    1.0
pi[0]    0.69  4.1e-4 7.7e-3   0.68   0.69   0.69    0.7   0.71  365.0    1.0
pi[1]    0.31  4.1e-4 7.7e-3   0.29    0.3   0.31   0.31   0.32  365.0    1.0
lp__   -1.5e4    0.09   1.59 -1.5e4 -1.5e4 -1.5e4 -1.5e4 -1.5e4  341.0    1.0

Samples were drawn using NUTS(diag_e) at 03/05/15 15:46:37.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

Pythonコードはこちら。

# -*- coding: utf-8 -*-
from __future__ import print_function, division
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
import seaborn as sns
import pystan

#サンプル数
N = 5000
#混合数
K = 2
#混合係数
pi = 0.7
#乱数の種
np.random.seed(0)

#混合係数から各分布のサンプリング数を決める
N_k1 = N*pi
N_k2 = N-N_k1

#真のパラメータ
mu1 = -5
sig1 = np.sqrt(25)
mu2 = 5
sig2 = np.sqrt(1)

x1 = np.random.normal(mu1,sig1,N_k1)
x2 = np.random.normal(mu2,sig2,N_k2)

#観測変数
x = np.hstack((x1,x2))
base=np.linspace(np.min(x),np.max(x),1000)
plt.hist(x,bins=100,normed=True)
plt.plot(base,pi*stats.norm.pdf(base,mu1,sig1))
plt.plot(base,(1-pi)*stats.norm.pdf(base,mu2,sig2))
plt.plot(base,pi*stats.norm.pdf(base,mu1,sig1)+(1-pi)*stats.norm.pdf(base,mu2,sig2))

#Stan
stan_data = {'N': N, 'M': K, 'x': x}

model = pystan.StanModel('D:/Python/GMM.stan')

fitchan = model.sampling(data=stan_data, iter=2000, warmup=100,chains=1)

fitchan    

fitchan.plot()
plt.tight_layout()

Stanコードはこちら。

data {
  int<lower=1> N;
  int<lower=1> M;
  real x[N];
}

parameters {
  vector[M] mu;
  vector<lower=0.0001>[M] sig;
  simplex[M] pi;
}

model {
    real ps[M];
    for(n in 1:N){
        for(m in 1:M){
            ps[m] <- log(pi[m]) + normal_log(x[n], mu[m], sig[m]);
        }
        increment_log_prob(log_sum_exp(ps));
    }
}