JAX で遊んでみる (1) — 線形回帰
目的
JAX をインストールして少し触ってみたという記録。大体 Linear Regression with JAX に書いているのと同じような内容で、普通に線形回帰をしましたという備忘録。
これだけだと記事がすぐに終わってしまうので、統計学による直接計算や scikit-learn の使用例も交えて水増ししてみた。
おさらい
ディープラーニングは画像分類、画像生成、画像認識や物体検出、自然言語処理など色々ジャンルはあると思うが、基本的には何かしら入力データの空間
-
は物件の家賃であり、\mathcal{Y} = \R^1 は駅からの距離、築年数、部屋が角部屋か否かという情報からなる。\mathcal{X} = \R^3 -
は人物の顔写真であり、\mathcal{Y} = \R^{64\times 64} は正規分布に従うランダムノイズからなる。\mathcal{X} = \R^{128} -
は 3 つの日本語の単語からなる集合で、個々の単語は\mathcal{Y} = \R^{3\times 128} 次元のベクトルで符号化されている。128 は 3 つの英語の単語からなる集合で、個々の単語は\mathcal{X} = \R^{3\times 128} 次元のベクトルで符号化されている。128
色々ある問題の中で、特に教師あり学習と呼ばれるものは、例
- 既知のデータ
に対して\hat{y_i} = f(x_i) が成り立つ。\hat{y}_i \approx y_i - 未知のデータ
に対してx \not\in \{y_i\} は何らかの意味でもっともらしい。f(x) \in \mathcal{Y}
今回は何をする?
大袈裟なことは避け、
データ作成
import numpy as np
import random
import matplotlib.pyplot as plt
xs = np.arange(-5, 15, 0.05)
ys = np.array([x*9/5+32 + random.gauss(0,3) for x in xs])
ys_ideal = np.array([x*9/5+32 for x in xs])
plt.scatter(xs,ys)
plt.plot(xs,ys_ideal, color='red')
plt.xlabel('Celsius')
plt.ylabel('Fahrenheit')
plt.show()
を今回の教師データとする。なお、よく知られているように、摂氏
で与えられ、上図で赤線で引いたものがこの直線に対応する。
統計学で解いてみる
摂氏-華氏のデータセットを
で係数
が解として求まる。念の為に Python で解くと
xs_mean = np.mean(xs)
ys_mean = np.mean(ys)
alpha = np.sum((xs - xs_mean)*(ys - ys_mean))/np.sum((xs - xs_mean)**2)
beta = ys_mean - alpha * xs_mean
print('estimate:', alpha, beta)
print('ideal:', 9/5, 32)
estimate: 1.8103115195284565 32.08597124079442
ideal: 1.8 32
という結果であった。
scikit-learn でも解いてみる
正直、今回程度の問題なら scikit-learn で解くのがベストだと思う。参考までに解いてみよう。
from sklearn.linear_model import LinearRegression
model_lr = LinearRegression()
model_lr.fit(xs.reshape(-1,1), ys.reshape(-1,1))
coef, intercept = model_lr.coef_[0][0], model_lr.intercept_[0]
print('estimate:', coef, intercept)
print('ideal:', 9/5, 32)
estimate: 1.8103115195284571 32.08597124079442
ideal: 1.8 32
簡単であるし、何ら問題はない。
JAX で解いてみる
漸くメインである。そして結果は分かっているのでまったく盛り上がらないが仕方ない。
ところでこの手のフレームワークはデータを正規化しないとうまく結果が得られないことが常なので、標準的な正規化を行いたい。
とする。良い係数
という対応になっている。
まずはデータの正規化を Python で実装しよう:
from jax import grad
import jax.numpy as jnp
xs_std = np.std(xs)
ys_std = np.std(ys)
xs_n = (xs - xs_mean) / xs_std
ys_n = (ys - ys_mean) / ys_std
次に線形回帰モデルを実装する:
def model(params, x):
W, b = params
return x * W + b
def loss(params, x, y):
preds = model(params, x)
return jnp.mean((preds - y)**2)
def update(params, x, y, lr=0.1):
return params - lr * grad(loss)(params, x, y)
ここまでできると、後は訓練ループを回すだけである。今回は何も考えずに 5000 回イテレーションを回す。回した後に得られた “最適値” を (3) 式に基づいて “元に戻す”:
params = jnp.array([0., 0.])
for _ in range(5000):
params = update(params, xs_n, ys_n)
a, b = params
a = a * ys_std / xs_std
b = ys_mean + b * ys_std - xs_mean * ys_std / xs_std
一応、結果を表示すると、
print('estimate:', a, b)
print('ideal:', 9/5, 32)
estimate: 1.8103114 31.748932
ideal: 1.8 32
という感じになる。記念にプロットもしておこう:
plt.scatter(xs,ys)
params = jnp.array([a, b])
plt.plot(xs,model(params,xs), color='red')
plt.xlabel('Celsius')
plt.ylabel('Fahrenheit')
plt.show()
めでたく期待通りの結果が得られた。これでモデルは摂氏と華氏の変換の知識を獲得したことになる。
まとめ
特にとりたてて書くほどのまとめもないが、TensorFlow や PyTorch に比べて質素に書けたように思う。データローダの準備だとかテンソルを GPU に乗せるといったことを意識せずに NumPy のように雑に書いて、Python 的な書き方で訓練ループも回せた。loss
を grad
で包むだけで自動微分が実行されるのも楽で良い。
-
統計学では不偏分散の平方根で不偏標準偏差を求めると思うが、ディープラーニングのコンテキストではそこまでしていないように見えるので、普通に標本分散の平方根による標準偏差を用いる。 ↩︎
Discussion