[強化学習][ゼロつく4]Q学習にニューラルネットワーク導入
ゼロつく4の7章の勉強メモです。
本章では、ゼロつく3で登場したDeZeroというフレームワークが使われています。
導入
- 6章までの強化学習では、例えば3x4のグリッドワールドのような状態と行動の数が限られた比較的小規模な問題を扱ってきた。
- しかし、現実世界の問題はもっと複雑。例えばチェスの盤面の状態は10の123乗通りもあり、このような規模になると従来のQテーブルを用いた方法では対応できない。
- そこで登場するのが、Q関数の近似。特に、ディープラーニングを用いてQ関数を近似する方法(深層強化学習)が有力。
多次元配列(テンソル)
機械学習では、特に深層学習では一般的に多次元配列を扱う。
スカラ、ベクトル、行列
多次元配列は、その次元数によって異なる名称で呼ばれる
- スカラ:0次元の配列で、単一の数値を表す。例:5
- ベクトル:1次元の配列で、数値の列を表す。例:[1,2,3]
- 行列:2次元の配列で、数値の表を表す。例:[[1,2],[3,4]]
ベクトルの内積
ベクトルの内積は、2つのベクトル間の要素ごとの積の和として定義される。
2つのベクトル
行列の積
行列の積は、左側の行列の行ベクトルと右側の行列の列ベクトルの内積を要素とする新しい行列を生成する。
例 2×2の行列AとBの積:
実践での応用
機械学習においては、これらの演算が頻繁に使用される。
例えば
- 入力データはベクトルや行列で表される
- ニューラルネットワークの重みは通常行列で表される
- ニューラルネットワークの層を通じてデータを伝播させる際に行列の積が用いられる。
実際にベクトルの内積と行列の積を計算してみる。
import numpy as np
from dezero import Variable
import dezero.functions as F
# ベクトルの内積
a = np.array([1, 2, 3])
b = np.array([4, 5, 6])
a, b = Variable(a), Variable(b) # 省略可能
c = F.matmul(a, b)
print(c)
# 行列の積
a = np.array([[1, 2], [3, 4]])
b = np.array([[5, 6], [7, 8]])
c = F.matmul(a, b)
print(c)
出力結果:
variable(32)
variable([[19 22][43 50]])
VariableはNumPyの多次元配列(np.ndarray)を包み込むクラスで、勾配の保持や微分などの機能が備わっている。
matmul関数はDeZeroの関数のひとつで、受け取った配列の次元に応じて、ベクトルの内積や行列の積を計算する。
引数はvariableインスタンスもしくはnp.ndarrayインスタンスを渡すことができる。
最適化
最適化とは、関数の最小値(もしくは最大値)を取る関数の引数(入力)を見つける作業のことである。
以下の式で表されるローゼンブロック関数の最小値を見つけてみる。
この関数は最小値の探索が難しく、最適化アルゴリズムのベンチマークとしてよく使われる。
この関数の出力が最小となる
DeZeroを使った実装
ローゼンブロック関数を実装し、ある点での勾配を計算してみる。
import numpy as np
from dezero import Variable
def rosenbrock(x0, x1):
y = 100 * (x1 - x0 ** 2) ** 2 + (x0 - 1) ** 2
return y
x0 = Variable(np.array(0.0))
x1 = Variable(np.array(2.0))
y = rosenbrock(x0, x1)
y.backward()
print(x0.grad, x1.grad)
出力結果
variable(-2.0) variable(400.0)
variableで数値データを包むことで、backward関数により微分が計算される。
出力結果より、
勾配降下法
勾配降下法は、勾配の計算とその方向(もしくは逆方向)に進むのを繰り返すことで最大値(最小値)を見つける方法である。
これを用いて最小値を求めてみる。
x0 = Variable(np.array(0.0))
x1 = Variable(np.array(2.0))
lr = 0.001 # 学習率
iters = 10000 # 繰り返す回数
for i in range(iters):
y = rosenbrock(x0, x1)
x0.cleargrad()
x1.cleargrad()
y.backward()
x0.data -= lr * x0.grad.data
x1.data -= lr * x1.grad.data
print(x0, x1)
出力結果
variable(0.9944984367782456) variable(0.9890050527419593)
- 繰り返し更新する回数を
iters
、学習率をlr
として設定している。 -
x0.data -= lr * x0.grad.data
の式では、その場所の勾配の値に学習率を掛け、その逆方向にパラメータを更新している。 -
.data
属性を使わないと、余分なバックプロパゲーションが行われてしまう。(あんま分かってない) - DeZeroは勾配を累積する設計のため、各ステップで勾配をリセットしている。
ローゼンメイデン関数の最小値は(1.0,1.0)なので、おおむね正しい値が得られた。
線形回帰
機械学習の最も基本となる線形回帰を実装する。
- 回帰:実数値yの値をxの値から予測すること
- 線形回帰:モデルを線形として回帰すること
線形回帰モデル
線形回帰は、y = Wx + b という形のモデルを仮定する。ここで、Wとbがモデルのパラメータになる。
実際のデータ点と予測値の差(残差)を最小化するために
損失関数(モデルの悪さを評価する関数)として以下に示す平均二乗誤差を使用する。
目標はこの損失関数が最小となるWとbを見つけることで、これは勾配降下法によって解くことができる。
線形回帰の実装
線形回帰を実装していく。
- 必要なライブラリとデータセットの準備
import numpy as np
from dezero import Variable
import dezero.functions as F
# トイ・データセットの生成
np.random.seed(0)
x = np.random.rand(100, 1)
y = 5 + 2 * x + np.random.rand(100, 1)
x, y = Variable(x), Variable(y) # DeZeroのVariable型に変換
ここでは、NumPyを使用して100個のサンプルを持つ1次元のランダムデータを生成している。
yは線形関係(y = 5 + 2x)にノイズを加えたものになっている。
- モデルの定義
W = Variable(np.zeros((1, 1)))
b = Variable(np.zeros(1))
def predict(x):
y = F.matmul(x, W) + b
return y
パラメータWとbをDeZeroのVariable型で初期化している。
predict関数では、行列の積のmatmal関数を使って計算している。
行列の積を使うことで、複数データに対してまとめて計算することができる。
- 損失関数の実装
def mean_squared_error(x0, x1):
diff = x0 - x1
return F.sum(diff ** 2) / len(diff)
平均二乗誤差(MSE)を計算する関数を定義する。
- 学習ループ
lr = 0.1
iters = 100
for i in range(iters):
y_pred = predict(x)
loss = mean_squared_error(y, y_pred)
W.cleargrad()
b.cleargrad()
loss.backward()
W.data -= lr * W.grad.data
b.data -= lr * b.grad.data
if i % 10 == 0:
print(loss.data)
print('====')
print('W =', W.data)
print('b =', b.data)
このループでは、以下の手順を繰り返す:
- モデルによる予測
- 損失の計算
- 勾配のクリア
- 誤差逆伝播による勾配の計算
- 勾配降下法によるパラメータの更新
また学習率0.1で100回のイテレーションを行い、10回ごとに損失値を出力する。
出力結果
42.296340129442335
0.24915731977561134
0.10078974954301652
0.09461859803040694
0.0902667138137311
0.08694585483964615
0.08441084206493275
0.08247571022229121
0.08099850454041051
0.07987086218625004
====
W = [[2.11807369]]
b = [5.46608905]
出力結果より、損失関数の値が徐々に減っていくことがわかる。
また、図7-9のグラフよりデータに適合したモデルを得られたことがわかる。
Discussion