PyTorch to JAX 移行ガイド(GPUでのCNN学習 | BatchNorm編)
背景
JAXベースのNNライブラリであるFlaxを用いて、PyTorchのコードをJAXに移行する方法を紹介しています。特に今回はGPUを用いたCNNの学習を取り上げ、FlaxにおけるBatchNormの使い方について学びます。
例によってコードはgistにアップロードしてあります。
サンプル: CNNをCIFAR-10で学習する
まずはPyTorchから。基本は前回と同じで、JAXへの移行を簡単にするために学習・推論の1ステップをstep
関数に切り出しています。あとは、GPUで学習するので.cuda
をつけています。ネットワークはよくあるConv-BN-Reluを積み重ねたCNNです。データローダには、前回高速であることが分かったTFDSを用いています。
import time
import torch
import torch.nn as nn
import torch.optim as optim
import tensorflow as tf
import tensorflow_datasets as tfds
def preprocessing(x, y):
x = tf.cast(x, tf.float32) / 255.
return x, y
ds = tfds.load("cifar10", as_supervised=True, shuffle_files=False, download=True)
train_set = ds["train"]
train_set = train_set.shuffle(len(train_set), seed=0, reshuffle_each_iteration=True).batch(32).map(preprocessing).prefetch(1)
val_set = ds["test"]
val_set = val_set.batch(32).map(preprocessing).prefetch(1)
# model
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.fc1 = nn.Linear(4096, 256)
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
h = self.conv1(x)
h = self.bn1(h)
h = torch.relu(h)
h = torch.max_pool2d(h, (2, 2))
h = self.conv2(h)
h = self.bn2(h)
h = torch.relu(h)
h = torch.max_pool2d(h, (2, 2))
h = h.reshape(len(h), -1)
h = self.fc1(h)
h = torch.relu(h)
h = self.fc2(h)
y = torch.log_softmax(h, -1)
return y
model = CNN().cuda()
opt = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.NLLLoss()
def step(x, y, is_training=True):
x = x.permute(0, 3, 1, 2) # TFDS loads images in HWC format
y_pred = model(x)
loss = criterion(y_pred, y)
if is_training:
opt.zero_grad()
loss.backward()
opt.step()
return loss, y_pred
for e in range(5):
tic = time.time()
train_loss, val_loss, acc = 0., 0., 0.
for x, y in train_set.as_numpy_iterator():
x = torch.from_numpy(x).cuda()
y = torch.from_numpy(y).cuda()
model.train()
loss, y_pred = step(x, y, is_training=True)
train_loss += loss.item()
train_loss /= len(train_set)
for x, y in val_set.as_numpy_iterator():
x = torch.from_numpy(x).cuda()
y = torch.from_numpy(y).cuda()
model.eval()
with torch.no_grad():
loss, y_pred = step(x, y, is_training=False)
val_loss += loss.item()
acc += (y_pred.max(-1)[1] == y).float().mean()
val_loss /= len(val_set)
acc /= len(val_set)
elapsed = time.time() - tic
print(f"train_loss: {train_loss:0.2f}, val_loss: {val_loss:0.2f}, val_acc: {acc:0.2f}, elapsed: {elapsed:0.2f}")
Colab (Tesla K80)で回すとこんな感じ。
train_loss: 1.27, val_loss: 1.30, val_acc: 0.54, elapsed: 43.59
train_loss: 0.92, val_loss: 1.02, val_acc: 0.64, elapsed: 32.64
train_loss: 0.79, val_loss: 1.07, val_acc: 0.64, elapsed: 32.69
train_loss: 0.69, val_loss: 0.79, val_acc: 0.72, elapsed: 32.64
train_loss: 0.61, val_loss: 0.81, val_acc: 0.72, elapsed: 32.68
JAXでGPUを使う
実のところやることは一つで、手元のCUDA versionに準拠したjaxをpipで入れたら終わりです。
$ pip install jax[cuda111]
いかがでしたか?これでJAXで書いたコードがGPUで動くようになります。
BatchNormの利用
(とりあえずJAXでCNNが組みたい!という人はこのセクションは飛ばしてください)
CNNに限らずBatchNormを使うことは多いかと思います。BatchNormは学習時と推論時で挙動が違うのが特徴ですね。学習時は、入力ミニバッチを自身の平均(batch_mean
)と分散(batch_var
)で正規化しつつ、学習データ全体の移動平均(mean
)および移動分散(var
)を計算します。一方推論時は、学習時に計算された移動平均・移動分散でバッチを正規化します。そのうえでBatchNormには、ハイパーパラメタとしてmomentum
, epsilon
、データから学習されるパラメタとしてscale
, bias
があります。これらは以下のように組み合わせて利用されます。
# x_in: 入力バッチ
# mean: 前バッチから引き継いでいる移動平均
# var: 前バッチから引き継いでいる移動分散
batch_mean = x_in.mean(0)
batch_var = x_in.var(0)
if is_training:
# 移動平均・分散の更新。momentum=慣性パラメタによって前バッチから引き継いでいる統計量の影響を残す
mean = momentum * mean + (1 - momentum) * batch_mean
var = momentum * var + (1 - momentum) * batch_var
else:
# 推論時は学習データの移動平均・分散を用いる
batch_mean = mean
batch_var = var
# 正規化。ゼロ除算を防ぐためにepsilonで下駄をはかせる
x_in_normed = (x_in - batch_mean) / jnp.sqrt(batch_var + epsilon)
# スケール・バイアスの適用
x_out = x_in_normed * scale + bias
(FlaxのBatchNorm実装も、変数の名前がちょっと違いますが基本的に同じことをしています)
このような挙動の切り替えを、PyTorchではmodel.train()
, model.eval()
によって実現しているのでした。ところがFlax にはそんな便利なものはないので、自分でどうにかする必要があります。
BatchNormの挙動を理解するために、以下のようにただ入力をBatchNormするだけのネットワークを作ってみましょう。
class BN(fnn.Module):
@fnn.compact
def __call__(self, x, use_running_average):
x = fnn.BatchNorm(use_running_average=use_running_average, momentum=0.5)(x)
return x
model = BN()
variables = model.init(jax.random.PRNGKey(0), jnp.ones((1, 1)), True)
さて、このvariables
の中身をのぞいてみると、
FrozenDict({
batch_stats: {
BatchNorm_0: {
mean: DeviceArray([0.], dtype=float32),
var: DeviceArray([1.], dtype=float32),
},
},
params: {
BatchNorm_0: {
scale: DeviceArray([1.], dtype=float32),
bias: DeviceArray([0.], dtype=float32),
},
},
})
このように、先ほど説明したBatchNormのパラメタ(scale
, bias
)をまとめたparams
と、移動平均・分散であるmean
, var
をまとめたbatch_stats
が生成されています。
学習時
params = variables["params"]
batch_stats = variables["batch_stats"]
x = jnp.array([[1.], [2.], [3.], [4.]])
y_pred, mutated_vars = model.apply({"params": params, "batch_stats": batch_stats}, x, use_running_average=False, mutable=["batch_stats"])
new_batch_stats = mutated_vars["batch_stats"]
学習時は移動平均・分散を利用しないので、use_running_average=False
になります。重要な点として、model.apply
にmutable=["batch_stats"]
という引数が与えられています。これは、BatchNormレイヤーにバッチが通ることで、batch_stats
が変更される(mutable = 可変)ということを伝えています(そうしないとエラーがでます)。new_batch_stats
の中身を見てみると、
FrozenDict({
BatchNorm_0: {
mean: DeviceArray([1.25], dtype=float32),
var: DeviceArray([1.125], dtype=float32),
},
})
一応検算してみましょう。モデルが初期化された段階で、batch_stats
の中身はmean = 0.0
, var = 1.0
でした。一方、新たな入力x = jnp.array([[1.], [2.], [3.], [4.]])
について、その平均はbatch_mean = 2.5
, batch_var = 1.25
です。momentum
を0.5に設定しているので、確かに新しい移動平均は0.0 * 0.5 + 2.5 * 0.5 = 1.25
、移動分散は1.0 * 0.5 + 1.25 * 0.5 = 1.125
となります。
推論時
y_pred = model.apply({"params": params, "batch_stats": new_batch_stats}, x, use_running_average=True)
推論時は、学習の際に得られたnew_batch_stats
をbatch_statsとして与えつつ、use_runnining_average=True
にして、その統計量を使って入力を正規化します。
JAXを用いたCNNの学習
ここからが本番です。先に書いたPyTorchバージョンのCNN学習をJAXに移行してみましょう。
CNNの定義
まずはCNNを定義します。
class CNN(fnn.Module):
@fnn.compact
def __call__(self, x, is_training):
x = fnn.Conv(features=32, kernel_size=(3, 3))(x)
x = fnn.BatchNorm(use_running_average=not is_training, momentum=0.1)(x)
x = fnn.relu(x)
x = fnn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = fnn.Conv(features=64, kernel_size=(3, 3))(x)
x = fnn.BatchNorm(use_running_average=not is_training, momentum=0.1)(x)
x = fnn.relu(x)
x = fnn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = fnn.Dense(features=256)(x)
x = fnn.relu(x)
x = fnn.Dense(features=10)(x)
x = fnn.log_softmax(x)
return x
# モデルの初期化
model = CNN()
variables = model.init(jax.random.PRNGKey(0), jnp.ones([1, 32, 32, 3]), True)
params = variables["params"]
batch_stats = variables["batch_stats"]
tx = optax.adam(0.001)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)
モデルを初期化するとparams
, batch_stats
が出てくるので、params
はTrainState
に渡し、batch_stats
はそのまま控えておきます。
学習・推論ステップ
学習時には、あるステップで得られたbatch_stats
を次のステップに引き継ぐ必要があります。そこで、学習ステップの返り値としてmutated_vars
の中に入っているbatch_stats
を与えます。
@partial(jax.jit, static_argnums=(4,))
def step(x, y, state, batch_stats, is_training=True):
def loss_fn(params, batch_stats):
y_pred, mutated_vars = state.apply_fn({"params": params, "batch_stats": batch_stats}, x, is_training, mutable=["batch_stats"])
new_batch_stats = mutated_vars["batch_stats"]
loss = optax.softmax_cross_entropy(logits=y_pred, labels=y).mean()
return loss, (y_pred, new_batch_stats)
y = jnp.eye(10)[y]
if is_training:
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, (y_pred, new_batch_stats)), grads = grad_fn(state.params, batch_stats)
state = state.apply_gradients(grads=grads)
else:
loss, (y_pred, new_batch_stats) = loss_fn(state.params, batch_stats)
return loss, y_pred, state, new_batch_stats
これにあわせて、学習・推論ループも少しだけ変わります。
for e in range(5):
tic = time.time()
train_loss, val_loss, acc = 0., 0., 0.
for x, y in train_set.as_numpy_iterator():
loss, y_pred, state, batch_stats = step(x, y, state, batch_stats, is_training=True)
train_loss += loss
train_loss /= len(train_set)
...
完成: JAX移行後
import time
from functools import partial
import jax
import jax.numpy as jnp
import flax.linen as fnn
from flax.training.train_state import TrainState
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
def preprocessing(x, y):
x = tf.cast(x, tf.float32) / 255.
return x, y
ds = tfds.load("cifar10", as_supervised=True, shuffle_files=False, download=True)
train_set = ds["train"]
train_set = train_set.shuffle(len(train_set), seed=0, reshuffle_each_iteration=True).batch(32).map(preprocessing).prefetch(1)
val_set = ds["test"]
val_set = val_set.batch(32).map(preprocessing).prefetch(1)
# model
class CNN(fnn.Module):
@fnn.compact
def __call__(self, x, is_training):
x = fnn.Conv(features=32, kernel_size=(3, 3))(x)
x = fnn.BatchNorm(use_running_average=not is_training, momentum=0.1)(x)
x = fnn.relu(x)
x = fnn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = fnn.Conv(features=64, kernel_size=(3, 3))(x)
x = fnn.BatchNorm(use_running_average=not is_training, momentum=0.1)(x)
x = fnn.relu(x)
x = fnn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = fnn.Dense(features=256)(x)
x = fnn.relu(x)
x = fnn.Dense(features=10)(x)
x = fnn.log_softmax(x)
return x
model = CNN()
variables = model.init(jax.random.PRNGKey(0), jnp.ones([1, 32, 32, 3]), True)
params = variables["params"]
batch_stats = variables["batch_stats"]
tx = optax.adam(0.001)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)
@partial(jax.jit, static_argnums=(4,))
def step(x, y, state, batch_stats, is_training=True):
def loss_fn(params, batch_stats):
y_pred, mutated_vars = state.apply_fn({"params": params, "batch_stats": batch_stats}, x, is_training, mutable=["batch_stats"])
new_batch_stats = mutated_vars["batch_stats"]
loss = optax.softmax_cross_entropy(logits=y_pred, labels=y).mean()
return loss, (y_pred, new_batch_stats)
y = jnp.eye(10)[y]
if is_training:
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, (y_pred, new_batch_stats)), grads = grad_fn(state.params, batch_stats)
state = state.apply_gradients(grads=grads)
else:
loss, (y_pred, new_batch_stats) = loss_fn(state.params, batch_stats)
return loss, y_pred, state, new_batch_stats
for e in range(5):
tic = time.time()
train_loss, val_loss, acc = 0., 0., 0.
for x, y in train_set.as_numpy_iterator():
loss, y_pred, state, batch_stats = step(x, y, state, batch_stats, is_training=True)
train_loss += loss
train_loss /= len(train_set)
for x, y in val_set.as_numpy_iterator():
loss, y_pred, state, batch_stats = step(x, y, state, batch_stats, is_training=False)
val_loss += loss
acc += (jnp.argmax(y_pred, 1) == y).mean()
val_loss /= len(val_set)
acc /= len(val_set)
elapsed = time.time() - tic
print(f"train_loss: {train_loss:0.2f}, val_loss: {val_loss:0.2f}, val_acc: {acc:0.2f}, elapsed: {elapsed:0.2f}")
回してみましょう。
train_loss: 1.28, val_loss: 1.04, val_acc: 0.63, elapsed: 29.05
train_loss: 0.94, val_loss: 0.97, val_acc: 0.66, elapsed: 13.19
train_loss: 0.81, val_loss: 0.87, val_acc: 0.69, elapsed: 21.37
train_loss: 0.73, val_loss: 0.83, val_acc: 0.72, elapsed: 13.03
train_loss: 0.65, val_loss: 0.87, val_acc: 0.71, elapsed: 12.82
コンパイルが入る最初のエポックを除いて、PyTorch実装のほぼ2倍くらい速くなりました。これはすごい。
まとめ
CNNの学習を例に、PyTorchのコードをJAX (Flax + Optax) に移行する方法を紹介しました。
1エポック32sec->13secはなかなか脅威ですね。
FlaxにおけるBatchNormの利用はPyTorchと比べて一癖あり、バッチの統計量を自分で管理しなければなりません。この辺についての公式説明が現状割と不足しており(c.f. https://github.com/google/flax/issues/932 )チュートリアルのさらなる充実が望まれそうです。普通にえいやで使うだけであれば、本記事が参考になれば幸いです。
Discussion