📘

ベイズ線形回帰

2023/03/11に公開約14,200字

事後予測分布


事後予測分布の計算による回帰直線の推定。実線は平均、エラーバンドは \pm 3 \sigma^2。データから離れると、エラーバンドが若干広くなっており、回帰曲線の不確かさが増していることがわかる。

確率モデルにおいて、出力データ \mathcal D_D が与えられたという条件のもとで、さらに未知のデータ \mathcal D_\ast が従う分布 p(\mathcal D_\ast | \mathcal D_D)事後予測分布 (posterior predictive distribution) という。

この分布は、次式のように、既知のデータ \mathcal D_D から計算されるパラメータ w の事後分布 p(w | \mathcal D_D) に尤度関数 p(\mathcal D_\ast | w) を掛け、さらに w について積分して周辺化したものとして定義される。

\begin{aligned} p(\mathcal D_\ast | \mathcal D_D) \coloneqq \int dw p(\mathcal D_\ast | w) p(w | \mathcal D_D) \end{aligned}

回帰モデルにおいては次のような形をしている。

\begin{aligned} p(\bm y_\ast | \bm X_\ast, \bm X, \bm y) \coloneqq \int d \bm w p(\bm y_\ast | \bm X_\ast, \bm w) p(\bm w | \bm X, \bm y) \end{aligned}

ただし \bm X_\ast \in \R^{D_\ast \times N}, \bm y_\ast \in \R^{D_\ast} は未知の入出力データであり、次のようなものとする。

\begin{aligned} \bm X_\ast ={}& \left[\begin{darray}{c} \bm x_{1} \\ \bm x_{2} \\ \vdots \\ \bm x_{D_\ast} \\ \end{darray}\right] ,& \bm y_\ast ={}& \left[\begin{darray}{c} y_{1} \\ y_{2} \\ \vdots \\ y_{D_\ast} \\ \end{darray}\right] \end{aligned}

事後予測分布を計算することで、出力が未知であるような入力データに対する出力データを、不確かさも含めて分布として推定することができる。そこで、事後予測分布を用いるという方法によって線形回帰を行うものを ベイズ線形回帰 (Bayesian linear regression) などと言う。

ただし、事後予測分布の計算が解析的に可能な場合は限られており、多くの場合はMCMCなどにより近似的に数値計算するのが現実的である。

具体的な計算

具体例を見てみよう。まず、尤度関数を以下のようなものにする。

\begin{aligned} p(\bm y | \bm X, \bm w) = \mathcal N_D(\bm y | \bm X \bm w, \sigma^2 \bm I_D) \end{aligned}

さらに、未知のデータ \bm X_\ast, \bm y_\ast についても、同じ分布に従うものと考え、分布 p(\bm y_\ast | \bm X_\ast, \bm w) を次式で定める。

p(\bm y_\ast | \bm X_\ast, \bm w) = \mathcal N_{D_\ast} (\bm y_\ast | \bm X_\ast \bm w, \sigma^2 \bm I_{D_\ast})

ただし、今回は未知のデータの個数を D_\ast 個として、\bm X_\ast \in \R^{D_\ast \times N}, \bm y_\ast \in \R^{D_\ast} としておこう。

事後分布の計算

続いて、パラメータの事前分布として正規分布を仮定する。

\begin{aligned} p(\bm w) &= \mathcal N_N(\bm w | \bm m_0, \bm V_0) \\ &\propto \exp \left( -\frac{1}{2} (\bm w - \bm m_0)^\mathsf{T} \bm V_0^{-1} (\bm w - \bm m_0) \right) \end{aligned}

この場合、パラメータの事後分布は次のような正規分布になる。

p(\bm w | \bm X, \bm y) = \mathcal N_N(\bm w | \bm m_{w}, \bm V_{w})
\left\lbrace\begin{aligned} \bm m_{w} &= \bm V_{w} \left( \frac{1}{\sigma^2} \bm X^\mathsf{T} \bm y + \bm V_0^{-1} \bm m_0 \right) \\ \bm V_{w}^{-1} &= \frac{1}{\sigma^2} \bm X^\mathsf{T} \bm X + \bm V_0^{-1} \end{aligned}\right.

計算

材料が揃ったので、事後予測分布の計算に取り掛かる。

以下の計算においては、\bm w および \bm y_\ast に関係のない項をすべて \mathrm{const.} として纏めていることに注意。

\begin{aligned} &\hspace{-1pc} p(\bm y_\ast | \bm X_\ast, \bm X, \bm y) \\ ={}& \int d \bm w p(\bm y_\ast | \bm X_\ast, \bm w) p(\bm w | \bm X, \bm y) \\ ={}& \int d \bm w \mathcal N_{D_\ast}(\bm y_\ast | \bm X_\ast \bm w, \sigma^2 \bm I_{D_\ast}) \mathcal N_{n}(\bm w | \bm m_{w}, \bm V_{w}) \\ \propto{}& \int d \bm w \exp \left( -\frac{1}{2\sigma^2} \| \bm y_\ast - \bm X_\ast \bm w \|_2^2 \right) \exp \left( -\frac{1}{2}(\bm w - \bm m_{w}) \bm V_{w}^{-1} (\bm w - \bm m_{w}) \right) \\ ={}& \int d \bm w \exp \left( -\frac{1}{2} \underbrace{\left( \frac{1}{\sigma^2} \| \bm y_\ast - \bm X_\ast \bm w \|_2^2 + (\bm w - \bm m_{w}) \bm V_{w}^{-1} (\bm w - \bm m_{w}) \right)}_{(1)} \right) \\ &\left|\small\quad\begin{aligned} (1) ={}& \frac{1}{\sigma^2} \bm w^\mathsf{T} \bm X_\ast^\mathsf{T} \bm X_\ast \bm w - 2 \frac{1}{\sigma^2} \bm w^\mathsf{T} \bm X_\ast^\mathsf{T} \bm y_\ast + \frac{1}{\sigma^2} \bm y_\ast^\mathsf{T} \bm y_\ast \\ & + \bm w^\mathsf{T} \bm V_{w}^{-1} \bm w - 2 \bm w^\mathsf{T} \bm V_{w}^{-1} \bm m_{w} + \underbrace{\bm m_{w}^\mathsf{T} \bm V_{w}^{-1} \bm m_{w}}_\mathrm{const.} \\ ={}& \bm w^\mathsf{T} \underbrace{\left( \frac{1}{\sigma^2} \bm X_\ast^\mathsf{T} \bm X_\ast + \bm V_{w}^{-1} \right)}_{\bm V_{w_\ast}^{-1}} \bm w - 2 \bm w^\mathsf{T} \underbrace{\left( \frac{1}{\sigma^2} \bm X_\ast^\mathsf{T} \bm y_\ast + \bm V_{w}^{-1} \bm m_{w} \right)}_{\bm V_{w_\ast}^{-1} \bm m_{+}} + \frac{1}{\sigma^2} \bm y_\ast^\mathsf{T} \bm y_\ast + \mathrm{const.} \\ ={}& \bm w^\mathsf{T} \bm V_{w_\ast}^{-1} \bm w - 2 \bm w^\mathsf{T} \bm V_{w_\ast}^{-1} \bm m_{+} + \frac{1}{\sigma^2} \bm y_\ast^\mathsf{T} \bm y_\ast + \mathrm{const.} \\ ={}& (\bm w - \bm m_{+})^\mathsf{T} \bm V_{w_\ast}^{-1} (\bm w - \bm m_{+}) + \underbrace{ \frac{1}{\sigma^2} \bm y_\ast^\mathsf{T} \bm y_\ast - \bm m_{+} \bm V_{w_\ast}^{-1} \bm m_{+} }_{(2)} + \mathrm{const.} \\ (2) ={}& \frac{1}{\sigma^2} \bm y_\ast^\mathsf{T} \bm y_\ast - \bm m_{+} \bm V_{w_\ast}^{-1} \bm m_{+} \\ ={}& \frac{1}{\sigma^2} \bm y_\ast^\mathsf{T} \bm y_\ast - \bm y_\ast^\mathsf{T} \left( \frac{1}{\sigma^2} \right)^2 \bm X_\ast \bm V_{w_\ast} \bm X_\ast^\mathsf{T} \bm y_\ast - 2 \bm y_\ast^\mathsf{T} \frac{1}{\sigma^2} \bm X_\ast \bm V_{w_\ast} \bm V_{w}^{-1} \bm m_{w} + \mathrm{const.} \\ ={}& \bm y_\ast^\mathsf{T} \underbrace{\left( \frac{1}{\sigma^2} \bm I_{D_\ast} - \left( \frac{1}{\sigma^2} \right)^2 \bm X_\ast \bm V_{w_\ast} \bm X_\ast^\mathsf{T} \right)}_{\bm V_{y_\ast}^{-1}} \bm y_\ast - 2 \bm y_\ast^\mathsf{T} \underbrace{\frac{1}{\sigma^2} \bm X_\ast \bm V_{w_\ast} \bm V_{w}^{-1} \bm m_{w}}_{\bm V_{y_\ast}^{-1} \bm m_{y_\ast}} + \mathrm{const.} \\ ={}& \bm y_\ast^\mathsf{T} \bm V_{y_\ast}^{-1} \bm y_\ast - 2 \bm y_\ast^\mathsf{T} \bm V_{y_\ast}^{-1} \bm m_{y_\ast} + \mathrm{const.} \\ ={}& (\bm y_\ast - \bm m_{y_\ast})^\mathsf{T} \bm V_{y_\ast}^{-1} (\bm y_\ast - \bm m_{y_\ast}) + \mathrm{const.} \\ \end{aligned}\right. \\ \propto{}& \int d \bm w \exp \left( - \frac{1}{2} (\bm w - \bm m_{+})^\mathsf{T} \bm V_{w_\ast}^{-1} (\bm w - \bm m_{+}) - \frac{1}{2} (\bm y_\ast - \bm m_{y_\ast})^\mathsf{T} \bm V_{y_\ast}^{-1} (\bm y_\ast - \bm m_{y_\ast}) \right) \\ \propto{}& \int d \bm w \mathcal N_N (\bm w | \bm m_{+}, \bm V_{w_\ast}) \mathcal N_{D_\ast} (\bm y_\ast | \bm m_{y_\ast}, \bm V_{y_\ast}) \\ ={}& \mathcal N_{D_\ast} (\bm y_\ast | \bm m_{y_\ast}, \bm V_{y_\ast}) \\ \end{aligned}

こうして事後予測分布が計算された。

\begin{aligned} p(\bm y_\ast | \bm X_\ast, \bm X, \bm y) = \mathcal N_{D_\ast} (\bm y_\ast | \bm m_{y_\ast}, \bm V_{y_\ast}) \end{aligned}
\left\{\begin{aligned} \bm m_{y_\ast} ={}& \frac{1}{\sigma^2} \bm V_{y_\ast} \bm X_\ast \bm V_{w_\ast} \bm V_{w}^{-1} \bm m_{w} \\ \bm V_{y_\ast}^{-1} ={}& \frac{1}{\sigma^2} \left( \bm I_{D_\ast} - \frac{1}{\sigma^2} \bm X_\ast \bm V_{w_\ast} \bm X_\ast^\mathsf{T} \right) \\ \bm V_{w_\ast}^{-1} ={}& \frac{1}{\sigma^2} \bm X_\ast^\mathsf{T} \bm X_\ast + \bm V_{w}^{-1} \\ \bm m_{w} ={}& \bm V_{w} \left( \frac{1}{\sigma^2} \bm X^\mathsf{T} \bm y + \bm V_0^{-1} \bm m_0 \right) \\ \bm V_{w}^{-1} ={}& \frac{1}{\sigma^2} \bm X^\mathsf{T} \bm X + \bm V_0^{-1} \end{aligned}\right.

Woodburyの公式

事後分布の式の形を更に書き換えていく。次の公式を使う。

\begin{aligned} (\bm A + \bm B \bm D \bm C)^{-1} = \bm A^{-1} - \bm A^{-1} \bm B (\bm D^{-1} + \bm C \bm A^{-1} \bm B) \bm C \bm A^{-1} \end{aligned}
\begin{aligned} \bm A \in{}& \R^{M \times M} \\ \bm B \in{}& \R^{M \times N} \\ \bm C \in{}& \R^{N \times M} \\ \bm D \in{}& \R^{N \times N} \end{aligned}

特に \bm D = \bm I_N の場合には次式に帰着する。

\begin{aligned} (\bm A + \bm B \bm C)^{-1} = \bm A^{-1} - \bm A^{-1} \bm B (\bm I_N + \bm C \bm A^{-1} \bm B) \bm C \bm A^{-1} \end{aligned}

1. 準備

まず \bm V_{w}\bm m_{w}、さらに \bm V_{w_\ast} を計算する。

\begin{aligned} \bm V_{w} ={}& \left( \frac{1}{\sigma^2} \bm X^\mathsf{T} \bm X + \bm V_0^{-1} \right)^{-1} \\ ={}& \bm V_0 - \bm V_0 \bm X^\mathsf{T} (\bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_D)^{-1} \bm X \bm V_0 \\ \\ \bm m_{w} ={}& \bm V_{w} \left( \frac{1}{\sigma^2} \bm X^\mathsf{T} \bm y + \bm V_0^{-1} \bm m_0 \right) \\ ={}& \left( \bm V_0 - \bm V_0 \bm X^\mathsf{T} (\bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_D)^{-1} \bm X \bm V_0 \right) \left( \frac{1}{\sigma^2} \bm X^\mathsf{T} \bm y + \bm V_0^{-1} \bm m_0 \right) \\ ={}& \frac{1}{\sigma^2} \bm V_0 \bm X^\mathsf{T} \underbrace{ \left( \bm I_D - (\bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_D)^{-1} \bm X \bm V_0 \bm X^\mathsf{T} \right) }_{(1)} \bm y \\ & + \left( \bm I_N - \bm V_0 \bm X^\mathsf{T} (\bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_D)^{-1} \bm X \right) \bm m_0 \\ &\left|\small\quad\begin{aligned} (1) ={}& \underbrace{\bm I_D}_{\bm A^{-1}} - ( \underbrace{ \bm X \bm V_0 \bm X^\mathsf{T} }_{ \bm C } + \underbrace{ \sigma^2 \bm I_D }_{\bm D^{-1}} )^{-1} \underbrace{ \bm X \bm V_0 \bm X^\mathsf{T} }_{ \bm C } \\ ={}& ( \underbrace{ \bm I_D }_{\bm A} + \underbrace{ \frac{1}{\sigma^2} \bm X \bm V_0 \bm X^\mathsf{T} }_{ \bm B \bm D \bm C} )^{-1} \\ ={}& \sigma^2 \left( \bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_D \right)^{-1} \\ \end{aligned}\right. \\ ={}& \bm V_0 \bm X^\mathsf{T} \left( \bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_D \right)^{-1} \bm y + \left( \bm I_N - \bm V_0 \bm X^\mathsf{T} (\bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_D)^{-1} \bm X \right) \bm m_0 \\ \\ \bm V_{w_\ast} ={}& \left( \frac{1}{\sigma^2} \bm X_\ast^\mathsf{T} \bm X_\ast + \bm V_{w}^{-1} \right)^{-1} \\ ={}& \bm V_{w} - \bm V_{w} \bm X_\ast^\mathsf{T} (\bm X_\ast \bm V_{w} \bm X_\ast^\mathsf{T} + \sigma^2 \bm I_{D_\ast})^{-1} \bm X_\ast \bm V_{w} \\ \end{aligned}

2. 共分散行列の計算

続いて \bm V_{y_\ast} を計算する。

\begin{aligned} \bm V_{y_\ast} ={}& \sigma^2 \left( \bm I_{D_\ast} - \frac{1}{\sigma^2} \bm X_\ast \bm V_{w_\ast}\bm X_\ast^\mathsf{T} \right)^{-1} \\ ={}& \sigma^2 \left( \bm I_{D_\ast} - \bm X_\ast (-\sigma^2 \bm V_{w_\ast}^{-1} + \bm X_\ast^\mathsf{T} \bm X_\ast)^{-1} \bm X_\ast^\mathsf{T} \right) \\ ={}& \bm X_\ast \underbrace{\left( \bm V_{w_\ast}^{-1} - \frac{1}{\sigma^2} \bm X_\ast^\mathsf{T} \bm X_\ast \right)^{-1}}_{\bm V_{w}} \bm X_\ast^\mathsf{T} + \sigma^2 \bm I_{D_\ast} \\ ={}& \bm X_\ast \bm V_{w} \bm X_\ast^\mathsf{T} + \sigma^2 \bm I_{D_\ast} \\ ={}& \bm X_\ast \bm V_0 \bm X_\ast^\mathsf{T} + \sigma^2 \bm I_{D_\ast} - \bm X_\ast \bm V_0 \bm X^\mathsf{T} ( \bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_D )^{-1} \bm X \bm V_0 \bm X_\ast^\mathsf{T} \end{aligned}

3. 期待値の計算

最後に \bm m_{y_\ast}。これはちょっと面倒くさい。

\begin{aligned} \bm m_{y_\ast} ={}& \frac{1}{\sigma^2} \bm V_{y_\ast} \bm X_\ast \bm V_{w_\ast} \bm V_{w}^{-1} \bm m_{w} \\ &\left|\small\quad\begin{aligned} \bm V_{w_\ast} ={}& \bm V_{w} - \bm V_{w} \bm X_\ast^\mathsf{T} (\underbrace{\sigma^2 \bm I_{D_\ast} + \bm X_\ast \bm V_{w} \bm X_\ast^\mathsf{T}}_{\bm V_{y_\ast}})^{-1} \bm X_\ast \bm V_{w} \\ ={}& \bm V_{w} - \bm V_{w} \bm X_\ast^\mathsf{T} \bm V_{y_\ast}^{-1} \bm X_\ast \bm V_{w} \end{aligned}\right. \\ ={}& \frac{1}{\sigma^2} \bm V_{y_\ast} \bm X_\ast (\bm I_N - \bm V_{w} \bm X_\ast^\mathsf{T} \bm V_{y_\ast}^{-1} \bm X_\ast) \bm m_{w} \\ ={}& \frac{1}{\sigma^2} \bm V_{y_\ast} (\bm I_{D_\ast} - \bm X_\ast \bm V_{w} \bm X_\ast^\mathsf{T} \bm V_{y_\ast}^{-1}) \bm X_\ast \bm m_{w} \\ ={}& \frac{1}{\sigma^2} \bm V_{y_\ast} \underbrace{(\bm V_{y_\ast} - \bm X_\ast \bm V_{w} \bm X_\ast^\mathsf{T})}_{\sigma^2 \bm I_{D_\ast}} \bm V_{y_\ast}^{-1}\bm X_\ast \bm m_{w} \\ ={}& \bm X_\ast \bm m_{w} \\ ={}& \bm X_\ast \bm V_0 \bm X^\mathsf{T} \left( \bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_D \right)^{-1} \bm y \\ & + \left( \bm X_\ast - \bm X_\ast \bm V_0 \bm X^\mathsf{T} (\bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_D)^{-1} \bm X \right) \bm m_0 \\ \end{aligned}

計算の途中で \bm m_{y_\ast} = \bm X_\ast \bm m_{w} というものが出てきた。これは、\bm y = \bm X \bm w + \bm \varepsilon において、

\bm y \to \bm m_{y_\ast}, \quad\bm X \to \bm X_\ast, \quad\bm w \to \bm m_{w}

と書き換えたものに相当する。

結果

ということで、事後予測分布は次のような形で書ける。

\begin{aligned} p(\bm y_\ast | \bm X_\ast, \bm X, \bm y) = \mathcal N_{D_\ast} (\bm y_\ast | \bm m_{y_\ast}, \bm V_{y_\ast}) \end{aligned}
\left\{\begin{aligned} \bm m_{y_\ast} ={}& \bm X_\ast \bm V_0 \bm X^\mathsf{T} \left( \bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_D \right)^{-1} \bm y \\ & + \left( \bm X_\ast - \bm X_\ast \bm V_0 \bm X^\mathsf{T} (\bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_D)^{-1} \bm X \right) \bm m_0 \\ \bm V_{y_\ast} ={}& \bm X_\ast \bm V_0 \bm X_\ast^\mathsf{T} + \sigma^2 \bm I_{D_\ast} - \bm X_\ast \bm V_0 \bm X^\mathsf{T} ( \bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_D )^{-1} \bm X \bm V_0 \bm X_\ast^\mathsf{T} \\ \end{aligned}\right.

まとめ

  • 事後予測分布は、既知データが与えられた条件下での、未知のデータの分布を表す。

  • 線形回帰を事後予測分布を用いて行うものをベイズ線形回帰と呼ぶ。

  • 事後予測分布の計算は、解析的に可能な場合が限られており、多くの場合はMCMCなどで近似的に数値計算される。

  • 正規事前分布を用いる場合、事後予測分布のパラメータは次式で与えられる。

    \left\{\begin{aligned} \bm m_{y_\ast} ={}& \bm X_\ast \bm V_0 \bm X^\mathsf{T} \left( \bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_D \right)^{-1} \bm y \\ & + \left( \bm X_\ast - \bm X_\ast \bm V_0 \bm X^\mathsf{T} (\bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_D)^{-1} \bm X \right) \bm m_0 \\ \bm V_{y_\ast} ={}& \bm X_\ast \bm V_0 \bm X_\ast^\mathsf{T} + \sigma^2 \bm I_{D_\ast} - \bm X_\ast \bm V_0 \bm X^\mathsf{T} ( \bm X \bm V_0 \bm X^\mathsf{T} + \sigma^2 \bm I_D )^{-1} \bm X \bm V_0 \bm X_\ast^\mathsf{T} \\ \end{aligned}\right.

Discussion

ログインするとコメントできます