JAXを使いほぼ1からニューラルネットワークを作る
はじめに
JAXは、numpyのように使えるAPIを持った、自動微分+XLAライブラリです。
これをベース(あるいはバックエンド)にして、ニューラルネットワーク・ディープラーニングライブラリがいくつか作られています。
例として、DeepMindのHaikuやGoogleのFlax、そして何度か記事を書いているTraxが挙げられます。
使用例としては、執筆時の数日前にDeepMindのAlphaFold2の実装が公開されて祭りになっていますが、これにHaikuが使われています。
というわけで(?)、今回は上記のライブラリを使わずに、JAXでほぼ1からニューラルネットワークを作ってみたいと思います。
なお、JAXにはstaxというニューラルネットワークモジュールがあるのですが、今回はこれも使いません。
ゴール
以下が動くように作ります。
if __name__ == '__main__':
rng = jax.random.PRNGKey(0)
rng1, rng2 = jax.random.split(rng)
rng1w, rng1b = jax.random.split(rng1)
rng2w, rng2b = jax.random.split(rng2)
params = {
"linear1": {
"W": glorot_normal()(rng1w, (4, 100)),
"b": normal()(rng1b, (100,))
},
"linear2": {
"W": glorot_normal()(rng2w, (100, 3)),
"b": normal()(rng2b, (3,))
}
}
iris_dataset = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris_dataset['data'], iris_dataset['target'], test_size=0.25, random_state=0)
y_train = np.eye(3)[y_train]
params = train(params, X_train, y_train, 100)
acc = accuracy(params, X_test, y_test)
print(acc)
タスクはよくあるirisデータセットを用いた分類です。
これはscikit-learnのやつをそのまま使っています。
作成するニューラルネットワークは2層(数え方によっては3層)ニューラルネットワークです。
初期値は横着して、jax.nn.initializers
にあるglorot_normal, normal
をそのまま使っています。
初期化にPRNG keyを使うので、重みとバイアスの数だけ異なるkeyが得られるよう、適宜jax.random.split
しています。
あとはparams
と訓練データをtrain
関数にわたし、100エポック訓練したのち、更新されたparams
を受け取ります。
そのparams
とテストデータをaccuracy
関数に入れるとaccuracyが返ってきます。
というように使えるtrain
とaccuracy
を実装しましょう。
モデル
肝心の「モデル」の実装はこんな感じです。
import jax.numpy as jnp
from jax import nn
import jax
@jax.jit
def linear(x, W, b):
return jnp.dot(x, W) + b
@jax.jit
def Linear(params, inputs):
outputs = linear(inputs, params["W"], params["b"])
return outputs
@jax.jit
def MLP(params, inputs):
linear1_out = Linear(params["linear1"], inputs)
linear2_in = nn.relu(linear1_out)
linear2_out = Linear(params["linear2"], linear2_in)
logits = nn.softmax(linear2_out)
return logits
JAXのnumpy APIはjnp
として読み込むことが多いです。
活性化関数はJAXのnnモジュールにあるものをそのまま使いました。
あとはあまり解説することはないです。
JITはとりあえずつけています。
学習
lossはCategorical Cross Entropyを使います。
@jax.jit
def categorical_cross_entropy_loss(logits, onehot_labels):
return jnp.mean(-jnp.sum(onehot_labels * jnp.log(logits), axis=1))
オプティマイザーはSGDを使います。愚直に書いて
@jax.jit
def SGD(params, grad, lr = 0.1):
params["linear1"]["W"] -= lr * grad["linear1"]["W"]
params["linear2"]["W"] -= lr * grad["linear2"]["W"]
params["linear1"]["b"] -= lr * grad["linear1"]["b"]
params["linear2"]["b"] -= lr * grad["linear2"]["b"]
return params
こんな感じにしました。
これらを用いて、まずバッチごとのparamsの更新関数を以下のように実装します。
@jax.jit
def train_batch(params, batch_X, batch_y):
def loss_fn(params_, batch_X_):
logits = MLP(params_, batch_X_)
return categorical_cross_entropy_loss(logits, batch_y)
grad = jax.jit(jax.grad(loss_fn))(params, batch_X)
return SGD(params, grad)
JAXではjax.grad
を用いて勾配を計算します。
そして計算した勾配を用いてSGDでparamsをupdateします。
これを各バッチで呼び出します。
これが1エポックです。
def train_one_epoch(rng, params, X_train, y_train, batch_size = 50):
index = rng.permutation(X_train.shape[0])
num_batches = X_train.shape[0] // batch_size + 1
for batch in range(num_batches):
batch_index = index[batch * batch_size: (batch + 1) * batch_size]
params = train_batch(params, X_train[batch_index], y_train[batch_index])
return params
rng
はnumpy.random.RandomState(0)
です。
データのシャッフルのために使っていますが、これのせいでJITにできなかったり、バッチでの学習をjax.lax.fori_loop
にできないので、できればjax.random
のものを使った方が良いのかなと思います。
これを用いてtrain
関数は以下のようにかけます。
@jax.jit
def train(params, X_train, y_train, epochs: int):
rng = npr.RandomState(0)
params = jax.lax.fori_loop(0, epochs, lambda epoch_, params_: train_one_epoch(rng, params_, X_train, y_train), params)
return params
train_one_epoch
関数は、for文ではなく、jax.lax.fori_loop
を使って呼び出すことで高速化しています。
これはドキュメントにある通り、セマンティクスとしては
def fori_loop(lower, upper, body_fun, init_val):
val = init_val
for i in range(lower, upper):
val = body_fun(i, val)
return val
と同じです。ただ速いです。
これでtrain
関数が実装できました。
評価
評価は、学習したparamsとテストデータを入力としてaccuraryを計算します。
@jax.jit
def accuracy(params, X_test, y_test):
pred = MLP(params, X_test)
return jnp.mean(y_test == jnp.argmax(pred, axis=1))
これで全ての実装が終わりました。
動かしてみると、accuracyは約92%でした。ちょっと低いですね・・・
JAX全然分からん
というわけでJAXを用いてほぼ1からニューラルネットワークを作成したのですが、著者にはまだ全然分からないことが結構あります。
JITは付けられるだけつければ良いのでしょうか?
staxを見るとあまり使っていないように見えます。
Flaxだとそれなりに使っているようです。
また、staxにしろ、Haikuにしろ、Flaxにしろ、Traxにしろ、あまりfori_loopは使われていません。
batchやepochの繰り返しで一番使いそうな気がするのですが・・・
追記
いくつか修正して、numpyをそのまま使っていたところを全てJAXに置き換えました。
まず
X_train, X_test, y_train, y_test = train_test_split(jnp.array(iris_dataset['data']), jnp.array(iris_dataset['target']), test_size=0.25, random_state=0)
y_train = jnp.eye(3)[y_train]
として、データをjnp.array
にしました。
次に、train
関数とtrain_one_epoch
関数で、データセットのシャッフルにJAXを使用することにしました。
@jax.jit
def train_one_epoch(rng, params, X_train, y_train, batch_size = 50):
index = jax.random.permutation(rng, X_train.shape[0])
num_batches = X_train.shape[0] // batch_size + 1
for batch in range(num_batches):
batch_index = index[batch * batch_size: (batch + 1) * batch_size]
params = train_batch(params, X_train[batch_index], y_train[batch_index])
return params
@jax.jit
def train(params, X_train, y_train, epochs: int):
rng = jax.random.PRNGKey(42)
params = jax.lax.fori_loop(0, epochs, lambda epoch_, params_: train_one_epoch(rng, params_, X_train, y_train), params)
return params
これで、train_batch
もfori_loop
で呼び出せると思ったのですが、インデックスが動的に変わるのがダメっぽいです。
jax.lax.dynamic_slice
を使うといけるかもしれませんが、バッチサイズが固定ではない(端数のバッチがある)ので、使えなそうです。
割り切れるバッチサイズにすれば、固定にできるので、fori_loop
が使えるかもしれませんが、そのためにバッチサイズを弄る必要があるというのも・・・?
また、SGDはjax.tree_multimap
を使いました。
@jax.jit
def SGD(params, grads, lr = 0.1):
return jax.tree_multimap(lambda param, grad: param - lr * grad, params, grads)
というわけで、修正版のコードは以下のgistにあります。
リファレンス
本記事はstaxの実装とそのサンプルを参考に作成しました。
またparamsの形式はHaikuを参考にしました。
@software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/google/jax},
version = {0.2.5},
year = {2018},
}
@software{haiku2020github,
author = {Tom Hennigan and Trevor Cai and Tamara Norman and Igor Babuschkin},
title = {{H}aiku: {S}onnet for {JAX}},
url = {http://github.com/deepmind/dm-haiku},
version = {0.0.3},
year = {2020},
}
Discussion