🫠

Tweedieの公式を証明してみる

に公開

ゼロから作るDeep Learning 5のP. 246にも登場する、拡散モデルの中で重要な働きをするMaurice Tweedieの公式を示してみます。ちょっと証明を検索してみたけれど簡単なものが見つからなかったので、自分で再構築してみました。

示す公式は以下の通り


\boldsymbol{x} \sim \mathcal{N}(\boldsymbol{x} ; \boldsymbol{\mu}, \boldsymbol{\Sigma})

によってサンプル \boldsymbol{x} が得られたとき、次の式が成り立ちます。

\mathbb{E}[\boldsymbol{\mu} \mid \boldsymbol{x}]=\boldsymbol{x}+\boldsymbol{\Sigma} \nabla_{\boldsymbol{x}} \log p(\boldsymbol{x})

前提

標本(サンプル)\boldsymbol{x}は、観測不能な潜在変数である真の平均\boldsymbol{\mu}と分散\boldsymbol{\Sigma}を中心とした多変量ガウス分布\mathcal{N}(\boldsymbol{x};\boldsymbol{\mu},\boldsymbol{\Sigma})から得られるとし、得られた標本\boldsymbol{x}を用いて、\boldsymbol{\mu}を推定したい。\boldsymbol{x}\boldsymbol{x} = \boldsymbol{\mu} + \boldsymbol{\epsilon},\ \boldsymbol{\epsilon}\sim \mathcal{N}(\boldsymbol{0}, \boldsymbol{\Sigma})のようになって観測可能な値になっていると仮定している。\boldsymbol{\Sigma} \nabla_{\boldsymbol{x}} \log p(\boldsymbol{x})は「対数尤度\log p(\boldsymbol{x})の入力\boldsymbol{x}に関する勾配」であり、機械学習、特に拡散モデルの中ではスコア関数とかスコアとか呼ばれることもある。

あと、p(\boldsymbol{x} \mid \boldsymbol{\mu}) = \mathcal{N}(\boldsymbol{x} ; \boldsymbol{\mu}, \boldsymbol{\Sigma})であってるよね?

証明

第1段階

まず多変量ガウス分布の式は以下の通り

\mathcal{N}(\boldsymbol{x};\boldsymbol{\mu},\boldsymbol{\Sigma}) = \frac{1}{(2\pi)^{d/2}|\boldsymbol{\Sigma}|^{1/2}} \exp\left( -\frac{1}{2} (\boldsymbol{x}-\boldsymbol{\mu})^\top\boldsymbol{\Sigma}^{-1}(\boldsymbol{x}-\boldsymbol{\mu})\right)

後ほどこの関数について\boldsymbol{x}についての微分を取ることになるが、後の計算のために、指数関数を直接微分するのではなく、対数を取ってから微分することにする。

\boldsymbol{x}に依存しないところはCとでもまとめておいておく。

\mathcal{N}(\boldsymbol{x};\boldsymbol{\mu},\boldsymbol{\Sigma}) = C\, \exp\left( -\frac{1}{2} (\boldsymbol{x}-\boldsymbol{\mu})^\top\boldsymbol{\Sigma}^{-1}(\boldsymbol{x}-\boldsymbol{\mu}) \right) \text{ where } C=(2\pi)^{-d/2}|\boldsymbol{\Sigma}|^{-1/2}

対数をとる

\log \mathcal{N}(\boldsymbol{x};\boldsymbol{\mu},\boldsymbol{\Sigma}) = \log C -\frac{1}{2} (\boldsymbol{x}-\boldsymbol{\mu})^\top\boldsymbol{\Sigma}^{-1}(\boldsymbol{x}-\boldsymbol{\mu})

これを微分していく。

この勾配(grad)をとった\nabla_{\boldsymbol{x}} \log \mathcal{N}(\boldsymbol{x};\boldsymbol{\mu},\boldsymbol{\Sigma})を求める。\log C部分の勾配は(\boldsymbol{x}に依存しないので)0

\nabla_{\boldsymbol{x}} = \frac{\partial}{\partial \boldsymbol{x}} = \begin{pmatrix} \partial/\partial x_1\\ \vdots\\ \partial/\partial x_d \end{pmatrix}

である。次に、ゼロから作るDeep Learning5の付録Aの(A.10)式あたりの議論またはMatrix cookbookの(85)式あたりを参考にして以下のように計算できる。

\begin{aligned} \nabla_{\boldsymbol{x}} \log \mathcal{N}(\boldsymbol{x};\boldsymbol{\mu},\boldsymbol{\Sigma}) &= -\frac{1}{2} \frac{\partial}{\partial \boldsymbol{x}} \left( (\boldsymbol{x}-\boldsymbol{\mu})^\top\boldsymbol{\Sigma}^{-1}(\boldsymbol{x}-\boldsymbol{\mu})\right) \\ &= -\frac{1}{2}\cdot 2 \boldsymbol{\Sigma}^{-1}(\boldsymbol{x}-\boldsymbol{\mu}) \\ &= -\boldsymbol{\Sigma}^{-1}(\boldsymbol{x}-\boldsymbol{\mu}) \end{aligned}

ここで、\logに関する微分から元の関数に戻すには、一般に関数f(x)について

\nabla_{\boldsymbol{x}} \log f(\boldsymbol{x}) = \frac{1}{f(\boldsymbol{x})} \nabla_{\boldsymbol{x}} f(\boldsymbol{x}) \tag{☆}

であるから、

\nabla_{\boldsymbol{x}} \log \mathcal{N}(\boldsymbol{x};\boldsymbol{\mu},\boldsymbol{\Sigma}) = \frac{1}{\mathcal{N}(\boldsymbol{x};\boldsymbol{\mu},\boldsymbol{\Sigma})} \nabla_{\boldsymbol{x}} \mathcal{N}(\boldsymbol{x};\boldsymbol{\mu},\boldsymbol{\Sigma})

したがって

\nabla_{\boldsymbol{x}} \mathcal{N}(\boldsymbol{x};\boldsymbol{\mu},\boldsymbol{\Sigma}) = -\boldsymbol{\Sigma}^{-1}(\boldsymbol{x}-\boldsymbol{\mu})\mathcal{N}(\boldsymbol{x};\boldsymbol{\mu},\boldsymbol{\Sigma})

が得られる。

第2段階

一方で、周辺分布を書いてみると、

\begin{aligned} \nabla_{\boldsymbol{x}} p(\boldsymbol{x}) &= \int \nabla_{\boldsymbol{x}} p(\boldsymbol{x} \mid \boldsymbol{\mu}) p(\boldsymbol{\mu})\,d\boldsymbol{\mu} \\ &= \int \nabla_{\boldsymbol{x}} \mathcal N(\boldsymbol{x};\boldsymbol{\mu},\boldsymbol{\Sigma})p(\boldsymbol{\mu})\,d\boldsymbol{\mu} \end{aligned}

第1段階の結果を代入すると

\nabla_{\boldsymbol{x}} p(\boldsymbol{x}) = -\boldsymbol{\Sigma}^{-1} \int (\boldsymbol{x}-\boldsymbol{\mu})\mathcal N(\boldsymbol{x};\boldsymbol{\mu},\boldsymbol{\Sigma})p(\boldsymbol{\mu})\,d\boldsymbol{\mu}

が成り立つ。ここで、p(\boldsymbol{\mu})\boldsymbol{\mu}の事前分布である。

また、Bayes 則より

p(\boldsymbol{\mu}\mid \boldsymbol{x}) = \frac{p(\boldsymbol{x}\mid \boldsymbol{\mu})p(\boldsymbol{\mu})}{p(\boldsymbol{x})} = \frac{\mathcal N(\boldsymbol{x};\boldsymbol{\mu},\boldsymbol{\Sigma})p(\boldsymbol{\mu})}{p(\boldsymbol{x})}

よって、先の式の右辺の積分部分は以下のように書き換えられる。

\begin{aligned} \int (\boldsymbol{x}-\boldsymbol{\mu})\mathcal N(\boldsymbol{x};\boldsymbol{\mu},\boldsymbol{\Sigma})p(\boldsymbol{\mu})\,d\boldsymbol{\mu} &= p(\boldsymbol{x})\int (\boldsymbol{x}-\boldsymbol{\mu})p(\boldsymbol{\mu}\mid \boldsymbol{x})d\boldsymbol{\mu} \\ &= p(\boldsymbol{x})\,\mathbb E[\boldsymbol{x}-\boldsymbol{\mu}\mid \boldsymbol{x}] \end{aligned}

ここで、条件付き期待値\mathbb E[\boldsymbol{x}-\boldsymbol{\mu}\mid \boldsymbol{x}]は線形性より以下のように変形できる。

\begin{aligned} \mathbb E[\boldsymbol{x}-\boldsymbol{\mu}\mid \boldsymbol{x}] &= \mathbb{E}[\boldsymbol{x} \mid \boldsymbol{x}] - \mathbb{E}[\boldsymbol{\mu} \mid \boldsymbol{x}] \\ &= \boldsymbol{x} - \mathbb{E}[\boldsymbol{\mu} \mid \boldsymbol{x}] \end{aligned}

\boldsymbol{x}は条件として固定された(=観測済みの)ベクトル)

第3段階

これらをまとめると、以下のようになる。

\begin{aligned} \nabla_{\boldsymbol{x}} p(\boldsymbol{x}) &= - \boldsymbol{\Sigma}^{-1} p(\boldsymbol{x}) \mathbb{E}[\boldsymbol{x}-\boldsymbol{\mu} \mid \boldsymbol{x}] \\ &= - \boldsymbol{\Sigma}^{-1} p(\boldsymbol{x}) \left( \boldsymbol{x} - \mathbb{E}[\boldsymbol{\mu} \mid \boldsymbol{x}] \right) \end{aligned}

両辺に左から\boldsymbol{\Sigma}をかけて、\mathbb{E}[\boldsymbol{\mu} \mid \boldsymbol{x}]について解くと、

\mathbb{E}[\boldsymbol{\mu} \mid \boldsymbol{x}] = \boldsymbol{x} + \frac{1}{p(\boldsymbol{x})} \boldsymbol{\Sigma} \nabla_{\boldsymbol{x}} p(\boldsymbol{x})

これは、対数尤度の勾配(☆)を用いて以下のように書き換えられる。

\mathbb{E}[\boldsymbol{\mu} \mid \boldsymbol{x}] = \boldsymbol{x} + \boldsymbol{\Sigma} \nabla_{\boldsymbol{x}} \log p(\boldsymbol{x})

よって、示したい式が得られた。

1変数ガウス分布のとき

1変数の場合、\boldsymbol{x}x\boldsymbol{\mu}\mu\boldsymbol{\Sigma}\sigma^2に置き換えられば成り立つ。微分のところはより簡単なのでやってみてください。

Discussion