Stanで混合正規分布の推定①
前回↓に引き続き、混合正規分布のパラメータ推定。ただし、今回はEMアルゴリズムではなくStanを使ってMCMCにてパラメータを推定する。
前回に引き続き、データは以下の分布からサンプリングすることにより作る。
ここに出てくる2つの平均、2つの分散、2つの混合比率を推定するのが目的。
サンプリングされたデータのヒストグラムは以下の通りで、乱数の種を揃えているので、前回と完全に同じデータとなる。
ただし今回は、密度関数の値はそれぞれの分布のものに加えて、それらの合計も紫の線で示した。
Stanを実行することで得られたパラメータ推定値の分布とパスのプロットがこちら。
上段がそれぞれの分布の平均、中段が標準偏差(分散ではない)、下段が混合係数となっていて、それぞれ問題なく分布収束しているように見える。
バーンインは100しか取ってないのにStanすごいやん。
続いて、パラメータ推定値のサマリー。
EMアルゴリズムの結果とほぼ完璧に一致していることが確認できる。
ただ、EMアルゴリズムはパラメータの標準誤差を計算するのが難しい(がんばればできるようだが)ことから信頼区間が構築しづらいのに対し、MCMCでは以下のように簡単に信用区間を構築できるため、その点はMCMCの方が便利。
In [18]: fitchan Out[18]: Inference for Stan model: anon_model_912e3ad2e9344e95d4acd093ad7b9b1b. 1 chains, each with iter=2000; warmup=100; thin=1; post-warmup draws per chain=1900, total post-warmup draws=1900. mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat mu[0] -5.24 7.7e-3 0.11 -5.44 -5.31 -5.24 -5.17 -5.01 194.0 1.0 mu[1] 5.0 1.3e-3 0.03 4.94 4.98 5.0 5.02 5.06 511.0 1.0 sig[0] 4.82 3.2e-3 0.07 4.68 4.78 4.82 4.87 4.97 531.0 1.0 sig[1] 1.02 1.1e-3 0.03 0.97 1.01 1.02 1.04 1.07 488.0 1.0 pi[0] 0.69 4.1e-4 7.7e-3 0.68 0.69 0.69 0.7 0.71 365.0 1.0 pi[1] 0.31 4.1e-4 7.7e-3 0.29 0.3 0.31 0.31 0.32 365.0 1.0 lp__ -1.5e4 0.09 1.59 -1.5e4 -1.5e4 -1.5e4 -1.5e4 -1.5e4 341.0 1.0 Samples were drawn using NUTS(diag_e) at 03/05/15 15:46:37. For each parameter, n_eff is a crude measure of effective sample size, and Rhat is the potential scale reduction factor on split chains (at convergence, Rhat=1).
Pythonコードはこちら。
# -*- coding: utf-8 -*- from __future__ import print_function, division import numpy as np from scipy import stats import matplotlib.pyplot as plt import seaborn as sns import pystan #サンプル数 N = 5000 #混合数 K = 2 #混合係数 pi = 0.7 #乱数の種 np.random.seed(0) #混合係数から各分布のサンプリング数を決める N_k1 = N*pi N_k2 = N-N_k1 #真のパラメータ mu1 = -5 sig1 = np.sqrt(25) mu2 = 5 sig2 = np.sqrt(1) x1 = np.random.normal(mu1,sig1,N_k1) x2 = np.random.normal(mu2,sig2,N_k2) #観測変数 x = np.hstack((x1,x2)) base=np.linspace(np.min(x),np.max(x),1000) plt.hist(x,bins=100,normed=True) plt.plot(base,pi*stats.norm.pdf(base,mu1,sig1)) plt.plot(base,(1-pi)*stats.norm.pdf(base,mu2,sig2)) plt.plot(base,pi*stats.norm.pdf(base,mu1,sig1)+(1-pi)*stats.norm.pdf(base,mu2,sig2)) #Stan stan_data = {'N': N, 'M': K, 'x': x} model = pystan.StanModel('D:/Python/GMM.stan') fitchan = model.sampling(data=stan_data, iter=2000, warmup=100,chains=1) fitchan fitchan.plot() plt.tight_layout()
Stanコードはこちら。
data { int<lower=1> N; int<lower=1> M; real x[N]; } parameters { vector[M] mu; vector<lower=0.0001>[M] sig; simplex[M] pi; } model { real ps[M]; for(n in 1:N){ for(m in 1:M){ ps[m] <- log(pi[m]) + normal_log(x[n], mu[m], sig[m]); } increment_log_prob(log_sum_exp(ps)); } }