📝

多変量ガウス分布の条件付き確率をきちんと導出する

2022/07/30に公開約18,400字

ガウス過程回帰モデルを勉強する際に、多変量ガウス分布に従う複数の確率分布についての公式を使うことが多いので、その公式の導出について簡単に紹介します。

公式

ある連続確率ベクトル\boldsymbol{x}_1 \in \mathbb R^{k_1}, \boldsymbol{x}_2\in \mathbb R^{k_2}について、それらの確率密度関数が

\begin{align*} p(\boldsymbol{x}_1) &= \mathcal N(\boldsymbol{x}_1|\boldsymbol{m}_1, \boldsymbol{V}_1), \\ p(\boldsymbol{x}_2) &= \mathcal N(\boldsymbol{x}_2|\boldsymbol{m}_2, \boldsymbol{V}_2), \\ p(\boldsymbol{x}_1, \boldsymbol{x}_2) &= \mathcal N\!\left( \begin{bmatrix}\boldsymbol{x}_1 \\ \boldsymbol{x}_2 \end{bmatrix} \middle| \begin{bmatrix}\boldsymbol{m}_1 \\ \boldsymbol{m}_2 \end{bmatrix} , \begin{bmatrix}\boldsymbol{V}_1 & \boldsymbol{V}_{12} \\ \boldsymbol{V}_{12}^\mathsf{T} & \boldsymbol{V}_2 \end{bmatrix} \right) \end{align*}

で与えられるとき、\boldsymbol{x}_1が定まった条件のもとでの\boldsymbol{x}_2もまたガウス分布に従い、

p(\boldsymbol{x}_2|\boldsymbol{x}_1) = \mathcal N(\boldsymbol x_2| \boldsymbol m_2 + \boldsymbol{V}_{12}^\mathsf{T}\boldsymbol{V}_1^{-1}(\boldsymbol{x}_1 - \boldsymbol{m}_1), \boldsymbol{V}_2 - \boldsymbol{V}_{12}^\mathsf{T}\boldsymbol{V}_1^{-1} \boldsymbol{V}_{12})
\begin{align*} \mathbb E[\boldsymbol{x}_2|\boldsymbol{x}_1] &= \boldsymbol m_2 + \boldsymbol{V}_{12}^\mathsf{T}\boldsymbol{V}_1^{-1}(\boldsymbol{x}_1 - \boldsymbol{m}_1) \\ \mathrm{Cov}[\boldsymbol{x}_2|\boldsymbol{x}_1] &= \boldsymbol{V}_2 - \boldsymbol{V}_{12}^\mathsf{T}\boldsymbol{V}_1^{-1} \boldsymbol{V}_{12} \end{align*}

となる。

確率密度関数

多変量ガウス分布の確率密度関数は次式で与えられます。

\mathcal N(\boldsymbol{x}|\boldsymbol{m},\boldsymbol{V}) = \frac{1}{\sqrt{(2\pi)^D\det \boldsymbol{V}}} \exp\!\left( -\frac{1}{2} (\boldsymbol{x}-\boldsymbol{m})^\mathsf{T} \boldsymbol{V}^{-1} (\boldsymbol{x}-\boldsymbol{m}) \right)

ただし、

\begin{align*} \boldsymbol{x} &\in \mathbb R^D \\ \boldsymbol{m} &\in \mathbb R^D \\ \boldsymbol{V} &\in \mathbb R^{D \times D} \end{align*}

であるとし、また\boldsymbol{V}は半正定値実対称行列とします。

\boldsymbol{m} \in \mathbb R^Dは平均ベクトルで、確率ベクトル\boldsymbol{x}の各要素の平均値を要素に持ちます。\boldsymbol{V} \in \mathbb R^{D\times D}は共分散行列と呼ばれ、\boldsymbol{V}(i,j)成分が\boldsymbol{x}i成分x_ij成分x_jの共分散\mathrm{Cov}[x_i,x_j]を表します。i=jの箇所、すなわち対角成分はx_iの分散\mathbb V[x_i]を表します。それぞれ、

\begin{align*} \mathbb E[\boldsymbol{x}] &= \boldsymbol m \\ \mathrm{Cov}[\boldsymbol{x}] &= \boldsymbol{V} \end{align*}

などと表します。

定数部分1 / \sqrt{(2\pi)^D\det \boldsymbol{V}}は単なる正規化定数です。これを1/Zと置けば、

\mathcal N(\boldsymbol{x}|\boldsymbol{m},\boldsymbol{V}) = \frac{1}{Z} \exp\!\left( -\frac{1}{2} (\boldsymbol{x}-\boldsymbol{m})^\mathsf{T} \boldsymbol{V}^{-1} (\boldsymbol{x}-\boldsymbol{m}) \right)

と書いてしまうことができます。Zは、正規化条件

\int_{\mathbb{R}^D} \mathcal N(\boldsymbol{x}|\boldsymbol{m},\boldsymbol{V}) d\boldsymbol{x} = 1

を用いて

Z = \int_{\mathbb R^D}\exp\!\left( -\frac{1}{2} (\boldsymbol{x}-\boldsymbol{m})^\mathsf{T} \boldsymbol{V}^{-1} (\boldsymbol{x}-\boldsymbol{m}) \right) d\boldsymbol{x}

とし、多変量ガウス積分の公式を適用して計算することができるので、それほど重要ではありません。したがって、ガウス分布の形の決定に本質的に重要な部分は、指数の内側の二次形式の部分

(\boldsymbol{x}-\boldsymbol{m})^\mathsf{T} \boldsymbol{V}^{-1} (\boldsymbol{x}-\boldsymbol{m})

ということになります。これは、\boldsymbol{V}が実対称行列であることから、

(\boldsymbol{x}-\boldsymbol{m})^\mathsf{T} \boldsymbol{V}^{-1} (\boldsymbol{x}-\boldsymbol{m}) = \boldsymbol{x}^\mathsf{T} \boldsymbol{V}^{-1} \boldsymbol{x} -2 \boldsymbol{x}^\mathsf{T} \boldsymbol{V}^{-1} \boldsymbol{m} + \boldsymbol{m}^\mathsf{T} \boldsymbol{V}^{-1} \boldsymbol{m}

と分解することができます。このことを念頭に入れて、導出していきます。

導出

1. 条件付き確率の式

条件付き確率の定義式

p(\boldsymbol{x}_2|\boldsymbol{x}_1) = \frac{ p(\boldsymbol{x}_1, \boldsymbol{x}_2) }{ p(\boldsymbol{x}_1) }

に、

\begin{align*} p(\boldsymbol{x}_1, \boldsymbol{x}_2) &= \mathcal N\!\left( \begin{bmatrix}\boldsymbol{x}_1 \\ \boldsymbol{x}_2 \end{bmatrix} \middle| \begin{bmatrix}\boldsymbol{m}_1 \\ \boldsymbol{m}_2 \end{bmatrix} , \begin{bmatrix} \boldsymbol{V}_1 & \boldsymbol{V}_{12} \\ \boldsymbol{V}_{12}^\mathsf{T} & \boldsymbol{V}_2 \end{bmatrix} \right) \\ &= \mathcal N(\boldsymbol{x}|\boldsymbol{m},\boldsymbol{V}) \\ &= \frac{1}{\sqrt{(2\pi)^{k} \det \boldsymbol{V}}} \exp\!\left( -\frac{1}{2} (\boldsymbol{x}-\boldsymbol{m})^\mathsf{T} \boldsymbol{V}^{-1} (\boldsymbol{x}-\boldsymbol{m}) \right) \end{align*}

および

\begin{align*} p(\boldsymbol{x}_1) &= \mathcal N(\boldsymbol{x}_1|\boldsymbol{m}_1, \boldsymbol{V}_1) \\ &= \frac{1}{\sqrt{(2\pi)^{k_1} \det \boldsymbol{V}_1}} \exp\!\left( -\frac{1}{2} (\boldsymbol{x}_1-\boldsymbol{m}_1)^\mathsf{T} \boldsymbol{V}_1^{-1} (\boldsymbol{x}_1-\boldsymbol{m}_1) \right) \end{align*}

を代入すると、

\begin{align*} p(\boldsymbol{x}_2|\boldsymbol{x}_1) &= \frac{ p(\boldsymbol{x}_1, \boldsymbol{x}_2) }{ p(\boldsymbol{x}_1) } \\ &= \frac{1}{ \sqrt{ (2\pi)^{k-k_1} \frac{\det \boldsymbol{V}}{\det \boldsymbol{V}_1}} } \exp\!\left( -\frac{1}{2} \left\{ (\boldsymbol{x}-\boldsymbol{m})^\mathsf{T} \boldsymbol{V}^{-1} (\boldsymbol{x}-\boldsymbol{m}) - (\boldsymbol{x}_1-\boldsymbol{m}_1)^\mathsf{T} \boldsymbol{V}_1^{-1} (\boldsymbol{x}_1-\boldsymbol{m}_1) \right\} \right) \end{align*}

となります。これの指数部分を抜き出すと、

(\boldsymbol{x}-\boldsymbol{m})^\mathsf{T} \boldsymbol{V}^{-1} (\boldsymbol{x}-\boldsymbol{m}) - (\boldsymbol{x}_1-\boldsymbol{m}_1)^\mathsf{T} \boldsymbol{V}_1^{-1} (\boldsymbol{x}_1-\boldsymbol{m}_1)

となります。よって、これを計算した結果が、条件付き確率の公式

p(\boldsymbol{x}_2|\boldsymbol{x}_1) = \mathcal N(\boldsymbol x_2| \boldsymbol m_2 + \boldsymbol{V}_{12}^\mathsf{T}\boldsymbol{V}_1^{-1}(\boldsymbol{x}_1 - \boldsymbol{m}_1), \boldsymbol{V}_2 - \boldsymbol{V}_{12}^\mathsf{T}\boldsymbol{V}_1^{-1} \boldsymbol{V}_{12})

の指数部分

\left( \boldsymbol{x}_2 - \left( \boldsymbol{m}_2 + \boldsymbol{V}_{12}^\mathsf{T} \boldsymbol{V}_1^{-1}(\boldsymbol{x}_1 - \boldsymbol{m}_1) \right) \right)^\mathsf{T} \boldsymbol{W}_2^{-1} \left( \boldsymbol{x}_2 - \left( \boldsymbol{m}_2 + \boldsymbol{V}_{12}^\mathsf{T} \boldsymbol{V}_1^{-1}(\boldsymbol{x}_1 - \boldsymbol{m}_1) \right) \right)
\boldsymbol{W}_2 = \boldsymbol{V}_2 - \boldsymbol{V}_{12}^\mathsf{T}\boldsymbol{V}_1^{-1} \boldsymbol{V}_{12}

に等しいことを示すのが目標ということになります。


目標:

\begin{align*} &(\boldsymbol{x}-\boldsymbol{m})^\mathsf{T} \boldsymbol{V}^{-1} (\boldsymbol{x}-\boldsymbol{m}) - (\boldsymbol{x}_1-\boldsymbol{m}_1)^\mathsf{T} \boldsymbol{V}_1^{-1} (\boldsymbol{x}_1-\boldsymbol{m}_1) \\ &\quad= \left( \boldsymbol{x}_2 - \left( \boldsymbol{m}_2 + \boldsymbol{V}_{12}^\mathsf{T} \boldsymbol{V}_1^{-1}(\boldsymbol{x}_1 - \boldsymbol{m}_1) \right) \right)^\mathsf{T} \boldsymbol{W}_2^{-1} \left( \boldsymbol{x}_2 - \left( \boldsymbol{m}_2 + \boldsymbol{V}_{12}^\mathsf{T} \boldsymbol{V}_1^{-1}(\boldsymbol{x}_1 - \boldsymbol{m}_1) \right) \right) \end{align*}

ただし

\boldsymbol{W}_2 = \boldsymbol{V}_2 - \boldsymbol{V}_{12}^\mathsf{T}\boldsymbol{V}_1^{-1} \boldsymbol{V}_{12}

2. 精度行列

共分散行列

\boldsymbol{V} = \begin{bmatrix}\boldsymbol{V}_1 & \boldsymbol{V}_{12} \\ \boldsymbol{V}_{12}^\mathsf{T} & \boldsymbol{V}_2 \end{bmatrix}

の逆行列\boldsymbol{V}^{-1}を求めます。まず、両側から行列を掛けてブロック対角化します。

\begin{align*} & \begin{bmatrix} \boldsymbol{I} & \boldsymbol{O} \\ -\boldsymbol{V}_{12}^\mathsf{T} \boldsymbol{V}_1^{-1} & \boldsymbol{I} \end{bmatrix} \begin{bmatrix} \boldsymbol{V}_1 & \boldsymbol{V}_{12} \\ \boldsymbol{V}_{12}^\mathsf{T} & \boldsymbol{V}_2 \end{bmatrix} \begin{bmatrix} \boldsymbol{I} & -\boldsymbol{V}_1^{-1} \boldsymbol{V}_{12} \\ \boldsymbol{O} & \boldsymbol{I} \end{bmatrix} \\ &= \begin{bmatrix} \boldsymbol{V}_1 & \boldsymbol{O} \\ \boldsymbol{O} & \boldsymbol{V}_2 - \boldsymbol{V}_{12}^\mathsf{T} \boldsymbol{V}_1^{-1} \boldsymbol{V}_{12} \end{bmatrix} \\ &= \begin{bmatrix} \boldsymbol{V}_1 & \boldsymbol{O} \\ \boldsymbol{O} & \boldsymbol{W}_2 \end{bmatrix} \end{align*}

両辺の逆行列をとることにより、

\left( \begin{bmatrix} \boldsymbol{I} & \boldsymbol{O} \\ -\boldsymbol{V}_{12}^\mathsf{T} \boldsymbol{V}_1^{-1} & \boldsymbol{I} \end{bmatrix} \begin{bmatrix} \boldsymbol{V}_1 & \boldsymbol{V}_{12} \\ \boldsymbol{V}_{12}^\mathsf{T} & \boldsymbol{V}_2 \end{bmatrix} \begin{bmatrix} \boldsymbol{I} & -\boldsymbol{V}_1^{-1} \boldsymbol{V}_{12} \\ \boldsymbol{O} & \boldsymbol{I} \end{bmatrix} \right)^{-1} = \begin{bmatrix} \boldsymbol{V}_1 & \boldsymbol{O} \\ \boldsymbol{O} & \boldsymbol{W}_2 \end{bmatrix} ^{-1}

を得ます。括弧を外すと、

\begin{bmatrix} \boldsymbol{I} & -\boldsymbol{V}_1^{-1} \boldsymbol{V}_{12} \\ \boldsymbol{O} & \boldsymbol{I} \end{bmatrix} ^{-1} \begin{bmatrix} \boldsymbol{V}_1 & \boldsymbol{V}_{12} \\ \boldsymbol{V}_{12}^\mathsf{T} & \boldsymbol{V}_2 \end{bmatrix} ^{-1} \begin{bmatrix} \boldsymbol{I} & \boldsymbol{O} \\ -\boldsymbol{V}_{12}^\mathsf{T} \boldsymbol{V}_1^{-1} & \boldsymbol{I} \end{bmatrix} ^{-1} = \begin{bmatrix} \boldsymbol{V}_1^{-1} & \boldsymbol{O} \\ \boldsymbol{O} & \boldsymbol{W}_2^{-1} \end{bmatrix}

となります。左辺の中心にある行列が求めたい逆行列です。両辺から行列を掛けて余計なものを消します。

\begin{align*} \begin{bmatrix} \boldsymbol{V}_1 & \boldsymbol{V}_{12} \\ \boldsymbol{V}_{12}^\mathsf{T} & \boldsymbol{V}_2 \end{bmatrix} ^{-1} &= \begin{bmatrix} \boldsymbol{I} & -\boldsymbol{V}_1^{-1} \boldsymbol{V}_{12} \\ \boldsymbol{O} & \boldsymbol{I} \end{bmatrix} \begin{bmatrix} \boldsymbol{V}_1^{-1} & \boldsymbol{O} \\ \boldsymbol{O} & \boldsymbol{W}_2^{-1} \end{bmatrix} \begin{bmatrix} \boldsymbol{I} & \boldsymbol{O} \\ -\boldsymbol{V}_{12}^\mathsf{T} \boldsymbol{V}_1^{-1} & \boldsymbol{I} \end{bmatrix} \\ &= \begin{bmatrix} \boldsymbol{V}_1^{-1} + \boldsymbol{V}_1^{-1} \boldsymbol{V}_{12} \boldsymbol{W}_{2}^{-1} \boldsymbol{V}_{12}^\mathsf{T} \boldsymbol{V}_1^{-1} & -\boldsymbol{V}_1 \boldsymbol{V}_{12} \boldsymbol{W}_2^{-1} \\ -\boldsymbol{W}_2^{-1} \boldsymbol{V}_{12}^\mathsf{T} \boldsymbol{V}_1 & \boldsymbol{W}_2^{-1} \end{bmatrix} \end{align*}

後の計算のために、ここにもうひと工夫入れておきます。

\begin{align*} &= \begin{bmatrix} \boldsymbol{V}_1^{-1} & \boldsymbol{O} \\ \boldsymbol{O} & \boldsymbol{O} \end{bmatrix} + \begin{bmatrix} -\boldsymbol{V}_1^{-1} \boldsymbol{V}_{12} \\ \boldsymbol{I} \end{bmatrix} \boldsymbol{W}_2^{-1} \begin{bmatrix} - \boldsymbol{V}_{12}^\mathsf{T} \boldsymbol{V}_1^{-1} & \boldsymbol{I} \end{bmatrix} \\ &= \begin{bmatrix} \boldsymbol{V}_1^{-1} & \boldsymbol{O} \\ \boldsymbol{O} & \boldsymbol{O} \end{bmatrix} + \begin{bmatrix} -\boldsymbol{V}_1^{-1} \boldsymbol{V}_{12} \\ \boldsymbol{I} \end{bmatrix} \boldsymbol{W}_2^{-1} \begin{bmatrix} -\boldsymbol{V}_1^{-1} \boldsymbol{V}_{12} \\ \boldsymbol{I} \end{bmatrix}^\mathsf{T} \end{align*}

3. 第1項の計算

第1項の

(\boldsymbol{x}-\boldsymbol{m})^\mathsf{T}\boldsymbol{V}^{-1}(\boldsymbol{x}-\boldsymbol{m}) = \begin{bmatrix} \boldsymbol{x}_1 - \boldsymbol{m}_1 \\ \boldsymbol{x}_2 - \boldsymbol{m}_2 \end{bmatrix}^\mathsf{T} \begin{bmatrix} \boldsymbol{V}_1 & \boldsymbol{V}_{12} \\ \boldsymbol{V}_{12}^\mathsf{T} & \boldsymbol{V}_2 \end{bmatrix}^{-1} \begin{bmatrix} \boldsymbol{x}_1 - \boldsymbol{m}_1 \\ \boldsymbol{x}_2 - \boldsymbol{m}_2 \end{bmatrix}

を計算します。先ほど求めた逆行列を代入して計算すると、

\begin{align*} &= \begin{bmatrix} \boldsymbol{x}_1 - \boldsymbol{m}_1 \\ \boldsymbol{x}_2 - \boldsymbol{m}_2 \end{bmatrix}^\mathsf{T} \begin{bmatrix} \boldsymbol{V}_1^{-1} & \boldsymbol{O} \\ \boldsymbol{O} & \boldsymbol{O} \end{bmatrix} \begin{bmatrix} \boldsymbol{x}_1 - \boldsymbol{m}_1 \\ \boldsymbol{x}_2 - \boldsymbol{m}_2 \end{bmatrix} \\ &\quad + \begin{bmatrix} \boldsymbol{x}_1 - \boldsymbol{m}_1 \\ \boldsymbol{x}_2 - \boldsymbol{m}_2 \end{bmatrix}^\mathsf{T} \begin{bmatrix} -\boldsymbol{V}_1^{-1} \boldsymbol{V}_{12} \\ \boldsymbol{I} \end{bmatrix} \boldsymbol{W}_2^{-1} \begin{bmatrix} -\boldsymbol{V}_1^{-1} \boldsymbol{V}_{12} \\ \boldsymbol{I} \end{bmatrix}^\mathsf{T} \begin{bmatrix} \boldsymbol{x}_1 - \boldsymbol{m}_1 \\ \boldsymbol{x}_2 - \boldsymbol{m}_2 \end{bmatrix} \\ &= (\boldsymbol{x}_1 - \boldsymbol{m}_1)^\mathsf{T} \boldsymbol{V}_1^{-1} (\boldsymbol{x}_1 - \boldsymbol{m}_1) \\ &\quad + \left( \begin{bmatrix} -\boldsymbol{V}_1^{-1} \boldsymbol{V}_{12} \\ \boldsymbol{I} \end{bmatrix}^\mathsf{T} \begin{bmatrix} \boldsymbol{x}_1 - \boldsymbol{m}_1 \\ \boldsymbol{x}_2 - \boldsymbol{m}_2 \end{bmatrix} \right)^\mathsf{T} \boldsymbol{W}_2^{-1} \begin{bmatrix} -\boldsymbol{V}_1^{-1} \boldsymbol{V}_{12} \\ \boldsymbol{I} \end{bmatrix}^\mathsf{T} \begin{bmatrix} \boldsymbol{x}_1 - \boldsymbol{m}_1 \\ \boldsymbol{x}_2 - \boldsymbol{m}_2 \end{bmatrix} \end{align*}

となります。

4. 指数部分をまとめる

こうして求まった

\begin{align*} & (\boldsymbol{x}-\boldsymbol{m})^\mathsf{T}\boldsymbol{V}^{-1}(\boldsymbol{x}-\boldsymbol{m}) \\ &= (\boldsymbol{x}_1 - \boldsymbol{m}_1)^\mathsf{T} \boldsymbol{V}_1^{-1} (\boldsymbol{x}_1 - \boldsymbol{m}_1) \\ &\quad + \left( \begin{bmatrix} -\boldsymbol{V}_1^{-1} \boldsymbol{V}_{12} \\ \boldsymbol{I} \end{bmatrix}^\mathsf{T} \begin{bmatrix} \boldsymbol{x}_1 - \boldsymbol{m}_1 \\ \boldsymbol{x}_2 - \boldsymbol{m}_2 \end{bmatrix} \right)^\mathsf{T} \boldsymbol{W}_2^{-1} \begin{bmatrix} -\boldsymbol{V}_1^{-1} \boldsymbol{V}_{12} \\ \boldsymbol{I} \end{bmatrix}^\mathsf{T} \begin{bmatrix} \boldsymbol{x}_1 - \boldsymbol{m}_1 \\ \boldsymbol{x}_2 - \boldsymbol{m}_2 \end{bmatrix} \end{align*}

を、

(\boldsymbol{x}-\boldsymbol{m})^\mathsf{T} \boldsymbol{V}^{-1} (\boldsymbol{x}-\boldsymbol{m}) - (\boldsymbol{x}_1-\boldsymbol{m}_1)^\mathsf{T} \boldsymbol{V}_1^{-1} (\boldsymbol{x}_1-\boldsymbol{m}_1)

に代入すると、代入した式の第2項の部分だけが生き残って、

\begin{align*} & (\boldsymbol{x}-\boldsymbol{m})^\mathsf{T} \boldsymbol{V}^{-1} (\boldsymbol{x}-\boldsymbol{m}) - (\boldsymbol{x}_1-\boldsymbol{m}_1)^\mathsf{T} \boldsymbol{V}_1^{-1} (\boldsymbol{x}_1-\boldsymbol{m}_1) \\ &\quad =\left( \begin{bmatrix} -\boldsymbol{V}_1^{-1} \boldsymbol{V}_{12} \\ \boldsymbol{I} \end{bmatrix}^\mathsf{T} \begin{bmatrix} \boldsymbol{x}_1 - \boldsymbol{m}_1 \\ \boldsymbol{x}_2 - \boldsymbol{m}_2 \end{bmatrix} \right)^\mathsf{T} \boldsymbol{W}_2^{-1} \begin{bmatrix} -\boldsymbol{V}_1^{-1} \boldsymbol{V}_{12} \\ \boldsymbol{I} \end{bmatrix}^\mathsf{T} \begin{bmatrix} \boldsymbol{x}_1 - \boldsymbol{m}_1 \\ \boldsymbol{x}_2 - \boldsymbol{m}_2 \end{bmatrix} \end{align*}

となります。ここで、

\begin{align*} & \begin{bmatrix} -\boldsymbol{V}_1^{-1} \boldsymbol{V}_{12} \\ \boldsymbol{I} \end{bmatrix}^\mathsf{T} \begin{bmatrix} \boldsymbol{x}_1 - \boldsymbol{m}_1 \\ \boldsymbol{x}_2 - \boldsymbol{m}_2 \end{bmatrix} \\ &= -\boldsymbol{V}_{12}^\mathsf{T} \boldsymbol{V}_1^{-1}(\boldsymbol{x}_1 - \boldsymbol{m}_1) + (\boldsymbol{x}_2 - \boldsymbol{m}_2) \\ &= \boldsymbol{x}_2 - \left( \boldsymbol{m}_2 + \boldsymbol{V}_{12}^\mathsf{T} \boldsymbol{V}_1^{-1}(\boldsymbol{x}_1 - \boldsymbol{m}_1) \right) \end{align*}

なので、指数部分は最終的に

\begin{align*} & (\boldsymbol{x}-\boldsymbol{m})^\mathsf{T} \boldsymbol{V}^{-1} (\boldsymbol{x}-\boldsymbol{m}) - (\boldsymbol{x}_1-\boldsymbol{m}_1)^\mathsf{T} \boldsymbol{V}_1^{-1} (\boldsymbol{x}_1-\boldsymbol{m}_1) \\ &\quad =\left( \boldsymbol{x}_2 - \left( \boldsymbol{m}_2 + \boldsymbol{V}_{12}^\mathsf{T} \boldsymbol{V}_1^{-1}(\boldsymbol{x}_1 - \boldsymbol{m}_1) \right) \right)^\mathsf{T} \boldsymbol{W}_2^{-1} \left( \boldsymbol{x}_2 - \left( \boldsymbol{m}_2 + \boldsymbol{V}_{12}^\mathsf{T} \boldsymbol{V}_1^{-1}(\boldsymbol{x}_1 - \boldsymbol{m}_1) \right) \right) \end{align*}

となります。こうして目標は達成されました。やれやれ。

5. (おまけ) 正規化定数部分

正規化条件が使える限りでは、指数の内側の二次形式さえ分かれば正規化定数はどうにでもなるのですが、一応正規化定数をきちんと計算することもできます。

\begin{align*} p(\boldsymbol{x}_2|\boldsymbol{x}_1) &= \frac{ p(\boldsymbol{x}_1, \boldsymbol{x}_2) }{ p(\boldsymbol{x}_1) } \\ &= \frac{1}{ \sqrt{ (2\pi)^{k-k_1} \frac{\det \boldsymbol{V}}{\det \boldsymbol{V}_1}} } \exp\!\left( -\frac{1}{2} \left\{ (\boldsymbol{x}-\boldsymbol{m})^\mathsf{T} \boldsymbol{V}^{-1} (\boldsymbol{x}-\boldsymbol{m}) - (\boldsymbol{x}_1-\boldsymbol{m}_1)^\mathsf{T} \boldsymbol{V}_1^{-1} (\boldsymbol{x}_1-\boldsymbol{m}_1) \right\} \right) \end{align*}

の正規化定数部分

\frac{1}{ \sqrt{ (2\pi)^{k-k_1} \frac{\det \boldsymbol{V}}{\det \boldsymbol{V}_1}} }

において、

k - k_1 = k_1 + k_2 - k_1 = k_2

および、

\begin{align*} \frac{\det \boldsymbol{V}}{\det \boldsymbol{V}_1} &= \frac{1}{\det \boldsymbol{V}_1} \left( \det \boldsymbol{V}_1 \det \boldsymbol{V}_2 - \det \boldsymbol{V}_{12} \det \boldsymbol{V}_{12}^\mathsf{T} \right) \\ &= \det \boldsymbol{V}_2 - \det \boldsymbol{V}_{12}^\mathsf{T} \det \boldsymbol{V}_1^{-1} \det \boldsymbol{V}_{12} \\ &= \det\!\left( \boldsymbol{V}_2 - \boldsymbol{V}_{12}^\mathsf{T} \boldsymbol{V}_1^{-1} \boldsymbol{V}_{12} \right) \end{align*}

により、

\frac{1}{ \sqrt{ (2\pi)^{k-k_1} \frac{\det \boldsymbol{V}}{\det \boldsymbol{V}_1}} } = \frac{1}{ \sqrt{ (2\pi)^{k_2} \det\!\left( \boldsymbol{V}_2 - \boldsymbol{V}_{12}^\mathsf{T} \boldsymbol{V}_1^{-1} \boldsymbol{V}_{12} \right) }}

となります。正規化定数部分もきちんと成り立つことがわかりますね。

6. まとめる

以上をまとめて書くと、

\begin{align*} & p(\boldsymbol{x}_2|\boldsymbol{x}_1) \\ &= \frac{1}{ \sqrt{ (2\pi)^{k_2} \det\!\left( \boldsymbol{V}_2 - \boldsymbol{V}_{12}^\mathsf{T} \boldsymbol{V}_1^{-1} \boldsymbol{V}_{12} \right) }} \\ &\quad \exp\!\left( -\frac{1}{2} \left( \boldsymbol{x}_2 - \left( \boldsymbol{m}_2 + \boldsymbol{V}_{12}^\mathsf{T} \boldsymbol{V}_1^{-1}(\boldsymbol{x}_1 - \boldsymbol{m}_1) \right) \right)^\mathsf{T} \boldsymbol{W}_2^{-1} \left( \boldsymbol{x}_2 - \left( \boldsymbol{m}_2 + \boldsymbol{V}_{12}^\mathsf{T} \boldsymbol{V}_1^{-1}(\boldsymbol{x}_1 - \boldsymbol{m}_1) \right) \right) \right) \\ &= \mathcal N(\boldsymbol x_2| \boldsymbol m_2 + \boldsymbol{V}_{12}^\mathsf{T}\boldsymbol{V}_1^{-1}(\boldsymbol{x}_1 - \boldsymbol{m}_1), \boldsymbol{W}_2) \\ &= \mathcal N(\boldsymbol x_2| \boldsymbol m_2 + \boldsymbol{V}_{12}^\mathsf{T}\boldsymbol{V}_1^{-1}(\boldsymbol{x}_1 - \boldsymbol{m}_1), \boldsymbol{V}_2 - \boldsymbol{V}_{12}^\mathsf{T}\boldsymbol{V}_1^{-1} \boldsymbol{V}_{12}) \end{align*}

であり、

\begin{align*} \mathbb E[\boldsymbol{x}_2|\boldsymbol{x}_1] &= \boldsymbol m_2 + \boldsymbol{V}_{12}^\mathsf{T}\boldsymbol{V}_1^{-1}(\boldsymbol{x}_1 - \boldsymbol{m}_1) \\ \mathrm{Cov}[\boldsymbol{x}_2|\boldsymbol{x}_1] &= \boldsymbol{V}_2 - \boldsymbol{V}_{12}^\mathsf{T}\boldsymbol{V}_1^{-1} \boldsymbol{V}_{12} \end{align*}

ということがわかります。

参考文献

  1. Carl Edward Rasmussen and Christopher K. I. Williams: "Gaussian Process for Machine Learning," A. Mathematical Background, https://gaussianprocess.org/gpml/chapters/, The MIT Press, 2006.

Discussion

ログインするとコメントできます