EMアルゴリズムで混合正規分布の推定
EMアルゴリズムの導入として、定番の混合正規分布のパラメータ推定をやってみる。
問題としては、K個の正規分布からある比率に従ってサンプリングされたデータだけが手元に得られた状態で、それぞれの正規分布の平均と分散、及びその比率を推定するというもの。
数式で書けば、混合正規分布は次のようになる。
今回は、最も簡単なのケースを設定する。
具体的には、次の分布を使う。
ここに出てくる2つの平均、2つの分散、2つの混合比率を推定するのが目的。
まずは、この分布からデータをサンプリングするとこんな感じになる。
理論的な密度関数の値をヒストグラムに重ねてみても大体それに沿っているので、多分想定通りにサンプリングできていると思われる。
このデータだけをもとに(=一つひとつのデータがそれぞれどちらの正規分布からサンプリングされたか知ることなしに)、EMアルゴリズムを使ってパラメータを推定しにいく。
EMアルゴリズムは、パラメータに適当な初期値を与えた上で、EステップとMステップを繰り返すことで(この繰り返す数をイテレーションと呼んだりする。そこはMCMCと同じ)、尤度の高いところのパラメータを見つけに行く最尤推定の一種で、この繰り返し計算により尤度は単調に増加することが知られている(ただし、大域的最適解に辿り着くとは限らない⇔局所解に陥る可能性あり)
EMアルゴリズムのアルゴリズムの説明をここで行うのはあまりに大変なのと、既存の本よりうまく説明できる気も全くしないので、興味がある方は例えば下記の本を参照頂きたい。いずれも混合正規分布を例にした説明がある。
今回はおもに一番上のPRMLを参考にしてコードは書いた。
- 作者: C.M.ビショップ,元田浩,栗田多喜夫,樋口知之,松本裕治,村田昇
- 出版社/メーカー: 丸善出版
- 発売日: 2012/02/29
- メディア: 単行本
- 購入: 6人 クリック: 14回
- この商品を含むブログを見る
- 作者: 平井有三
- 出版社/メーカー: 森北出版
- 発売日: 2012/07/31
- メディア: 単行本(ソフトカバー)
- 購入: 1人 クリック: 7回
- この商品を含むブログ (2件) を見る
- 作者: 金谷健一
- 出版社/メーカー: 共立出版
- 発売日: 2005/09/01
- メディア: 単行本
- 購入: 29人 クリック: 424回
- この商品を含むブログ (41件) を見る
以下は、EMアルゴリズムを用いて実際に推定を行った結果。
気になるパラメータの前に、まずは対数尤度が本当に単調増加しているかを確認する。
大体イテレーションが100ぐらいのところまでは対数尤度は単調に増加し、それ以降は上昇はストップしている模様。
最後のイテレーションのパラメータを抜いてくると、このような値が得られる。
上段が平均、中段が分散、下段が混合比率を表している。微妙なズレはあるものの、ほぼ正解と一致していることがわかる。
In [21]: theta[:,:,iteration-1] Out[21]: array([[ -5.23184557, 5.00480952], [ 23.28299505, 1.04155592], [ 0.69366902, 0.30633098]])
コード全体はこちら。
本当であれば、対数尤度の増加がある閾値を下回ったところでアルゴリズムを打ち切りたいところだが、今回はイテレーション数を頭のところで決め打ちにしている。
また、今回はEMアルゴリズムでパラメータ推定を行ったが、MCMCを用いたパラメータ推定も可能と思われる。それについてはまた別の機会に書くかもしれないし書かないかもしれない。
# -*- coding: utf-8 -*- from __future__ import print_function, division import numpy as np from scipy import stats import matplotlib.pyplot as plt import seaborn as sns #サンプル数 N = 5000 #混合数 K = 2 #混合係数 pi = 0.7 #イテレーション iteration = 500 #乱数の種 np.random.seed(0) #混合具合を一様乱数から決める #N_k1 = np.round(np.random.uniform()*N,decimals=0) N_k1 = N*pi N_k2 = N-N_k1 #真のパラメータ mu1 = -5 sig1 = np.sqrt(25) mu2 = 5 sig2 = np.sqrt(1) x1 = np.random.normal(mu1,sig1,N_k1) x2 = np.random.normal(mu2,sig2,N_k2) #観測変数 x = np.hstack((x1,x2)) base=np.linspace(np.min(x),np.max(x),1000) plt.hist(x,bins=100,normed=True) plt.plot(base,pi*stats.norm.pdf(base,mu1,sig1)) plt.plot(base,(1-pi)*stats.norm.pdf(base,mu2,sig2)) #負担率格納用 gamma = np.zeros((N,K,iteration)) #尤度格納用 llk = np.zeros((iteration,1)) #パラメータ格納用 theta = np.zeros((3,K,iteration)) #初期値 theta[0,0,:]=0 theta[0,1,:]=1 theta[1,0,:]=1 theta[1,1,:]=2 theta[2,0,:]=0.6 theta[2,1,:]=0.4 for i in range(iteration): ################################## #E-Step ################################## gamma[:,0,i] = theta[2,0,i]*stats.norm.pdf(x,theta[0,0,i],np.sqrt(theta[1,0,i]))/ \ (theta[2,0,i]*stats.norm.pdf(x,theta[0,0,i],np.sqrt(theta[1,0,i]))+ \ theta[2,1,i]*stats.norm.pdf(x,theta[0,1,i],np.sqrt(theta[1,1,i]))) gamma[:,1,i] = theta[2,1,i]*stats.norm.pdf(x,theta[0,1,i],np.sqrt(theta[1,1,i]))/ \ (theta[2,0,i]*stats.norm.pdf(x,theta[0,0,i],np.sqrt(theta[1,0,i]))+ \ theta[2,1,i]*stats.norm.pdf(x,theta[0,1,i],np.sqrt(theta[1,1,i]))) ################################## #M-Step ################################## if i != iteration-1: #mu theta[0,0,i+1] = gamma[:,0,i].dot(x)/np.sum(gamma[:,0,i]) theta[0,1,i+1] = gamma[:,1,i].dot(x)/np.sum(gamma[:,1,i]) #sig theta[1,0,i+1] = gamma[:,0,i].dot((x-theta[0,0,i+1])**2)/np.sum(gamma[:,0,i]) theta[1,1,i+1] = gamma[:,1,i].dot((x-theta[0,1,i+1])**2)/np.sum(gamma[:,1,i]) #pi theta[2,0,i+1] = np.sum(gamma[:,0,i])/N theta[2,1,i+1] = np.sum(gamma[:,1,i])/N ################################## #対数尤度の計算 ################################## llk[i] = np.sum(np.log( \ theta[2,0,i+1]*stats.norm.pdf(x,theta[0,0,i+1],np.sqrt(theta[1,0,i+1]))+ \ theta[2,1,i+1]*stats.norm.pdf(x,theta[0,1,i+1],np.sqrt(theta[1,1,i+1])) \ )) plt.plot(llk[range(iteration-1)]) theta[:,:,iteration-1]