ガウス過程回帰とGPyのメモ
はじめに
Kindle の日替わりセールで買って積読になっていた『ガウス過程と機械学習』を読んだので,ガウス過程回帰について整理し,GPy で試してみました.この記事の内容は『ガウス過程と機械学習』の 2-4 章に対応しています.
ガウス過程回帰とは回帰分析手法の一種で,非線形な関数関係を表すことができます.ガウス過程回帰ではカーネル関数と呼ばれる類似度を表す関数を設定しますが,このカーネル関数を変えることで無限回微分可能な滑らかな関数やブラウン運動,周期性など様々なものを扱えます.
ガウス分布
平均
で与えられます.平均
多変量ガウス分布
1 変数のガウス分布を拡張した,多変量ガウス分布の確率密度関数は,
で与えられます.ここで,
ガウス過程
入力の集合
ガウス過程は共分散行列
カーネル関数には様々な種類のものがありますが,ここでは RBF カーネル(ガウスカーネル)と呼ばれる次の式を使います.
ガウス過程回帰
ガウス過程回帰では,ガウス過程に基づいて予測分布を計算します.まず,
が与えられているとします.ただし,式を簡単にするために平均を
ここで,行列
予測分布の式からわかるように,ガウス過程回帰では
GPy を用いたガウス過程回帰
Python でガウス過程回帰を行えるライブラリはいくつかありますが,今回は GPy を用いてガウス過程回帰を試してみました.
インストール
GPy は pip でインストールができます.詳細は SheffieldML/GPy をご覧ください.
$ pip install GPy
ガウス過程回帰(入力・出力 1 次元)
今回は適当に作った関数からランダムにサンプリングをすることで学習データを作成し,それらを用いてガウス過程回帰を行います.ただし,グラフを簡単にするために入力・出力の次元は 1 次元とします.
各種ライブラリのインポート
import GPy
import numpy as np
import matplotlib.pyplot as plt
学習データの作成
正弦波を重ね合わせた関数を定義し,ガウス分布
def func(x):
return np.sin(2*np.pi*0.1*x) + np.sin(2*np.pi*0.2*x) + np.sin(2*np.pi*0.3*x)
n_train = 40
x_train = np.random.uniform(-10, 10, n_train)
y_train = func(x_train) + np.random.normal(0, 0.1, n_train)
次の図は func(x)
と
x_true = np.linspace(-12.0,12.0,500)
fig,axes = plt.subplots()
axes.plot(x_true, func(x_true), label="True")
axes.scatter(x_train,y_train, label="Measured")
axes.legend()
plt.show()
カーネルの定義とガウス過程回帰
今回は上述の RBF カーネルを使用します.必須の引数は入力の次元 input_dim
です.ガウス過程回帰は GPy.models.GPRegression()
で行えます.
kern = GPy.kern.RBF(1)
model = GPy.models.GPRegression(x_train.reshape(-1,1), y_train.reshape(-1,1), kern)
model
GPy には可視化用の関数が用意されているのでそれを使います.× は学習データ,青の実線は期待値,薄い青色の範囲は信頼区間を指します.
fig = plt.figure(figsize=(6,8))
fig,axes = plt.subplots()
model.plot(ax=axes)
plt.show()
現状では上手く回帰が出来ていないので,ハイパーパラメータを最適化します.GPy では関数 optimize()
で最適化を行えます.
model.optimize(messages=True)
また,精度の悪い局所解を避けるために複数回の最適化を行う関数 optimize_restarts()
も用意されています.引数 num_restarts
で回数を指定します.今回は結果がほとんど変わらないことから,おそらく最適解が得られていると予想されます.
model.optimize_restarts(num_restarts = 10)
Optimization restart 1/10, f = 11.810668361597326
Optimization restart 2/10, f = 11.810668361508423
Optimization restart 3/10, f = 11.810668361673926
Optimization restart 4/10, f = 11.810668361659776
Optimization restart 5/10, f = 11.81066836149197
Optimization restart 6/10, f = 11.810668361486357
Optimization restart 7/10, f = 11.810668361489427
Optimization restart 8/10, f = 11.810668361494514
Optimization restart 9/10, f = 11.810668361664248
Optimization restart 10/10, f = 11.810668361552
[<paramz.optimization.optimization.opt_lbfgsb at 0x7f43d4819b40>,
<paramz.optimization.optimization.opt_lbfgsb at 0x7f43d4819ae0>,
<paramz.optimization.optimization.opt_lbfgsb at 0x7f43d7b57e20>,
<paramz.optimization.optimization.opt_lbfgsb at 0x7f43d481b010>,
<paramz.optimization.optimization.opt_lbfgsb at 0x7f43d4819bd0>,
<paramz.optimization.optimization.opt_lbfgsb at 0x7f43d48b7580>,
<paramz.optimization.optimization.opt_lbfgsb at 0x7f43d48b7880>,
<paramz.optimization.optimization.opt_lbfgsb at 0x7f43d4818a60>,
<paramz.optimization.optimization.opt_lbfgsb at 0x7f43d4819ba0>,
<paramz.optimization.optimization.opt_lbfgsb at 0x7f43d481a740>,
<paramz.optimization.optimization.opt_lbfgsb at 0x7f43d481bb50>]
次の図は最適化後の結果をプロットしたものです.最適化前と比べると明らかに回帰の精度が向上しています.データ数が少ない範囲では信頼区間が広くなっているので,その部分では自信がないということが明確にわかります.
fig = plt.figure(figsize=(6,8))
fig,axes = plt.subplots()
model.plot(ax=axes)
plt.show()
まとめ
本記事ではガウス過程回帰の概要について整理し,GPy を用いてガウス過程回帰を行いました.機械学習については初心者ですが,『ガウス過程と機械学習』は非常にわかりやすく楽しかったので,今更ですが記事を書いてみました.ガウス過程でここまで複雑なこともできるのかとワクワクしながら読み進めることができました.特にガウス過程とニューラルネットワークの関係性はびっくりしました.もう少し機械学習も勉強してみようという気分になりました.
Discussion