🌟

EMアルゴリズムとその幾何構造

に公開

はじめに

EMアルゴリズムは最尤推定の一種であり、データの一部が欠損しているなどして観測できない箇所があるときに有効な手法です。EMアルゴリズムでググってみると、様々解説記事が上がっていたり、教科書としてもよく取り上げられるトピックですが、本記事ではEMアルゴリズムだけでなく情報幾何学に基づいた、幾何的な性質についても紹介していこうと思います。

EMアルゴリズム

EMアルゴリズムではデータ\mathbf{x}について、\mathbf{x}=(\mathbf{v},\mathbf{h})と書ける場合を考えます。\mathbf{v}はデータの中でも観測可能な部分、\mathbf{h}は観測できない部分を指し、隠れ変数と呼びます。このような隠れ変数を持つ例として前回の記事で挙げたような制限ボルツマンマシンが挙げられます。この観測できた\mathbf{v}を活用し、最尤推定の達成を目指していきます。

EMアルゴリズムの方針

最尤推定ではモデルとなる確率分布の最大化を実行します。ここでは確率分布p_{\bm{\xi}}(\mathbf{v})の最大化を考えます。一般的にはその対数を取った

\log p_{\bm{\xi}}(\mathbf{v})

の最大化を考えます。\bm{\xi}は確率分布のパラメータで、これが決まれば確率分布が1つに決まるものです[1]。この最大化の目的を達成するために適切なパラメータを推定していきます。

EMアルゴリズムでは隠れ変数を含む場合について、パラメータ更新をしながら最大化を図れる工夫がされています[2]。そのための準備をします。天下り的ですが先の結果を見越して確率分布q(\mathbf{h}|\mathbf{v})を用いて整理します[3]

\begin{align*} \log p_{\bm{\xi}}(\mathbf{v}) &=\left(\int q(\mathbf{h}|\mathbf{v})d\mathbf{h}\right)\log \frac{q(\mathbf{h}|\mathbf{v})p_{\bm{\xi}}(\mathbf{v},\mathbf{h})}{q(\mathbf{h}|\mathbf{v})p_{\bm{\xi}}(\mathbf{h}|\mathbf{v})}\\ &=\int q(\mathbf{h}|\mathbf{v})\left(\log \frac{p_{\bm{\xi}}(\mathbf{v},\mathbf{h})}{q(\mathbf{h}|\mathbf{v})}+\log \frac{q(\mathbf{h}|\mathbf{v})}{p_{\bm{\xi}}(\mathbf{h}|\mathbf{v})}\right)d\mathbf{h}\\ &=\int q(\mathbf{h}|\mathbf{v})\log \frac{p_{\bm{\xi}}(\mathbf{v},\mathbf{h})}{q(\mathbf{h}|\mathbf{v})}d\mathbf{h}+\int q(\mathbf{h}|\mathbf{v})\log \frac{q(\mathbf{h}|\mathbf{v})}{p_{\bm{\xi}}(\mathbf{h}|\mathbf{v})}d\mathbf{h} \end{align*}

ここで

\begin{align*} L_{\bm{\xi}}[q(\mathbf{h}|\mathbf{v})]&=\int q(\mathbf{h}|\mathbf{v})\log \frac{p_{\bm{\xi}}(\mathbf{v},\mathbf{h})}{q(\mathbf{h}|\mathbf{v})}d\mathbf{h}\\ D_{\text{KL}}(q(\mathbf{h}|\mathbf{v})||p_{\bm{\xi}}(\mathbf{h}|\mathbf{v}))&=\int q(\mathbf{h}|\mathbf{v})\log \frac{q(\mathbf{h}|\mathbf{v})}{p_{\bm{\xi}}(\mathbf{h}|\mathbf{v})}d\mathbf{h} \end{align*}

とおけば

\log p_{\bm{\xi}}(\mathbf{v})=L_{\bm{\xi}}[q(\mathbf{h}|\mathbf{v})]+D_{\text{KL}}(q(\mathbf{h}|\mathbf{v})||p_{\bm{\xi}}(\mathbf{h}|\mathbf{v}))

と記述することができます。D_{\text{KL}}はKLダイバージェンスであり、0以上の値を取るものです。であれば、

\log p_{\bm{\xi}}(\mathbf{v})\geq L_{\bm{\xi}}[q(\mathbf{h}|\mathbf{v})]

となり、\log p_{\bm{\xi}}(\mathbf{v})の下限がL_{\bm{\xi}}[q(\mathbf{h}|\mathbf{v})]であることが分かります。こうした結果から\log p_{\bm{\xi}}(\mathbf{v})の最大化を下限を押し上げる形で達成しようと思います。

Eステップ

下限を押し上げるというのはL_{\bm{\xi}}[q(\mathbf{h}|\mathbf{v})]について、より大きな値を取るq(\mathbf{h}|\mathbf{v})とパラメータ\bm{\xi}を見つけることに他なりません。このq(\mathbf{h}|\mathbf{v})\bm{\xi}を交互に更新する手続きがEMアルゴリズムの流れです。まずはq(\mathbf{h}|\mathbf{v})の更新について説明します。

q(\mathbf{h}|\mathbf{v})の更新はパラメータ\bm{\xi}を固定して実施します。再度次の式を参照します。

\log p_{\bm{\xi}}(\mathbf{v})=L_{\bm{\xi}}[q(\mathbf{h}|\mathbf{v})]+D_{\text{KL}}(q(\mathbf{h}|\mathbf{v})||p_{\bm{\xi}}(\mathbf{h}|\mathbf{v}))

\log p_{\bm{\xi}}(\mathbf{v})q(\mathbf{h}|\mathbf{v})に対して依存しません。q(\mathbf{h}|\mathbf{v})がどんな関数であろうと\log p_{\bm{\xi}}(\mathbf{v})\log p_{\bm{\xi}}(\mathbf{v})のままです。であるならばKLダイバージェンスD_{\text{KL}}0になるようなq(\mathbf{h}|\mathbf{v})を選べば、L_{\bm{\xi}}[q(\mathbf{h}|\mathbf{v})]は大きな値を取ります。このときq(\mathbf{h}|\mathbf{v})=p_{\bm{\xi}}(\mathbf{h}|\mathbf{v})となります。ただこの過程ではパラメータは固定しているので、固定したパラメータを\bm{\xi}'とでも置けば、

q(\mathbf{h}|\mathbf{v})=p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v})

という更新になります。

Mステップ

次に\bm{\xi}の更新を実施します。このときq(\mathbf{h}|\mathbf{v})は固定です。その固定は先ほど示した通りq(\mathbf{h}|\mathbf{v})=p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v})です。従って

L_{\bm{\xi}}[p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v})]=\int p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v})\log \frac{p_{\bm{\xi}}(\mathbf{v},\mathbf{h})}{p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v})}d\mathbf{h}

がより大きな値を取る\bm{\xi}を考えることになります。これは、

\text{grad}_{\bm{\xi}}L_{\bm{\xi}}[p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v})]=0

となる\bm{\xi}を探せばよさそうです[4]。EMアルゴリズムが最尤推定の手法であることを踏まえると、観測したデータ\mathbf{v}^{(1)},\mathbf{v}^{(2)},\cdots,\mathbf{v}^{(N)}に対する最大化を実施するので正確には、

\text{grad}_{\bm{\xi}} \left\{ \frac{1}{N}\sum_{n=1}^{N}L_{\bm{\xi}}[p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v}^{(n)})]\right\} =0

を解くことになります[5]。この偏微分を実行するにあたり、

L_{\bm{\xi}}[p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v}^{(n)})]= \int p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v}^{(n)})\log p_{\bm{\xi}}(\mathbf{v}^{(n)},\mathbf{h})d\mathbf{h} -\int p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v}^{(n)})\log p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v}^{(n)})d\mathbf{h}

であるので、第2項目は\bm{\xi}に依存しないことからこれは無視します。従って第1項目のみを残した

Q(\bm{\xi},\bm{\xi}')=\frac{1}{N}\sum_{n=1}^{N}\int p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v}^{(n)})\log p_{\bm{\xi}}(\mathbf{v}^{(n)},\mathbf{h})d\mathbf{h}

が最大となる\bm{\xi}を探せばよいという結論に至ります。

情報幾何学超入門

EMアルゴリズムの幾何的な性質を説明する前に、その基礎となる情報幾何学について述べていきます。

情報幾何とリーマン幾何

情報幾何学というのは簡単に言えば統計モデル(つまり確率分布)を対象とした幾何で、パラメータを使ってその幾何構造を紐解いていきます。先ほども示した通りパラメータが一意に決まれば確率分布も1つに決まるため、確率分布の空間S

S=\{p_{\bm{\xi}}(\mathbf{x})\mid \bm{\xi}=(\xi_{1},\cdots,\xi_{k})\in\Xi\}

で定義します。S=\{p_{\bm{\xi}}(\mathbf{x})\}と略して書かれることもあります。ここでパラメータの個数はkのなのでSの次元\dim S=kです。\Xiはパラメータ全体の集合を指します。情報幾何学ではこの確率分布の幾何をリーマン幾何学の枠組みで解析していきます。リーマン幾何に必要なのはこの空間の他に計量gです。この計量はリーマン計量と呼ばれ、テンソルの仲間です[6]。点\bm{\xi}に置ける計量の行列表現g(\bm{\xi})=(g_{ij}(\bm{\xi}))を用いることにすると、ベクトル\bm{x}=(x^{i}),\bm{y}=(y^{j})について、

g_{\bm{\xi}}(\bm{x},\bm{y})=\sum_{i,j}g_{ij}(\bm{\xi})x^{i}y^{j}

という内積g_{\bm{\xi}}(\bm{x},\bm{y})が定まります[7]。この内積を使えば距離などの幾何的な性質を調べることができ、確率分布の空間の計量(内積)としては、

g_{ij}(\bm{\xi})=\mathbb{E}_{p_{\bm{\xi}}(\mathbf{x})}\left[ \frac{\partial\log p_{\bm{\xi}}(\mathbf{x})}{\partial\xi_{i}} \frac{\partial\log p_{\bm{\xi}}(\mathbf{x})}{\partial\xi_{j}} \right]

というフィッシャー情報行列が入ります[8]

確率空間のベクトル表現

情報幾何学には2つの視点があります。具体的には確率分布p_{\bm{\xi}}(\mathbf{x})そのものの観点と確率分布の対数\log p_{\bm{\xi}}(\mathbf{x})を取った視点です。この対数をとった視点は新たに

l_{\bm{\xi}}(\mathbf{x})=\log p_{\bm{\xi}}(\mathbf{x})

で定義することにします。対数は単調増加な関数なので、p_{\bm{\xi}}(\mathbf{x})l_{\bm{\xi}}(\mathbf{x})は1対1です。ここでフィッシャー情報行列の要素g_{ij}(\bm{\xi})を式変形してみます。連続分布の前提で進めます。

\begin{align*} g_{ij}(\bm{\xi})&= \int p_{\bm{\xi}}(\mathbf{x}) \frac{\partial\log p_{\bm{\xi}}(\mathbf{x})}{\partial\xi_{i}}\frac{\partial\log p_{\bm{\xi}}(\mathbf{x})}{\partial\xi_{j}}d\mathbf{x}\\ &=\int p_{\bm{\xi}}(\mathbf{x}) \frac{\partial \log p_{\bm{\xi}}(\mathbf{x})}{\partial\xi_{i}} \frac{1}{p_{\bm{\xi}}(\mathbf{x})} \frac{\partial p_{\bm{\xi}}(\mathbf{x})}{\partial\xi_{j}}d\mathbf{x}\\ &=\int \frac{\partial l_{\bm{\xi}}(\mathbf{x})}{\partial\xi_{i}} \frac{\partial p_{\bm{\xi}}(\mathbf{x})}{\partial\xi_{j}}d\mathbf{x}\\ \end{align*}

この式変形結果より

g_{ij}(\bm{\xi})=\left\langle \frac{\partial l_{\bm{\xi}}}{\partial\xi_{i}}, \frac{\partial p_{\bm{\xi}}}{\partial\xi_{j}}\right\rangle =\left\{ \begin{array}{ll} \displaystyle\int \frac{\partial l_{\bm{\xi}}(\mathbf{x})}{\partial\xi_{i}} \frac{\partial p_{\bm{\xi}}(\mathbf{x})}{\partial\xi_{j}}d\mathbf{x} & \text{連続分布のとき} \\[7pt] \displaystyle\sum_{\mathbf{x}}\frac{\partial l_{\bm{\xi}}(\mathbf{x})}{\partial\xi_{i}} \frac{\partial p_{\bm{\xi}}(\mathbf{x})}{\partial\xi_{j}}& \text{離散分布のとき} \end{array} \right.

を定義します。この式変形結果はユークリッド内積

\bm{a}\cdot\bm{b}=a_{1}b_{1}+a_{2}b_{2}+\cdots+a_{n}b_{n}

を彷彿とさせます。ここでg_{ij}(\bm{\xi})は基底ベクトル\bm{e}_{i},\bm{e}_{j}に関する内積g_{ij}(\bm{\xi})=g_{\bm{\xi}}(\bm{e}_{i},\bm{e}_{j})の結果であるので、この式変形から、ベクトル\bm{e}_{i}\frac{\partial l_{\bm{\xi}}}{\partial\xi_{i}}へ、ベクトル\bm{e}_{j}\frac{\partial p_{\bm{\xi}}}{\partial\xi_{j}}に表現を移して内積計算しているという見方ができます。この\frac{\partial l_{\bm{\xi}}}{\partial\xi_{i}}\frac{\partial p_{\bm{\xi}}}{\partial\xi_{j}}はベクトル\bm{e}_{i},\bm{e}_{j}のe-表現、m-表現と言います。

確率空間の接続

リーマン幾何において異なる点での計量の比較というのは、空間のつながり方を反映しているにすぎません。例えば\xi_{k}方向の計量のずれというのは、

\begin{align*} \frac{\partial g_{ij}(\bm{\xi})}{\partial \xi_{k}} =\frac{\partial}{\partial \xi_{k}}\left\langle \frac{\partial l_{\bm{\xi}}}{\partial\xi_{i}}, \frac{\partial p_{\bm{\xi}}}{\partial\xi_{j}}\right\rangle =\left\langle \frac{\partial^{2} l_{\bm{\xi}}}{\partial\xi_{k}\partial\xi_{i}}, \frac{\partial p_{\bm{\xi}}}{\partial\xi_{j}}\right\rangle + \left\langle \frac{\partial l_{\bm{\xi}}}{\partial\xi_{i}}, \frac{\partial^{2} p_{\bm{\xi}}}{\partial\xi_{k}\partial\xi_{j}}\right\rangle \end{align*}

であり、

\begin{align*} \Gamma^{(e)}_{kij}&=\left\langle \frac{\partial^{2} l_{\bm{\xi}}}{\partial\xi_{k}\partial\xi_{i}}, \frac{\partial p_{\bm{\xi}}}{\partial\xi_{j}}\right\rangle\\ \Gamma^{(m)}_{kji}&=\left\langle \frac{\partial l_{\bm{\xi}}}{\partial\xi_{i}}, \frac{\partial^{2} p_{\bm{\xi}}}{\partial\xi_{k}\partial\xi_{j}}\right\rangle \end{align*}

は接続係数と呼びます[9]。接続係数には接続の名の通り、空間のつながり方を示すものです。そのつながり方は一意かと思いきや、これは視点によって変わってきます。確率分布の視点としてp_{\bm{\xi}}(\mathbf{x})l_{\bm{\xi}}(\mathbf{x})の2つがあることを示しました。\Gamma^{(e)}_{kij}は確率分布の対数取った時の接続、\Gamma^{(m)}_{kji}は通常の確率の接続を表しているので、\Gamma^{(e)}_{kij}にとって平坦なつながり方でも、\Gamma^{(m)}_{kji}から見れば曲がっているということがあるのです。

こうした接続係数\Gamma^{(e)}_{kij},\Gamma^{(m)}_{kji}に基づく接続は\nabla^{(e)},\nabla^{(m)}と書かれそれぞれe-接続、m-接続といいます。これまで空間Sと呼んでいたものは数学では多様体という枠組みで解析され、数学にはSの中に入っている構造を書き下す慣習があります。リーマン計量gが入っている(S,g)はリーマン多様体といいます。さらに接続\nabla^{(e)},\nabla^{(m)}が入っている(S,g,\nabla^{(e)},\nabla^{(m)})は統計多様体と言います。確率空間と濁して呼んでいたSの正式名称は統計多様体です。

指数型分布族と混合型分布族

情報幾何では平坦性が重要ですが、この平坦性を決定づける2つの分布を紹介します。

指数型分布族

パラメータを\bm{\theta}=(\theta^{i})と置いたとき、指数型分布族に所属する確率分布p_{\bm{\theta}}(\mathbf{x})は、

p_{\bm{\theta}}(\mathbf{x})=\exp\left[C(\mathbf{x})+\sum_{i=1}^{k}\theta^{i}F_{i}(\mathbf{x})-\psi(\bm{\theta})\right]

と書け、M=\{p_{\bm{\theta}}(\mathbf{x})\}は指数型分布族と言います。\psi(\bm{\theta})は正規化因子です。指数型分布族について接続係数\Gamma^{(e)}_{kij}を計算すると、

\frac{\partial^{2} l_{\bm{\theta}}(\mathbf{x})}{\partial\theta_{k}\partial\theta_{i}} =\frac{\partial^{2}}{\partial\theta_{k}\partial\theta_{i}}\left[C(\mathbf{x})+\sum_{i=1}^{k}\theta^{i}F_{i}(\mathbf{x})-\psi(\bm{\theta})\right] =-\frac{\partial^{2}\psi(\bm{\theta})}{\partial\theta_{k}\partial\theta_{i}}

より、

\begin{align*} \Gamma^{(e)}_{kij}=\left\langle \frac{\partial^{2} l_{\bm{\theta}}}{\partial\theta_{k}\partial\theta_{i}}, \frac{\partial p_{\bm{\theta}}}{\partial\theta_{j}}\right\rangle &=-\frac{\partial^{2}\psi(\bm{\theta})}{\partial\theta_{k}\partial\theta_{i}}\int \frac{\partial p_{\bm{\theta}}(\mathbf{x})}{\partial\theta_{j}}d\mathbf{x}\\ &=-\frac{\partial^{2}\psi(\bm{\theta})}{\partial\theta_{k}\partial\theta_{i}} \frac{\partial }{\partial\theta_{j}}\int p_{\bm{\theta}}(\mathbf{x})d\mathbf{x}\\ &=-\frac{\partial^{2}\psi(\bm{\theta})}{\partial\theta_{k}\partial\theta_{i}} \frac{\partial }{\partial\theta_{j}}1=0 \end{align*}

0になることが分かります。これは指数型分布族が成す空間がe-接続にとってユークリッド空間のように平坦であることを示しています。e-接続について平坦な空間はe-平坦な空間と呼びます。

混合型分布族

パラメータを\bm{\eta}=(\eta_{i})で置いたとき、混合型分布族に所属する確率分布p_{\bm{\eta}}(\mathbf{x})は、

p_{\bm{\eta}}(\mathbf{x})=\sum_{i=1}^{k}\eta_{i}p_{i}(\mathbf{x})+\eta_{0}p_{0}(\mathbf{x}),\;\left(\eta_{0}=1-\sum_{i=1}^{k}\eta_{i}\right)

という確率分布p_{i}(\mathbf{x})の線形結合で表され、N=\{p_{\bm{\eta}}(\mathbf{x})\}は混合型分布族と言います。混合型分布族について接続係数\Gamma^{(m)}_{kji}を計算すると、明らかに、

\frac{\partial^{2} p_{\bm{\eta}}(\mathbf{x})}{\partial\eta_{k}\partial\eta_{j}}=0

であるので、

\Gamma^{(m)}_{kji}=\left\langle \frac{\partial l_{\bm{\eta}}}{\partial\eta_{i}}, \frac{\partial^{2} p_{\bm{\eta}}}{\partial\eta_{k}\partial\eta_{j}}\right\rangle=0

となります。これは指数型分布族が成す空間がm-接続にとって平坦であることを示しています。m-接続について平坦な空間はm-平坦な空間と呼びます。

測地線

測地線とは空間上の2点間の距離を最小にする曲線で、ユークリッド空間上だとこれは直線になります。日本からアメリカに向かう際、飛行機の航空路はメルカトル図法で見ればベーリング海あたりを飛行する無駄な経路に見えますが、球面上だと最短になります。これが測地線です。先ほどe-平坦な空間とm-平坦な空間を紹介しました。それぞれの空間の測地線を示します。

e-平坦な空間の測地線は、t\in[0,1]を用いて、

p_{t}(\mathbf{x})=\exp\left\{ C(\mathbf{x})+\sum_{i=1}^{k}\left[(1-t)\theta^{i}_{0}+t\theta^{i}_{1}\right]F_{i}(\mathbf{x})-\psi(\bm{\theta}(t)) \right\}

と書けます。ここで\theta^{i}(t)=(1-t)\theta^{i}_{0}+t\theta^{i}_{1}とすれば、\bm{\theta}(t)=(\theta^{i}(t))です。tの範囲でどんな値を取ったとしてもこれは指数型分布族に所属します。(1-t)\theta^{i}_{0}+t\theta^{i}_{1}という表記から察することができると思いますが、e-平坦な空間にとって直線となる線です。m-平坦な空間から見れば曲線です。これをe-測地線と言います。

逆にm-平坦な空間の測地線は、同様にt\in[0,1]を用いて、

p_{t}(\mathbf{x})=(1-t)p_{0}(\mathbf{x})+tp_{1}(\mathbf{x})

と書けます。まさしくこれは混合型分布族の確率分布の形です。とある混合型分布族を張る確率分布にp_{0}(\mathbf{x}),p_{1}(\mathbf{x})が含まれているならば、tのどんな値をとっても、p_{\bm{\eta(t)}}(\mathbf{x})も混合型分布族に所属します。これがm-平坦な空間にとっての直線でe-平坦な空間から見れば曲線です。これをm-測地線と言います。

KLダイバージェンスと射影

確率分布間の比較としてKLダイバージェンスはよく使われますが、統計多様体上では距離的な役割を果たし、確率分布間の離れ具合について示してくれます。特に統計多様体Sの部分多様体[10]MについてMがe-平坦な空間であるとします。統計多様体上の適当な確率分布pについて、Mに直交するように[11]m-測地線を下ろしたとき、その直交点をp'とすれば、次の関係がなりたちます。

D_{\text{KL}}(p||p')=\min_{r\in M} D_{\text{KL}}(p||r)

これらの関係を図示します。

m-射影の図

こうした関係が成り立つとき、p'pのm-射影だと言います。部分多様体上ではpp'だとみなす考え方です。

逆に統計多様体の部分多様体Nについて、m-平坦な空間としたとき、統計多様体の適当な確率分布pからNに直交するようにe-測地線を下ろしたとき、その直交点をp'とすれば、

D_{\text{KL}}(p'||p)=\min_{r\in N} D_{\text{KL}}(r||p)

も成り立ちます。これをp'pのe-射影だと言います。

emアルゴリズム

2つの部分多様体間の最も近い点、具体的には2つの部分多様体をD,Mとしたとき、

\min_{q\in D,p\in M}D_{\text{KL}}(q||p)

となるようなD,M上の点を探索するアルゴリズムとしてemアルゴリズムがあります。これはEMアルゴリズムとは区別されるもので、EMアルゴリズムに対する幾何的な解釈を与えます。まずemアルゴリズムについて説明します。

eステップ

eステップではM上の点p_{t}に対してKLダイバージェンスが最も小さくなる、D上の点q_{t}を探索します。数式で記述するならば次の通りです。

q_{t}=\argmin_{q\in D}D_{\text{KL}}(q||p_{t})

これは点p_{t}Dに射影するe-射影です。

mステップ

mステップではD上の点q_{t}に対してKLダイバージェンスが最も小さくなる、M上の点p_{t+1}を探索します。数式で記述するならば次の通りです。

p_{t+1}=\argmin_{p\in M}D_{\text{KL}}(q_{t}||p)

これは点q_{t}Mに射影するm-射影です。

幾何的な描像

emアルゴリズムはこのeステップ、mステップを繰り返し行います。幾何的には次のようになります。

emアルゴリズムの幾何描像

これらの繰り返しでもって、KLダイバージェンスは

D_{\text{KL}}(q_{t}||p_{t})\geq D_{\text{KL}}(q_{t}||p_{t+1})\geq D_{\text{KL}}(q_{t+1}||p_{t+1})

と小さくなっていくので、部分多様体間の最も近い点の探索ができるというわけです。もしDがm-平坦でMがe-平坦であれば、収束点は一意に決まります。一般的に収束点は1つとは限りません。

EMアルゴリズムの幾何構造

ようやく本題です。EMアルゴリズムは幾何的には観測多様体(またはデータ多様体)Dとモデル多様体Mと呼ばれる多様体間の近い点を探すアルゴリズムで、これはまさしくemアルゴリズムとしての捉え方になります。

隠れ変数を持たない通常の最尤推定は、経験分布\hat{q}という空間上の一点からモデル多様体に最も近いすなわち、KLダイバージェンス最も小さくなる点を探すことになりますが、隠れ変数を持つ場合、経験分布は観測できた部分\mathbf{v}の経験分布\hat{q}(\mathbf{v})しか得られないため、観測多様体として次のように広がりを持ちます。

D=\{\hat{q}(\mathbf{v})q(\mathbf{h}|\mathbf{v})\mid q(\mathbf{h}|\mathbf{v})\text{は任意}\}

ここで任意の2つの確率分布q_{1}(\mathbf{h}|\mathbf{v})q_{2}(\mathbf{h}|\mathbf{v})について、

q(\mathbf{h}|\mathbf{v})=(1-\lambda)q_{1}(\mathbf{h}|\mathbf{v})+\lambda q_{2}(\mathbf{h}|\mathbf{v})

という線形結合を取ると、新たな条件付確率q(\mathbf{h}|\mathbf{v})を作ることができ、

\begin{align*} \hat{q}(\mathbf{v})q(\mathbf{h}|\mathbf{v})&=\hat{q}(\mathbf{v})\left[(1-\lambda)q_{1}(\mathbf{h}|\mathbf{v})+\lambda q_{2}(\mathbf{h}|\mathbf{v})\right]\\ &=(1-\lambda)\hat{q}(\mathbf{v})q_{1}(\mathbf{h}|\mathbf{v})+\lambda \hat{q}(\mathbf{v})q_{2}(\mathbf{h}|\mathbf{v}) \end{align*}

の通り線形関係が見出せるので、\hat{q}(\mathbf{v})q_{1}(\mathbf{h}|\mathbf{v})\hat{q}(\mathbf{v})q_{2}(\mathbf{h}|\mathbf{v})が観測多様体上の点ならば、Dはm-平坦な空間とみなすことができます。対してM

M=\left\{p_{\bm{\xi}}(\mathbf{v},\mathbf{h})\right\}

という統計モデルの空間で、指数型分布族だと仮定すればMはe-平坦な空間になるでしょう。EMアルゴリズムは次のKLダイバージェンスを最も小さくする多様体間の2点を探索するアルゴリズムです。

F_{\bm{\xi}}[q(\mathbf{h}|\mathbf{v})]= D_{\text{KL}}(\hat{q}(\mathbf{v})q(\mathbf{h}|\mathbf{v})||p_{\bm{\xi}}(\mathbf{v},\mathbf{h}))=\int \hat{q}(\mathbf{v})q(\mathbf{h}|\mathbf{v}) \log\frac{\hat{q}(\mathbf{v})q(\mathbf{h}|\mathbf{v})}{p_{\bm{\xi}}(\mathbf{v},\mathbf{h})}d\mathbf{v}d\mathbf{h}

Eステップ

EステップではKLダイバージェンスF_{\bm{\xi}}[q(\mathbf{h}|\mathbf{v})]を小さくするようなq(\mathbf{h}|\mathbf{v})を見つけることが目的です。KLダイバージェンスをp_{\bm{\xi}}(\mathbf{v},\mathbf{h})=p_{\bm{\xi}}(\mathbf{h}|\mathbf{v})p_{\bm{\xi}}(\mathbf{v})の関係を使って式展開します。Eステップではパラメータ\bm{\xi}は固定なのでその固定を\bm{\xi}'とすれば、

\begin{align*} F_{\bm{\xi}'}[q(\mathbf{h}|\mathbf{v})] &=\int \hat{q}(\mathbf{v})q(\mathbf{h}|\mathbf{v}) \log\frac{\hat{q}(\mathbf{v})q(\mathbf{h}|\mathbf{v})}{p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v})p_{\bm{\xi}'}(\mathbf{v})}d\mathbf{v}d\mathbf{h}\\ &=\int \hat{q}(\mathbf{v})q(\mathbf{h}|\mathbf{v}) \left[\log\frac{q(\mathbf{h}|\mathbf{v})}{p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v})} + \log\frac{\hat{q}(\mathbf{v})}{p_{\bm{\xi}'}(\mathbf{v})} \right] d\mathbf{v}d\mathbf{h}\\ &=\int \hat{q}(\mathbf{v})q(\mathbf{h}|\mathbf{v}) \log\frac{q(\mathbf{h}|\mathbf{v})}{p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v})} d\mathbf{v}d\mathbf{h} + \int \hat{q}(\mathbf{v})q(\mathbf{h}|\mathbf{v}) \log\frac{\hat{q}(\mathbf{v})}{p_{\bm{\xi}'}(\mathbf{v})} d\mathbf{v}d\mathbf{h} \end{align*}

となり、諸々の式整理は補足にゆだねますが、これは

F_{\bm{\xi}'}[q(\mathbf{h}|\mathbf{v})] =\int q(\mathbf{h}|\mathbf{v}) \log\frac{q(\mathbf{h}|\mathbf{v})}{p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v})} d\mathbf{h} +\int \hat{q}(\mathbf{v}) \log\frac{\hat{q}(\mathbf{v})}{p_{\bm{\xi}'}(\mathbf{v})} d\mathbf{v}

となります。第2項目はq(\mathbf{h}|\mathbf{v})に依存しないので無視します。KLダイバージェンスが最小になるというのは第1項目が小さくなると言いかえられます。明らかに第1項目はq(\mathbf{h}|\mathbf{v})p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v})に関するKLダイバージェンスなので、これが最小になるのは、

q(\mathbf{h}|\mathbf{v})=p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v})

のときです。これはまさしく初めのEステップの説明で紹介した更新式で、モデル多様体M上の一点p_{\bm{\xi}'}(\mathbf{v},\mathbf{h})を観測多様体Dにe-射影する操作となっています。その射影先が上の更新式の通り、\hat{q}(\mathbf{v})p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v})ということです。

補足: KLダイバージェンスの式整理

KLダイバージェンス

F_{\bm{\xi}'}[q(\mathbf{h}|\mathbf{v})]= \int \hat{q}(\mathbf{v})q(\mathbf{h}|\mathbf{v}) \log\frac{q(\mathbf{h}|\mathbf{v})}{p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v})} d\mathbf{v}d\mathbf{h} + \int \hat{q}(\mathbf{v})q(\mathbf{h}|\mathbf{v}) \log\frac{\hat{q}(\mathbf{v})}{p_{\bm{\xi}'}(\mathbf{v})} d\mathbf{v}d\mathbf{h}

について第1項目は、\int \hat{q}(\mathbf{v})d\mathbf{v}=1より、

\begin{align*} \int \hat{q}(\mathbf{v})q(\mathbf{h}|\mathbf{v}) \log\frac{q(\mathbf{h}|\mathbf{v})}{p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v})} d\mathbf{v}d\mathbf{h} &= \int \hat{q}(\mathbf{v})d\mathbf{v}\int q(\mathbf{h}|\mathbf{v}) \log\frac{q(\mathbf{h}|\mathbf{v})}{p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v})} d\mathbf{h}\\ &= \int q(\mathbf{h}|\mathbf{v}) \log\frac{q(\mathbf{h}|\mathbf{v})}{p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v})} d\mathbf{h} \end{align*}

第2項目は、\int q(\mathbf{h}|\mathbf{v})d\mathbf{h}=1より、

\begin{align*} \int \hat{q}(\mathbf{v})q(\mathbf{h}|\mathbf{v}) \log\frac{\hat{q}(\mathbf{v})}{p_{\bm{\xi}'}(\mathbf{v})} d\mathbf{v}d\mathbf{h} &=\int q(\mathbf{h}|\mathbf{v})d\mathbf{h} \int \hat{q}(\mathbf{v}) \log\frac{\hat{q}(\mathbf{v})}{p_{\bm{\xi}'}(\mathbf{v})} d\mathbf{v}\\ &=\int \hat{q}(\mathbf{v}) \log\frac{\hat{q}(\mathbf{v})}{p_{\bm{\xi}'}(\mathbf{v})} d\mathbf{v} \end{align*}

となるので、最終的にKLダイバージェンスは

F_{\bm{\xi}'}[q(\mathbf{h}|\mathbf{v})] =\int q(\mathbf{h}|\mathbf{v}) \log\frac{q(\mathbf{h}|\mathbf{v})}{p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v})} d\mathbf{h} +\int \hat{q}(\mathbf{v}) \log\frac{\hat{q}(\mathbf{v})}{p_{\bm{\xi}'}(\mathbf{v})} d\mathbf{v}

と書けます。

Mステップ

MステップではKLダイバージェンスF_{\bm{\xi}}[q(\mathbf{h}|\mathbf{v})]を小さくするようなパラメータ\bm{\xi}を見つけることが目的です。q(\mathbf{h}|\mathbf{v})に関しては先ほどの更新式の通りp_{\bm{\xi}'}(\mathbf{h}|\mathbf{v})で固定します。ゆえにF_{\bm{\xi}}[p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v})]が小さくなる条件を考えます。まず式展開します。

\begin{align*} F_{\bm{\xi}}[p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v})] &=\int \hat{q}(\mathbf{v})p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v}) \log\frac{\hat{q}(\mathbf{v})p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v})}{p_{\bm{\xi}}(\mathbf{v},\mathbf{h})}d\mathbf{v}d\mathbf{h}\\ &=\int \hat{q}(\mathbf{v})p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v}) \log\hat{q}(\mathbf{v})p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v})d\mathbf{v}d\mathbf{h}-\int \hat{q}(\mathbf{v})p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v}) \log p_{\bm{\xi}}(\mathbf{v},\mathbf{h})d\mathbf{v}d\mathbf{h} \end{align*}

第1項目は\bm{\xi}に依存しないので無視します。第2項目について、経験分布は観測したデータ\mathbf{v}^{(1)},\mathbf{v}^{(2)},\cdots,\mathbf{v}^{(N)}とデルタ関数\deltaを用いて

\hat{q}(\mathbf{v})=\frac{1}{N}\sum_{i=1}^{N}\delta(\mathbf{v}-\mathbf{v}^{(i)})

で定義されるので、

\int \hat{q}(\mathbf{v})p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v}) \log p_{\bm{\xi}}(\mathbf{v},\mathbf{h})d\mathbf{v}d\mathbf{h}=\frac{1}{N}\sum_{n=1}^{N}\int p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v}^{(n)})\log p_{\bm{\xi}}(\mathbf{v}^{(n)},\mathbf{h})d\mathbf{h}

と書けます。これは初めのMステップの説明で挙げた関数Q(\bm{\xi},\bm{\xi}')に他なりません。Q(\bm{\xi},\bm{\xi}')を最大化することが\log p(\mathbf{v})の下限を上昇させることだと説明しましたが、これはKLダイバージェンスを小さくさせることでもあり、\hat{q}(\mathbf{v})p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v})という観測多様体上Dの一点をモデル多様体Mにm-射影する操作だということがわかります。もしQ(\bm{\xi},\bm{\xi}')の最大化が\bm{\xi}=\bm{\xi}^{*}で達成できるならば射影先はp_{\bm{\xi}^{*}}(\mathbf{v},\mathbf{h})です。

EMアルゴリズムの幾何描像

これまでの説明からEステップはe-射影をすること、Mステップはm-射影をする操作であり、観測多様体とモデル多様体間のKLダイバージェンスを小さくするアルゴリズムであることが分かりました。この意味でEMアルゴリズムはemアルゴリズムであり、EステップとMステップを繰り返せば多様体間の最も近い2点を探し出せそうです。EMアルゴリズムの幾何描像を示します。

EMアルゴリズムの幾何描像

最後に

このEMアルゴリズムの話は、技術書典18という技術書が集まる同人即売会で頒布した、以下の頒布物で書こうと思った話題なのですが、ページ数とか執筆時間とかを鑑みて省いたものです。どこにも書かないのはもったいないなぁと思ってZennに投稿したというワケです。

https://techbookfest.org/product/kB5RDArZvWSRhRiPAd5tdW?productVariantID=46diddqVWYDiqixUEFQYrX

https://diceandgeometry.booth.pm/items/6964515

情報幾何学自体は統計モデルの幾何学になりますが、そこには双対平坦というとても綺麗な構造が入り、この構造が展開する織り成す幾何が非常に興味深いのです。もし興味があればお手に取ってみてください!それでは!

参考文献

  1. 斎藤康毅(著)「ゼロから作るDeep Learning ❺ ―生成モデル編」オライリー・ジャパン (2024)
  2. 甘利俊一 (著)「SGCライブラリ-154 新版 情報幾何学の新展開」サイエンス社 (2019)
  3. 村田昇 (著)「SGC Books M-3 新版 情報理論の基礎 -情報と学習の直観的理解のために-」サイエンス社 (2008)
  4. 甘利俊一「めくるめく数理の世界―情報幾何学・人工知能・神経回路網理論」サイエンス社(2024)
脚注
  1. 例えば正規分布だとそのパラメータは平均値や標準偏差になります。平均値や標準偏差が一意に決まれば正規分布も1つに決まります。 ↩︎

  2. \log p_{\bm{\xi}}(\mathbf{v})=\log \int p_{\bm{\xi}}(\mathbf{v},\mathbf{h})d\mathbf{h}の状態で\bm{\xi}で偏微分しても\log p_{\bm{\xi}}(\mathbf{v})を最大化するようなパラメータを解析的に得られませんが、EMアルゴリズムはこれを回避しています。 ↩︎

  3. \int q(\mathbf{h}|\mathbf{v})d\mathbf{h}=1の性質と、p_{\bm{\xi}}(\mathbf{h}|\mathbf{v})=p_{\bm{\xi}}(\mathbf{v},\mathbf{h})/p_{\bm{\xi}}(\mathbf{v})の条件付確率の定義を用います。 ↩︎

  4. \text{grad}_{\bm{\xi}}=\left(\frac{\partial}{\partial \xi_{1}},\frac{\partial}{\partial \xi_{2}},\cdots\right)です。\xi_{1},\xi_{2},\cdots\bm{\xi}の要素です。 ↩︎

  5. Eステップの話に戻ると、あたかもq(\mathbf{h}|\mathbf{v})=p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v})の1回の更新でよさそうに見えますが、この記述を踏まえると正確にはq(\mathbf{h}|\mathbf{v}^{(n)})=p_{\bm{\xi}'}(\mathbf{h}|\mathbf{v}^{(n)})という観測したデータ個数の更新が必要です。 ↩︎

  6. テンソルというと情報科学の分野では多次元配列を指すことが多々ありますが、こちらは数学的厳密性に基づいた多重線形性を有する写像を指しています。 ↩︎

  7. この行列はなんでも良いわけでなく、内積としての定義を満たすため正定値行列が対象です。 ↩︎

  8. 唐突にフィッシャー情報行列を紹介して「なぜ確率分布の空間の内積がフィッシャー情報行列になるんだろう」と思ったかもしれません。これは十分統計量に関する普遍性を要求すると、こうした計量がフィッシャー情報行列で決まります。これをChentsovの定理と言います。 ↩︎

  9. 一般的に接続係数というと\Gamma^k_{ij}を指しますが、幾何的な意味は異ならないので\Gamma_{ijl} = \sum_{k}g_{lk}\Gamma^k_{ij}を接続係数と呼んでいます。 ↩︎

  10. 部分多様体Mとは雑に説明すると、親の多様体Sより次元数が少なくなった多様体です。 ↩︎

  11. 直交するようにというのは直交点で内積(フィッシャー情報行列)が0になることを指します。 ↩︎

Discussion