Chapter 05

リッジ回帰(Ridge Regression)

リッジ回帰

回帰問題。入力xに対する出力yを予測する。

単純な線形回帰にちょっとした修正を加えただけだが、線形回帰に比べて学習が安定する(線形回帰のようにそもそも解けないという場合がなく、過学習が起こりにくい)。

本シリーズで用いる数式の表記一覧

本シリーズでは独自記号として、以下の「列挙(enumerate)」を使う。

\overset{n}{\underset{i=1}{\sf E}} x _ i = x _ 1, x _ 2, \ldots, x _ n

たとえばこれを用いて集合を定義することができる。

\Bigl\{ \overset{n}{\underset{i=1}{\sf E}} x _ i \Bigr\} = \{ x _ 1, x _ 2, \ldots, x _ n \}

モデル

入力x \in \mathcal{X}、特徴量写像\phi(x) \colon \mathcal{X} \to \mathbb{R} ^ n、出力y \in \mathbb{R}の状況。パラメータはw \in \mathbb{R} ^ n

出力とパラメータの同時確率分布を以下のように仮定する。

p(y,w|x) = p(y|w,x)p(w) \tag{1}

ただし出力yは平均w ^ \mathrm{T} \phi(x)、分散\sigma ^ 2の正規分布に従い、パラメータwは各要素が平均0、精度\lambda(分散\lambda ^ {-1})の正規分布に従うものとする。

\begin{aligned} & p(y|w,x)= \mathcal{N}(y | w ^ \mathrm{T} \phi(x), \sigma ^ 2) \\ & p(w) = \mathcal{N}(w | 0, \lambda ^{-1} I_n) = \prod _ {i=1} ^ n \mathcal{N}(w _ i | 0, \lambda ^ {-1}) \end{aligned} \tag{2}

ここでI _ nn \times nの単位行列。

出力の予測値\hat{y} \in \mathbb{R}は以下で表現する。

\hat{y} = w ^ \mathrm{T} \phi(x)

学習

独立同分布(i.i.d.)なN個のデータを観測したデータセット

\mathcal{D} = \Bigl\{ \overset{N}{\underset{k=1}{\sf E}} (x ^ {(k)}, y ^ {(k)}) \Bigr\}

が与えられたもとでの同時確率密度の最大化、または事後確率の最大化(MAP推定)によってwを求める。これら二つの方法は以下で示されるように等価である。

まず、条件付き確率の定義により

p(y,w|x) = p(w | y, x)p(y | x)

である。ここでp(y|x)p(y,w|x)からwを積分消去した周辺確率分布であるから、

p(y|x) = \int p(y,w | x) dw

である。ここで(1)式と合わせれば、

p(w |y, x) = \frac{p(y,w|x)}{p(y|x)} = \frac{p(y|w,x)p(w)}{\int p(y,w | x) dw}

となる(Bayesの定理)。上式で表されるp(w | y, x)wの事後確率分布と呼ばれる。一見いかめしい形になってしまったが、分母はデータセットのみから定まる周辺尤度(marginal likelihood)で定数なので今回で行う最適化には影響を与えない。分子は同時確率分布であるから、同時確率密度の最大化と、事後確率密度の最大化は等価となる。

データセットがすべて観測されたときの同時確率密度は

F(w) = p(w) \prod _ {k=1} ^ N p(y ^ {(k)}| w, x ^ {(k)})

である。MAP推定の文脈で、上式を周辺尤度で割った

G(w) = \frac{F(w)}{\int p(y,w | x) dw} = \frac{p(w) \prod _ {k=1} ^ N p(y ^ {(k)}|x ^ {(k)}, w)}{\int p(y,w | x) dw}

は、悪しき伝統で「事後確率」と呼ばれる。実のところ、連続型確率分布を用いているので、厳密にはこれは確率ではない(1よりも大きい値を取りうる)。本来ならば「事後確率密度」と呼ぶべきであろう。

ともかく、F(w)を最大化するようなwを求める。例のごとく対数を取った最大化問題を解く。ただし、計算を簡略化するため、(2)式のp(w)で見たように、互いに独立なガウス分布と多変量ガウス分布の関係

\prod _ {k=1} ^ N \mathcal{N}(y ^ {(k)} | w ^ \mathrm{T} \phi (x ^ {(k)}), \sigma ^ 2) = \mathcal{N}(y | \Phi ^\mathrm{T} w, \sigma ^ 2 I _ N)

を早々に用いて計算を簡略化する。ただし\Phi, y, wは線形回帰で使用したものと同じである。

\begin{aligned} &\underset{w \in \mathbb{R ^ n}}{\operatorname{arg} \operatorname{max}} \log F(w) \\ =& \, \underset{w \in \mathbb{R ^ n}}{\operatorname{arg} \operatorname{max}} \sum _ {k=1} ^ N \log \mathcal{N}(y ^ {(k)} | w ^ \mathrm{T} \phi (x ^ {(k)}), \sigma ^ 2) + \log \mathcal{N}(w | 0, \lambda ^ {-1} I _ n) \\ =& \, \underset{w \in \mathbb{R ^ n}}{\operatorname{arg} \operatorname{max}} \log \mathcal{N}(y | \Phi ^\mathrm{T} w, \sigma ^ 2 I _ N) + \log \mathcal{N}(w | 0, \lambda ^ {-1} I _ n) \\ =& \, \underset{w \in \mathbb{R ^ n}}{\operatorname{arg} \operatorname{max}} - \frac{1}{2 \sigma ^ 2}(y - \Phi ^ \mathrm{T} w) ^ \mathrm{T} (y - \Phi ^ \mathrm{T} w) - \frac{\lambda}{2} w ^ \mathrm{T} w + C \\ =& \, \underset{w \in \mathbb{R ^ n}}{\operatorname{arg} \operatorname{min}} \| y - \Phi ^ \mathrm{T} w \| _ 2 ^ 2 + \frac{\lambda}{\sigma ^ 2} \| w \| _ 2 ^ 2 \end{aligned}

\lambdaは人間が勝手に事前知識として与えたパラメータなので、最初からwの精度パラメータが\lambda = \lambda ^ \prime / \sigma ^ 2だったことにすれば上式は

\, \underset{w \in \mathbb{R ^ n}}{\operatorname{arg} \operatorname{min}} \| y - \Phi ^ \mathrm{T} w \| _ 2 ^ 2 + \lambda \| w \| _ 2 ^ 2 \tag{3}

と書ける。以上により、リッジ回帰の最適化問題が導出される。\lambda \| w \| _ 2 ^ 2は学習を安定させる作用があり、正則化項と呼ばれる。

勾配法などで解いてもよいが、\nabla f(w) = 0の微分方程式を直接解くことができる。解けない場合がある単純な線形回帰とは異なり、必ず解けることに注意する。

\nabla f(w) = - 2 \Phi (y - \Phi ^ \mathrm{T} w) + 2 \lambda w

であるから、\nabla f(w) = 0のとき、

\begin{aligned} &\Phi (y - \Phi ^ \mathrm{T} w) - \lambda w = 0 \\ & \therefore \Phi y = \Phi \Phi ^ \mathrm{T} w + \lambda w \\ & \therefore \Phi y = (\Phi \Phi ^ \mathrm{T} + \lambda I _ n) w \\ & \therefore w = (\Phi \Phi ^ \mathrm{T} + \lambda I _ n) ^ {-1} \Phi y \end{aligned}

としてwが求まる。第3式から第4式への変形には(\Phi \Phi ^ \mathrm{T} + \lambda I _ n)が正則であることが必要だが、\Phi \Phi ^ \mathrm{T}が半正定値、\lambda I _ n\lambda > 0のとき正定値だから、(\Phi \Phi ^ \mathrm{T} + \lambda I _ n)は正定値となり正則である(任意の正定値行列は正則である)。したがってリッジ回帰の最適化問題は必ず解ける。

\lambdawの分散の逆数であるから「各要素の値がどれくらい平均の近くに集まるか」を表すパラメータである。(3)式の最適化問題では、\lambdaが大きくなるほど\lambda \| w \|が支配的になり、この項を小さくしようとする。つまりどんなオプティマイザを用いて解いても、そのアルゴリズムは\lambda \to +\inftyのとき\| w \| \to 0となるように振る舞い、\lambdaが大きくなるほどwの各要素の値は0付近に集中する。

過学習が起こるとき、wの各要素はデータに適合させるために絶対値が非常に大きくなる傾向があるため、リッジ回帰の正則化項を付け加えることで過学習が起こりにくくなる。

scikit-learn

sklearn.linear_model.Ridge

実装

import numpy as np

# あとで書く