Stanで混合正規分布の推定②
混合正規分布のパラメータ推定シリーズ。
これまでEMアルゴリズムとStanのMCMCにてパラメータの推定を行ってきた。
今回は、数値最適化によるパラメータ推定をやってみようと思う。
StanといえばHMCを用いたMCMCだが、実は最近のバージョンアップにより数値最適化関数が用意されたようなので、今回はそれを使ってみる。
前回に引き続き、データは以下の分布からサンプリングすることにより作る。
ここに出てくる2つの平均、2つの分散、2つの混合比率を推定するのが目的。
サンプリングされたデータのヒストグラムは以下の通りで、乱数の種を揃えているので、いつもと完全に同じデータ。
数値最適化のやり方は非常に簡単で、Stanのモデルコードを読み込んだあとに、モデルのsamplingメソッドを使っていたところを、optimizingメソッドに変更するだけでOK。
Stanコードに書いたパラメータの制約条件もちゃんと考慮してくれる。
In [36]: fitchan = model.optimizing(data=stan_data)
Stanの数値最適化関数を実行することで得られた各パラメータの推定結果はこちら。
EMアルゴリズムとStanのMCMCの推定結果とほぼ一致していることが確認できる。
In [29]: fitchan Out[29]: OrderedDict([(u'mu', array([-5.23176606, 5.00483165])), (u'sig', array([ 4.82527209, 1.02056465])), (u'pi', array([ 0.69367853, 0.30632147]))])
ただし、Stanの最適化関数は特に何も指定しない限り初期値をランダムに選択するようで、初期値によっては全然違った結果が返ってくることもある。
以下は、想定の答えが返ってこなかった例。
In [24]: fitchan Out[24]: OrderedDict([(u'mu', array([-2.09644408, -1.57536974])), (u'sig', array([ 6.22396311, 6.08568693])), (u'pi', array([ 9.99260820e-01, 7.39179944e-04]))])
このように、初期値にかなり依存するので使う場合は注意が必要。
EMアルゴリズムも初期値に依存する点は同様だが、よく言われるように、その依存度合いは数値最適化よりは小さい印象。
Pythonコードはこちらで、sampling→optimizingの1行だけ変わっている。
# -*- coding: utf-8 -*- from __future__ import print_function, division import numpy as np from scipy import stats import matplotlib.pyplot as plt import seaborn as sns import pystan #サンプル数 N = 5000 #混合数 K = 2 #混合係数 pi = 0.7 #乱数の種 np.random.seed(0) #混合係数から各分布のサンプリング数を決める N_k1 = N*pi N_k2 = N-N_k1 #真のパラメータ mu1 = -5 sig1 = np.sqrt(25) mu2 = 5 sig2 = np.sqrt(1) x1 = np.random.normal(mu1,sig1,N_k1) x2 = np.random.normal(mu2,sig2,N_k2) #観測変数 x = np.hstack((x1,x2)) base=np.linspace(np.min(x),np.max(x),1000) plt.hist(x,bins=100,normed=True) plt.plot(base,pi*stats.norm.pdf(base,mu1,sig1)) plt.plot(base,(1-pi)*stats.norm.pdf(base,mu2,sig2)) plt.plot(base,pi*stats.norm.pdf(base,mu1,sig1)+(1-pi)*stats.norm.pdf(base,mu2,sig2)) #Stan stan_data = {'N': N, 'M': K, 'x': x} model = pystan.StanModel('D:/Python/GMM.stan') fitchan = model.optimizing(data=stan_data) fitchan
Stanコードはこちらで、前回と全く同じ。
data { int<lower=1> N; int<lower=1> M; real x[N]; } parameters { vector[M] mu; vector<lower=0.0001>[M] sig; simplex[M] pi; } model { real ps[M]; for(n in 1:N){ for(m in 1:M){ ps[m] <- log(pi[m]) + normal_log(x[n], mu[m], sig[m]); } increment_log_prob(log_sum_exp(ps)); } }