🔖

JAXを使いほぼ1からニューラルネットワークを作る

2021/07/24に公開

はじめに

JAXは、numpyのように使えるAPIを持った、自動微分+XLAライブラリです。
これをベース(あるいはバックエンド)にして、ニューラルネットワーク・ディープラーニングライブラリがいくつか作られています。
例として、DeepMindのHaikuやGoogleのFlax、そして何度か記事を書いているTraxが挙げられます。
使用例としては、執筆時の数日前にDeepMindのAlphaFold2の実装が公開されて祭りになっていますが、これにHaikuが使われています。
というわけで(?)、今回は上記のライブラリを使わずに、JAXでほぼ1からニューラルネットワークを作ってみたいと思います。
なお、JAXにはstaxというニューラルネットワークモジュールがあるのですが、今回はこれも使いません。

ゴール

以下が動くように作ります。

if __name__ == '__main__':
    rng = jax.random.PRNGKey(0)
    rng1, rng2 = jax.random.split(rng)
    rng1w, rng1b = jax.random.split(rng1)
    rng2w, rng2b = jax.random.split(rng2)
    params = {
        "linear1": {
            "W": glorot_normal()(rng1w, (4, 100)),
            "b": normal()(rng1b, (100,))
        },
        "linear2": {
            "W": glorot_normal()(rng2w, (100, 3)),
            "b": normal()(rng2b, (3,))
        }
    }
    iris_dataset = load_iris()
    X_train, X_test, y_train, y_test = train_test_split(iris_dataset['data'], iris_dataset['target'], test_size=0.25,  random_state=0)
    y_train = np.eye(3)[y_train]
    params = train(params, X_train, y_train, 100)
    acc = accuracy(params, X_test, y_test)
    print(acc)

タスクはよくあるirisデータセットを用いた分類です。
これはscikit-learnのやつをそのまま使っています。
作成するニューラルネットワークは2層(数え方によっては3層)ニューラルネットワークです。
初期値は横着して、jax.nn.initializersにあるglorot_normal, normalをそのまま使っています。
初期化にPRNG keyを使うので、重みとバイアスの数だけ異なるkeyが得られるよう、適宜jax.random.splitしています。
あとはparamsと訓練データをtrain関数にわたし、100エポック訓練したのち、更新されたparamsを受け取ります。
そのparamsとテストデータをaccuracy関数に入れるとaccuracyが返ってきます。

というように使えるtrainaccuracyを実装しましょう。

モデル

肝心の「モデル」の実装はこんな感じです。

import jax.numpy as jnp
from jax import nn
import jax


@jax.jit
def linear(x, W, b):
    return jnp.dot(x, W) + b


@jax.jit
def Linear(params, inputs):
    outputs = linear(inputs, params["W"], params["b"])
    return outputs


@jax.jit
def MLP(params, inputs):
    linear1_out = Linear(params["linear1"], inputs)
    linear2_in = nn.relu(linear1_out)
    linear2_out = Linear(params["linear2"], linear2_in)
    logits = nn.softmax(linear2_out)
    return logits

JAXのnumpy APIはjnpとして読み込むことが多いです。
活性化関数はJAXのnnモジュールにあるものをそのまま使いました。
あとはあまり解説することはないです。
JITはとりあえずつけています。

学習

lossはCategorical Cross Entropyを使います。

@jax.jit
def categorical_cross_entropy_loss(logits, onehot_labels):
    return jnp.mean(-jnp.sum(onehot_labels * jnp.log(logits), axis=1))

オプティマイザーはSGDを使います。愚直に書いて

@jax.jit
def SGD(params, grad, lr = 0.1):
    params["linear1"]["W"] -= lr * grad["linear1"]["W"]
    params["linear2"]["W"] -= lr * grad["linear2"]["W"]
    params["linear1"]["b"] -= lr * grad["linear1"]["b"]
    params["linear2"]["b"] -= lr * grad["linear2"]["b"]
    return params

こんな感じにしました。
これらを用いて、まずバッチごとのparamsの更新関数を以下のように実装します。

@jax.jit
def train_batch(params, batch_X, batch_y):
    def loss_fn(params_, batch_X_):
        logits = MLP(params_, batch_X_)
        return categorical_cross_entropy_loss(logits, batch_y)
    grad = jax.jit(jax.grad(loss_fn))(params, batch_X)
    return SGD(params, grad)

JAXではjax.gradを用いて勾配を計算します。
そして計算した勾配を用いてSGDでparamsをupdateします。
これを各バッチで呼び出します。
これが1エポックです。

def train_one_epoch(rng, params, X_train, y_train, batch_size = 50):
    index = rng.permutation(X_train.shape[0])
    num_batches = X_train.shape[0] // batch_size + 1
    for batch in range(num_batches):
        batch_index = index[batch * batch_size: (batch + 1) * batch_size]
        params = train_batch(params, X_train[batch_index], y_train[batch_index])
    return params

rngnumpy.random.RandomState(0)です。
データのシャッフルのために使っていますが、これのせいでJITにできなかったり、バッチでの学習をjax.lax.fori_loopにできないので、できればjax.randomのものを使った方が良いのかなと思います。

これを用いてtrain関数は以下のようにかけます。

@jax.jit
def train(params, X_train, y_train, epochs: int):
    rng = npr.RandomState(0)
    params = jax.lax.fori_loop(0, epochs, lambda epoch_, params_: train_one_epoch(rng, params_, X_train, y_train), params)
    return params

train_one_epoch関数は、for文ではなく、jax.lax.fori_loopを使って呼び出すことで高速化しています。
これはドキュメントにある通り、セマンティクスとしては

def fori_loop(lower, upper, body_fun, init_val):
  val = init_val
  for i in range(lower, upper):
    val = body_fun(i, val)
  return val

と同じです。ただ速いです。

これでtrain関数が実装できました。

評価

評価は、学習したparamsとテストデータを入力としてaccuraryを計算します。

@jax.jit
def accuracy(params, X_test, y_test):
    pred = MLP(params, X_test)
    return jnp.mean(y_test == jnp.argmax(pred, axis=1))

これで全ての実装が終わりました。

動かしてみると、accuracyは約92%でした。ちょっと低いですね・・・

JAX全然分からん

というわけでJAXを用いてほぼ1からニューラルネットワークを作成したのですが、著者にはまだ全然分からないことが結構あります。
JITは付けられるだけつければ良いのでしょうか?
staxを見るとあまり使っていないように見えます。
Flaxだとそれなりに使っているようです。
また、staxにしろ、Haikuにしろ、Flaxにしろ、Traxにしろ、あまりfori_loopは使われていません。
batchやepochの繰り返しで一番使いそうな気がするのですが・・・

追記

いくつか修正して、numpyをそのまま使っていたところを全てJAXに置き換えました。
まず

    X_train, X_test, y_train, y_test = train_test_split(jnp.array(iris_dataset['data']), jnp.array(iris_dataset['target']), test_size=0.25,  random_state=0)
    y_train = jnp.eye(3)[y_train]

として、データをjnp.arrayにしました。

次に、train関数とtrain_one_epoch関数で、データセットのシャッフルにJAXを使用することにしました。

@jax.jit
def train_one_epoch(rng, params, X_train, y_train, batch_size = 50):
    index = jax.random.permutation(rng, X_train.shape[0])
    num_batches = X_train.shape[0] // batch_size + 1
    for batch in range(num_batches):
        batch_index = index[batch * batch_size: (batch + 1) * batch_size]
        params = train_batch(params, X_train[batch_index], y_train[batch_index])
    return params


@jax.jit
def train(params, X_train, y_train, epochs: int):
    rng = jax.random.PRNGKey(42)
    params = jax.lax.fori_loop(0, epochs, lambda epoch_, params_: train_one_epoch(rng, params_, X_train, y_train), params)
    return params

これで、train_batchfori_loopで呼び出せると思ったのですが、インデックスが動的に変わるのがダメっぽいです。
jax.lax.dynamic_sliceを使うといけるかもしれませんが、バッチサイズが固定ではない(端数のバッチがある)ので、使えなそうです。
割り切れるバッチサイズにすれば、固定にできるので、fori_loopが使えるかもしれませんが、そのためにバッチサイズを弄る必要があるというのも・・・?

また、SGDはjax.tree_multimapを使いました。

@jax.jit
def SGD(params, grads, lr = 0.1):
    return jax.tree_multimap(lambda param, grad: param - lr * grad, params, grads)

というわけで、修正版のコードは以下のgistにあります。
https://gist.github.com/Catminusminus/002ddb267a7001aba83520b33518b20f

リファレンス

本記事はstaxの実装とそのサンプルを参考に作成しました。
またparamsの形式はHaikuを参考にしました。

@software{jax2018github,
  author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
  title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
  url = {http://github.com/google/jax},
  version = {0.2.5},
  year = {2018},
}
@software{haiku2020github,
  author = {Tom Hennigan and Trevor Cai and Tamara Norman and Igor Babuschkin},
  title = {{H}aiku: {S}onnet for {JAX}},
  url = {http://github.com/deepmind/dm-haiku},
  version = {0.0.3},
  year = {2020},
}

Discussion