PyTorch to JAX 移行ガイド(MLP学習編)
背景
「JAX最高」「GoogleではみんなJAXやってる」などと巷で言われているが、研の活動をやってると、比較手法がPyTorchで提供されていたり、ちょっと特殊な損失関数とかを使わないといけなかったり、あとはネットワーク魔改造をしたくなったりと、「とりあえずまずはPyTorchでやっとくか…」と思わせる要素がたくさんあり、PyTorchから抜け出せずにいた。
ムムッでもこれは2013年ごろを思い出す…その頃自分はとにかくMatlabで全部書いてて、なかなかPythonに移行出来ずにいた。そんななか「飯の種ネタをPythonで書き始めれば、Pythonできない→成果が出ない→死」なので自動的にPythonを習得できるのでは???と思い、えいやとPythonの海に飛び込んだのである。思えばPyTorchもDockerもそんな感じで飛び込んだが、今こそJAXに飛び込む時なのかもしれない。
移行手順
- データローダを
tensorflow-datasets
で書き直す - ネットワークを
flax.linen
で書き直す - 学習の1ステップをJAXで書き直して
jax.jit
でデコレートする
JAX自体はNN学習に関するあれこれをサポートしていないので、それ用のライブラリを追加で利用する必要があります。公式で紹介されている有名どころはflax
とdm-haiku
ですが、optax
(損失関数やoptimizerを実装したライブラリ)との組み合わせやすさという点からflaxを選びました。
サンプル: MLPをMNISTで学習する
コードはこちらにアップロードしています。速度は上記をGoogle colabのCPU環境で計測したものです。
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# dataloader
train_loader = DataLoader(datasets.mnist.MNIST("./data", download=True, transform=transforms.ToTensor(), train=True), batch_size=32, shuffle=True)
val_loader = DataLoader(datasets.mnist.MNIST("./data", download=True, transform=transforms.ToTensor(), train=False), batch_size=32, shuffle=False)
# model
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(),
nn.Linear(128, 256), nn.ReLU(),
nn.Linear(256, 10), nn.LogSoftmax(-1))
def forward(self, x):
return self.net(x)
model = MLP()
opt = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.NLLLoss()
def step(x, y, is_train=True):
x = x.reshape(-1, 28 * 28)
y_pred = model(x)
loss = criterion(y_pred, y)
if is_train:
opt.zero_grad()
loss.backward()
opt.step()
return loss, y_pred
for e in range(10):
tic = time.time()
train_loss, val_loss, acc = 0., 0., 0.
for x, y in train_loader:
model.train()
loss, y_pred = step(x, y, is_train=True)
train_loss += loss.item()
train_loss /= len(train_loader)
for x, y in val_loader:
model.eval()
with torch.no_grad():
loss, y_pred = step(x, y, is_train=False)
val_loss += loss.item()
acc += (y_pred.max(-1)[1] == y).float().mean()
val_loss /= len(val_loader)
acc /= len(val_loader)
elapsed = time.time() - tic
print(f"train_loss: {train_loss:0.2f}, val_loss: {val_loss:0.2f}, val_acc: {acc:0.2f}, elapsed: {elapsed:0.2f}")
まあなんかいつもの感じですね。一つポイントを挙げるなら、学習・推論の1ステップをstep
関数に切り出しておきます。こうすることで、移行が簡単になります。回すとだいたいこんな感じで、だいたい1epoch 14秒ちょっとです。
train_loss: 0.25, val_loss: 0.13, val_acc: 0.96, elapsed: 14.31
train_loss: 0.10, val_loss: 0.10, val_acc: 0.97, elapsed: 14.93
train_loss: 0.07, val_loss: 0.08, val_acc: 0.97, elapsed: 14.97
train_loss: 0.05, val_loss: 0.09, val_acc: 0.98, elapsed: 15.04
train_loss: 0.04, val_loss: 0.08, val_acc: 0.98, elapsed: 15.04
train_loss: 0.03, val_loss: 0.08, val_acc: 0.98, elapsed: 15.05
train_loss: 0.03, val_loss: 0.08, val_acc: 0.98, elapsed: 15.14
train_loss: 0.03, val_loss: 0.08, val_acc: 0.98, elapsed: 14.99
train_loss: 0.02, val_loss: 0.07, val_acc: 0.98, elapsed: 15.04
train_loss: 0.02, val_loss: 0.11, val_acc: 0.98, elapsed: 15.06
tensorflow-datasets
で書き直す
データローダをこれは実のところjaxあんま関係ないのですが、データローダをPyTorch純正のものから、tensorflow-datasets
(TFDS) のものに取り換えます。TFDSは「次のバッチをプリフェッチして高速化できる」「map
関数を使った前処理やデータ拡張が簡単」「numpyでロードできる」などいくつかのうれしい機能があり、PyTorchと組み合わせても有効に使えます。以下を参考にしました -> Pytorch+Tensorflowのちゃんぽんコードのすゝめ(tfdsでpytorchをブーストさせる話)。
置き換えるとこんな感じになります。as_numpy_iterator()
を使って、バッチをnumpy.arrayで吐くイテレータが使えます。
# ...
# 前略
# ...
import tensorflow as tf
import tensorflow_datasets as tfds
def preprocessing(x, y):
x = tf.cast(x, tf.float32) / 255.
return x, y
ds = tfds.load("mnist", as_supervised=True, shuffle_files=False, download=True)
train_set = ds["train"]
train_set = train_set.shuffle(len(train_set), seed=0, reshuffle_each_iteration=True).batch(32).map(preprocessing).prefetch(1)
val_set = ds["test"]
val_set = val_set.batch(32).map(preprocessing).prefetch(1)
# ...
# 中略
# ...
for e in range(10):
tic = time.time()
train_loss, val_loss, acc = 0., 0., 0.
for x, y in train_set.as_numpy_iterator():
x = torch.from_numpy(x)
y = torch.from_numpy(y)
model.train()
loss, y_pred = step(x, y, is_train=True)
train_loss += loss.item()
train_loss /= len(train_set)
# ...
# 後略
# ...
回してみましょう。
train_loss: 0.26, val_loss: 0.13, val_acc: 0.96, elapsed: 13.90
train_loss: 0.11, val_loss: 0.10, val_acc: 0.97, elapsed: 10.88
train_loss: 0.07, val_loss: 0.09, val_acc: 0.97, elapsed: 7.21
train_loss: 0.06, val_loss: 0.08, val_acc: 0.98, elapsed: 7.21
train_loss: 0.04, val_loss: 0.08, val_acc: 0.97, elapsed: 10.61
train_loss: 0.04, val_loss: 0.08, val_acc: 0.98, elapsed: 10.61
train_loss: 0.03, val_loss: 0.08, val_acc: 0.98, elapsed: 7.32
train_loss: 0.03, val_loss: 0.09, val_acc: 0.98, elapsed: 7.36
train_loss: 0.02, val_loss: 0.08, val_acc: 0.98, elapsed: 7.54
train_loss: 0.02, val_loss: 0.11, val_acc: 0.98, elapsed: 7.33
倍くらい速くなっていますね。TFDSすごい。
flax.linen
で書き直す
ネットワークをさてここからが本番です。torch.nn
に対応するモジュールとして、flax.linen
があり、この中に各種レイヤーが実装されています。PyTorchのnn.Module
と違い、__init__
やforward
を書く必要はなく、以下のような@fnn.compact
でデコレートされた__call__
を書くことになります。
import jax
import jax.numpy as jnp
import flax.linen as fnn
class MLP(fnn.Module):
@fnn.compact
def __call__(self, x):
x = fnn.Dense(128)(x)
x = fnn.relu(x)
x = fnn.Dense(256)(x)
x = fnn.relu(x)
x = fnn.Dense(10)(x)
x = fnn.log_softmax(x)
return x
PyTorchとflaxの大きな違いは、モデル自体がパラメタを保持しないという点です。flaxでは以下のように、パラメタが別のオブジェクトとして作成されます。
model = MLP()
params = model.init(jax.random.PRNGKey(0), jnp.ones([1, 28 * 28]))['params']
いろいろ出てきましたが順番に説明すると、
-
jax.numpy
にnumpy
に入っているいろんな関数が実装されています。たいていのnumpy関数はjax.numpyにも入っています。 -
model.init
では、特定の疑似乱数生成器(jax.random.PRNGKey
)を使い、モデルのパラメタを初期化します。このとき、入力テンソルの形をjnp.ones([1, 28 * 28])
のように与えることで、チャネル数が自動的によしなにされます(flatten
したあと入力のチャネル数がいくつになるっけ・・・?みたいなのを手計算しなくて良くなります) - この
params
の中身を見てみると、
FrozenDict({
Dense_0: {
kernel: DeviceArray([[-6.9373280e-02, -1.7459888e-02, 2.6923235e-04, ...,
-3.5230294e-02, -3.9390869e-02, 1.8164413e-02],
[ 7.5236415e-03, 2.2690799e-02, -1.8007368e-05, ...,
2.8825440e-02, 5.4414038e-02, 1.0891268e-02],
[-1.6200040e-02, -4.6700433e-02, 5.5902313e-02, ...,
5.1038559e-03, -3.8261801e-02, -1.7489832e-02],
...,
[-1.9485658e-02, -1.1464893e-02, 1.1733949e-02, ...,
-3.8100831e-02, 4.0606514e-02, -2.5036685e-02],
[-7.7254638e-02, 3.7608985e-02, 5.9655067e-02, ...,
3.0473005e-02, -3.4684684e-02, -2.2349501e-02],
[ 5.9532784e-02, 6.1933022e-02, 2.4909737e-02, ...,
-3.8187259e-03, -1.4835574e-02, 4.1087948e-02]], dtype=float32),
bias: DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
},
Dense_1: {
みたいな感じで、Dict形式でパラメタが保存されています。
flax.training.train_state.TrainState
を使ってひとまとめに
Optimizerや後述の損失関数は、optax
というライブラリに入っているのでこれを使います。また、flax
にはTrainState
という、モデルのforward, パラメタ、optimzerをひとまとめにする便利なクラスがあるので、こいつも使います。
from flax.training.train_state import TrainState
import optax
model = MLP()
params = model.init(jax.random.PRNGKey(0), jnp.ones([1, 28 * 28]))['params']
tx = optax.adam(0.001)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)
jax.jit
でデコレートする
学習の1ステップをJAXで書き直してdef step(x, y, is_train=True):
x = x.reshape(-1, 28 * 28)
y_pred = model(x)
loss = criterion(y_pred, y)
if is_train:
opt.zero_grad()
loss.backward()
opt.step()
return loss, y_pred
これをjaxで書き直すとこんな感じになります。
@partial(jax.jit, static_argnums=(3,))
def step(x, y, state, is_train=True):
def loss_fn(params):
y_pred = state.apply_fn({'params': params}, x)
loss = optax.softmax_cross_entropy(logits=y_pred, labels=y).mean()
return loss, y_pred
x = x.reshape(-1, 28 * 28)
y = jnp.eye(10)[y] # optax.softmax_cross_entropyはone-hotを受け取る
if is_train:
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, y_pred), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
else:
loss, y_pred = loss_fn(state.params)
return loss, y_pred, state
ちょっとややこしいですね。一つずつみていきましょう。
def loss_fn(params):
y_pred = state.apply_fn({'params': params}, x)
loss = optax.softmax_cross_entropy(logits=y_pred, labels=y).mean()
return loss, y_pred
ここはモデルのforward~損失関数の定義です。先に定義したTrainState
のインスタンスstate
の持つstate.apply_fn
関数にモデルのパラメタparams
と入力x
を食わせると、モデルの出力y_pred
が得られます。損失関数はoptax
の中にいろいろ入っています。
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, y_pred), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
上記で定義したloss_fn
についての勾配を計算する関数grad_fn
を用意します。lossを出力しなくていい場合は、jax.grad
を使います。loss_fn
がlossだけ返す場合はhas_aux=False
、loss以外も返す(y_predとか)場合はhas_aux=True
にします。この関数でgrads
を計算でき、その後のstate.apply_gradients
に与えることで、モデルのパラメタが更新されます。
else:
loss, y_pred = loss_fn(state.params)
推論時でgradientの計算が不要な場合は、loss_fn
をそのまま使います。
@partial(jax.jit, static_argnums=(3,))
def step(x, y, state, is_train=True):
jax.jit
を使ってstep関数をXLAコンパイルし、高速化を目指します。詳細は以下のブログが詳しかったです -> JAX入門~高速なNumPyとして使いこなすためのチュートリアル~
ポイントとしては、is_train
は固定値なのでstatic_argnums
で指定する必要があります。
完成: JAX移行後
import time
from functools import partial
import jax
import jax.numpy as jnp
import flax.linen as fnn
from flax.training.train_state import TrainState
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
def preprocessing(x, y):
x = tf.cast(x, tf.float32) / 255.
return x, y
ds = tfds.load("mnist", as_supervised=True, shuffle_files=False, download=True)
train_set = ds["train"]
train_set = train_set.shuffle(len(train_set), seed=0, reshuffle_each_iteration=True).batch(32).map(preprocessing).prefetch(1)
val_set = ds["test"]
val_set = val_set.batch(32).map(preprocessing).prefetch(1)
# model
class MLP(fnn.Module):
@fnn.compact
def __call__(self, x):
x = fnn.Dense(128)(x)
x = fnn.relu(x)
x = fnn.Dense(256)(x)
x = fnn.relu(x)
x = fnn.Dense(10)(x)
x = fnn.log_softmax(x)
return x
model = MLP()
params = model.init(jax.random.PRNGKey(0), jnp.ones([1, 28 * 28]))['params']
tx = optax.adam(0.001)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)
@partial(jax.jit, static_argnums=(3,))
def step(x, y, state, is_train=True):
def loss_fn(params):
y_pred = state.apply_fn({'params': params}, x)
loss = optax.softmax_cross_entropy(logits=y_pred, labels=y).mean()
return loss, y_pred
x = x.reshape(-1, 28 * 28)
y = jnp.eye(10)[y]
if is_train:
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, y_pred), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
else:
loss, y_pred = loss_fn(state.params)
return loss, y_pred, state
for e in range(10):
tic = time.time()
train_loss, val_loss, acc = 0., 0., 0.
for x, y in train_set.as_numpy_iterator():
loss, y_pred, state = step(x, y, state, is_train=True)
train_loss += loss
train_loss /= len(train_set)
for x, y in val_set.as_numpy_iterator():
loss, y_pred, state = step(x, y, state, is_train=False)
val_loss += loss
acc += (jnp.argmax(y_pred, 1) == y).mean()
val_loss /= len(val_set)
acc /= len(val_set)
elapsed = time.time() - tic
print(f"train_loss: {train_loss:0.2f}, val_loss: {val_loss:0.2f}, val_acc: {acc:0.2f}, elapsed: {elapsed:0.2f}")
回してみましょう
train_loss: 0.22, val_loss: 0.11, val_acc: 0.96, elapsed: 13.84
train_loss: 0.10, val_loss: 0.09, val_acc: 0.97, elapsed: 6.18
train_loss: 0.07, val_loss: 0.09, val_acc: 0.97, elapsed: 4.02
train_loss: 0.05, val_loss: 0.08, val_acc: 0.98, elapsed: 6.18
train_loss: 0.04, val_loss: 0.08, val_acc: 0.98, elapsed: 6.40
train_loss: 0.04, val_loss: 0.07, val_acc: 0.98, elapsed: 6.18
train_loss: 0.03, val_loss: 0.08, val_acc: 0.98, elapsed: 6.17
train_loss: 0.03, val_loss: 0.09, val_acc: 0.98, elapsed: 6.18
train_loss: 0.02, val_loss: 0.09, val_acc: 0.98, elapsed: 3.80
train_loss: 0.02, val_loss: 0.10, val_acc: 0.98, elapsed: 6.18
最初のエポックだけコンパイルが入るので時間がかかりますが、その後はかなり高速化されていますね。
実装 | dataloader | runtime/epoch (s) |
---|---|---|
PyTorch | PyTorch | 14秒前後 |
PyTorch | TFDS | 7.5~10秒前後 |
JAX | TFDS | 6秒前後 |
まとめ
MLPの学習を例に、PyTorchのコードをJAX (Flax + Optax) に移行する方法を紹介しました。割とすんなり移行できそうに感じます。今後FlaxとOptaxでカバーできるNNレイヤーやoptimizer/loss functionの種類が増えてくると、さらに使いやすくなりそうです。
今後、以下のような内容を取り上げたいと思います。
- GPUでの学習
- サンプリングが含まれるモデルの学習(VAE, GANなど)
Discussion
興味深い記事ありがとうございます!
完成: JAX移行後が、PyTorchのコードになってしまっている気がします
コメントありがとうございます!完全におっしゃるとおりでした & 修正しました!