WHAT' CHA GONNA DO FOR ME?

Python、統計、機械学習、R、ファイナンスとか

EMアルゴリズムで混合正規分布の推定

EMアルゴリズムの導入として、定番の混合正規分布のパラメータ推定をやってみる。

問題としては、K個の正規分布からある比率に従ってサンプリングされたデータだけが手元に得られた状態で、それぞれの正規分布の平均と分散、及びその比率を推定するというもの。

数式で書けば、混合正規分布は次のようになる。

{\displaystyle
x \sim\ \sum_{k=1}^{K}\pi_k \mathcal{N} (\mu_k, \sigma_k^2), \ \sum_{k=1}^{K}\pi_k=1
}

今回は、最も簡単な{K=2}のケースを設定する。
具体的には、次の分布を使う。

{
x \sim 0.7 \mathcal{N} (-5, 5^2) + 0.3 \mathcal{N} (5, 1^2)
}

ここに出てくる2つの平均、2つの分散、2つの混合比率を推定するのが目的。

まずは、この分布からデータをサンプリングするとこんな感じになる。
理論的な密度関数の値をヒストグラムに重ねてみても大体それに沿っているので、多分想定通りにサンプリングできていると思われる。

f:id:lofas:20150303145322p:plain

このデータだけをもとに(=一つひとつのデータがそれぞれどちらの正規分布からサンプリングされたか知ることなしに)、EMアルゴリズムを使ってパラメータを推定しにいく。
EMアルゴリズムは、パラメータに適当な初期値を与えた上で、EステップとMステップを繰り返すことで(この繰り返す数をイテレーションと呼んだりする。そこはMCMCと同じ)、尤度の高いところのパラメータを見つけに行く最尤推定の一種で、この繰り返し計算により尤度は単調に増加することが知られている(ただし、大域的最適解に辿り着くとは限らない⇔局所解に陥る可能性あり)

EMアルゴリズムアルゴリズムの説明をここで行うのはあまりに大変なのと、既存の本よりうまく説明できる気も全くしないので、興味がある方は例えば下記の本を参照頂きたい。いずれも混合正規分布を例にした説明がある。
今回はおもに一番上のPRMLを参考にしてコードは書いた。

パターン認識と機械学習 下 (ベイズ理論による統計的予測)

パターン認識と機械学習 下 (ベイズ理論による統計的予測)

はじめてのパターン認識

はじめてのパターン認識

これなら分かる最適化数学―基礎原理から計算手法まで

これなら分かる最適化数学―基礎原理から計算手法まで

以下は、EMアルゴリズムを用いて実際に推定を行った結果。
気になるパラメータの前に、まずは対数尤度が本当に単調増加しているかを確認する。
大体イテレーションが100ぐらいのところまでは対数尤度は単調に増加し、それ以降は上昇はストップしている模様。

f:id:lofas:20150303145330p:plain

最後のイテレーションのパラメータを抜いてくると、このような値が得られる。
上段が平均、中段が分散、下段が混合比率を表している。微妙なズレはあるものの、ほぼ正解と一致していることがわかる。

In [21]: theta[:,:,iteration-1]
Out[21]: 
array([[ -5.23184557,   5.00480952],
       [ 23.28299505,   1.04155592],
       [  0.69366902,   0.30633098]])

コード全体はこちら。
本当であれば、対数尤度の増加がある閾値を下回ったところでアルゴリズムを打ち切りたいところだが、今回はイテレーション数を頭のところで決め打ちにしている。
また、今回はEMアルゴリズムでパラメータ推定を行ったが、MCMCを用いたパラメータ推定も可能と思われる。それについてはまた別の機会に書くかもしれないし書かないかもしれない。

# -*- coding: utf-8 -*-
from __future__ import print_function, division
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
import seaborn as sns

#サンプル数
N = 5000
#混合数
K = 2
#混合係数
pi = 0.7
#イテレーション
iteration = 500
#乱数の種
np.random.seed(0)

#混合具合を一様乱数から決める
#N_k1 = np.round(np.random.uniform()*N,decimals=0)
N_k1 = N*pi
N_k2 = N-N_k1

#真のパラメータ
mu1 = -5
sig1 = np.sqrt(25)
mu2 = 5
sig2 = np.sqrt(1)

x1 = np.random.normal(mu1,sig1,N_k1)
x2 = np.random.normal(mu2,sig2,N_k2)

#観測変数
x = np.hstack((x1,x2))
base=np.linspace(np.min(x),np.max(x),1000)
plt.hist(x,bins=100,normed=True)
plt.plot(base,pi*stats.norm.pdf(base,mu1,sig1))
plt.plot(base,(1-pi)*stats.norm.pdf(base,mu2,sig2))

#負担率格納用
gamma = np.zeros((N,K,iteration))
#尤度格納用
llk = np.zeros((iteration,1))
#パラメータ格納用
theta = np.zeros((3,K,iteration))
#初期値
theta[0,0,:]=0
theta[0,1,:]=1
theta[1,0,:]=1
theta[1,1,:]=2
theta[2,0,:]=0.6
theta[2,1,:]=0.4

for i in range(iteration):
  ##################################
  #E-Step
  ##################################
    gamma[:,0,i] = theta[2,0,i]*stats.norm.pdf(x,theta[0,0,i],np.sqrt(theta[1,0,i]))/ \
        (theta[2,0,i]*stats.norm.pdf(x,theta[0,0,i],np.sqrt(theta[1,0,i]))+ \
         theta[2,1,i]*stats.norm.pdf(x,theta[0,1,i],np.sqrt(theta[1,1,i])))
         
    gamma[:,1,i] = theta[2,1,i]*stats.norm.pdf(x,theta[0,1,i],np.sqrt(theta[1,1,i]))/ \
        (theta[2,0,i]*stats.norm.pdf(x,theta[0,0,i],np.sqrt(theta[1,0,i]))+ \
         theta[2,1,i]*stats.norm.pdf(x,theta[0,1,i],np.sqrt(theta[1,1,i])))    

  ##################################
  #M-Step
  ##################################
    if i != iteration-1:
        #mu
        theta[0,0,i+1] = gamma[:,0,i].dot(x)/np.sum(gamma[:,0,i])
        theta[0,1,i+1] = gamma[:,1,i].dot(x)/np.sum(gamma[:,1,i])
        #sig
        theta[1,0,i+1] = gamma[:,0,i].dot((x-theta[0,0,i+1])**2)/np.sum(gamma[:,0,i])
        theta[1,1,i+1] = gamma[:,1,i].dot((x-theta[0,1,i+1])**2)/np.sum(gamma[:,1,i])
        #pi
        theta[2,0,i+1] = np.sum(gamma[:,0,i])/N
        theta[2,1,i+1] = np.sum(gamma[:,1,i])/N
  ##################################
  #対数尤度の計算
  ##################################
        llk[i] = np.sum(np.log( \
              theta[2,0,i+1]*stats.norm.pdf(x,theta[0,0,i+1],np.sqrt(theta[1,0,i+1]))+ \
              theta[2,1,i+1]*stats.norm.pdf(x,theta[0,1,i+1],np.sqrt(theta[1,1,i+1]))  \
        ))  

plt.plot(llk[range(iteration-1)])
theta[:,:,iteration-1]