PyTorch to JAX 移行ガイド(GPUでのCNN学習 | BatchNorm編)

2021/10/02に公開

背景

JAXベースのNNライブラリであるFlaxを用いて、PyTorchのコードをJAXに移行する方法を紹介しています。特に今回はGPUを用いたCNNの学習を取り上げ、FlaxにおけるBatchNormの使い方について学びます。

https://gist.github.com/yonetaniryo/37e7023bc55db275737ec1f8e5851551
例によってコードはgistにアップロードしてあります。

サンプル: CNNをCIFAR-10で学習する

まずはPyTorchから。基本は前回と同じで、JAXへの移行を簡単にするために学習・推論の1ステップをstep関数に切り出しています。あとは、GPUで学習するので.cudaをつけています。ネットワークはよくあるConv-BN-Reluを積み重ねたCNNです。データローダには、前回高速であることが分かったTFDSを用いています。

import time

import torch
import torch.nn as nn
import torch.optim as optim

import tensorflow as tf
import tensorflow_datasets as tfds


def preprocessing(x, y):
    x = tf.cast(x, tf.float32) / 255.
    return x, y

ds = tfds.load("cifar10", as_supervised=True, shuffle_files=False, download=True)
train_set = ds["train"]
train_set = train_set.shuffle(len(train_set), seed=0, reshuffle_each_iteration=True).batch(32).map(preprocessing).prefetch(1)
val_set = ds["test"]
val_set = val_set.batch(32).map(preprocessing).prefetch(1)

# model
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(4096, 256)
        self.fc2 = nn.Linear(256, 10)
        
    def forward(self, x):
        h = self.conv1(x)
        h = self.bn1(h)
        h = torch.relu(h)
        h = torch.max_pool2d(h, (2, 2))
        h = self.conv2(h)
        h = self.bn2(h)
        h = torch.relu(h)
        h = torch.max_pool2d(h, (2, 2))
        h = h.reshape(len(h), -1)
        h = self.fc1(h)
        h = torch.relu(h)
        h = self.fc2(h)
        y = torch.log_softmax(h, -1)

        return y


model = CNN().cuda()
opt = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.NLLLoss()

def step(x, y, is_training=True):
    x = x.permute(0, 3, 1, 2)  # TFDS loads images in HWC format
    y_pred = model(x)
    loss = criterion(y_pred, y)
    if is_training:
        opt.zero_grad()
        loss.backward()
        opt.step()
    return loss, y_pred
    
for e in range(5):
    tic = time.time()
    train_loss, val_loss, acc = 0., 0., 0.
    for x, y in train_set.as_numpy_iterator():
        x = torch.from_numpy(x).cuda()
        y = torch.from_numpy(y).cuda()
        model.train()
        loss, y_pred = step(x, y, is_training=True)
        train_loss += loss.item()
    train_loss /= len(train_set)

    for x, y in val_set.as_numpy_iterator(): 
        x = torch.from_numpy(x).cuda()
        y = torch.from_numpy(y).cuda()
        model.eval()
        with torch.no_grad():
            loss, y_pred = step(x, y, is_training=False)
        val_loss += loss.item()
        acc += (y_pred.max(-1)[1] == y).float().mean()
    val_loss /= len(val_set)
    acc /= len(val_set)
    elapsed = time.time() - tic
    
    print(f"train_loss: {train_loss:0.2f}, val_loss: {val_loss:0.2f}, val_acc: {acc:0.2f}, elapsed: {elapsed:0.2f}")

Colab (Tesla K80)で回すとこんな感じ。

train_loss: 1.27, val_loss: 1.30, val_acc: 0.54, elapsed: 43.59
train_loss: 0.92, val_loss: 1.02, val_acc: 0.64, elapsed: 32.64
train_loss: 0.79, val_loss: 1.07, val_acc: 0.64, elapsed: 32.69
train_loss: 0.69, val_loss: 0.79, val_acc: 0.72, elapsed: 32.64
train_loss: 0.61, val_loss: 0.81, val_acc: 0.72, elapsed: 32.68

JAXでGPUを使う

実のところやることは一つで、手元のCUDA versionに準拠したjaxをpipで入れたら終わりです。

$ pip install jax[cuda111]

いかがでしたか?これでJAXで書いたコードがGPUで動くようになります。

BatchNormの利用

(とりあえずJAXでCNNが組みたい!という人はこのセクションは飛ばしてください)

CNNに限らずBatchNormを使うことは多いかと思います。BatchNormは学習時と推論時で挙動が違うのが特徴ですね。学習時は、入力ミニバッチを自身の平均(batch_mean)と分散(batch_var)で正規化しつつ、学習データ全体の移動平均(mean)および移動分散(var)を計算します。一方推論時は、学習時に計算された移動平均・移動分散でバッチを正規化します。そのうえでBatchNormには、ハイパーパラメタとしてmomentum, epsilon、データから学習されるパラメタとしてscale, biasがあります。これらは以下のように組み合わせて利用されます。

# x_in: 入力バッチ
# mean: 前バッチから引き継いでいる移動平均
# var: 前バッチから引き継いでいる移動分散

batch_mean = x_in.mean(0)
batch_var = x_in.var(0)
if is_training:
  # 移動平均・分散の更新。momentum=慣性パラメタによって前バッチから引き継いでいる統計量の影響を残す
  mean = momentum * mean + (1 - momentum) * batch_mean
  var = momentum * var + (1 - momentum) * batch_var
else:
  # 推論時は学習データの移動平均・分散を用いる
  batch_mean = mean
  batch_var = var

# 正規化。ゼロ除算を防ぐためにepsilonで下駄をはかせる
x_in_normed = (x_in - batch_mean) / jnp.sqrt(batch_var + epsilon)

# スケール・バイアスの適用
x_out = x_in_normed * scale + bias

FlaxのBatchNorm実装も、変数の名前がちょっと違いますが基本的に同じことをしています)
このような挙動の切り替えを、PyTorchではmodel.train(), model.eval()によって実現しているのでした。ところがFlax にはそんな便利なものはないので、自分でどうにかする必要があります。

BatchNormの挙動を理解するために、以下のようにただ入力をBatchNormするだけのネットワークを作ってみましょう。

class BN(fnn.Module):

    @fnn.compact
    def __call__(self, x, use_running_average):
        x = fnn.BatchNorm(use_running_average=use_running_average, momentum=0.5)(x)
        return x

model = BN()
variables = model.init(jax.random.PRNGKey(0), jnp.ones((1, 1)), True)

さて、このvariablesの中身をのぞいてみると、

FrozenDict({
    batch_stats: {
        BatchNorm_0: {
            mean: DeviceArray([0.], dtype=float32),
            var: DeviceArray([1.], dtype=float32),
        },
    },
    params: {
        BatchNorm_0: {
            scale: DeviceArray([1.], dtype=float32),
            bias: DeviceArray([0.], dtype=float32),
        },
    },
})

このように、先ほど説明したBatchNormのパラメタ(scale, bias)をまとめたparamsと、移動平均・分散であるmean, varをまとめたbatch_statsが生成されています。

学習時

params = variables["params"]
batch_stats = variables["batch_stats"]
x = jnp.array([[1.], [2.], [3.], [4.]])
y_pred, mutated_vars = model.apply({"params": params, "batch_stats": batch_stats}, x, use_running_average=False, mutable=["batch_stats"])
new_batch_stats = mutated_vars["batch_stats"]

学習時は移動平均・分散を利用しないので、use_running_average=Falseになります。重要な点として、model.applymutable=["batch_stats"]という引数が与えられています。これは、BatchNormレイヤーにバッチが通ることで、batch_statsが変更される(mutable = 可変)ということを伝えています(そうしないとエラーがでます)。new_batch_statsの中身を見てみると、

FrozenDict({
    BatchNorm_0: {
        mean: DeviceArray([1.25], dtype=float32),
        var: DeviceArray([1.125], dtype=float32),
    },
})

一応検算してみましょう。モデルが初期化された段階で、batch_statsの中身はmean = 0.0, var = 1.0でした。一方、新たな入力x = jnp.array([[1.], [2.], [3.], [4.]])について、その平均はbatch_mean = 2.5, batch_var = 1.25です。momentumを0.5に設定しているので、確かに新しい移動平均は0.0 * 0.5 + 2.5 * 0.5 = 1.25、移動分散は1.0 * 0.5 + 1.25 * 0.5 = 1.125となります。

推論時

y_pred = model.apply({"params": params, "batch_stats": new_batch_stats}, x, use_running_average=True)

推論時は、学習の際に得られたnew_batch_statsをbatch_statsとして与えつつ、use_runnining_average=Trueにして、その統計量を使って入力を正規化します。

JAXを用いたCNNの学習

ここからが本番です。先に書いたPyTorchバージョンのCNN学習をJAXに移行してみましょう。

CNNの定義

まずはCNNを定義します。

class CNN(fnn.Module):

    @fnn.compact
    def __call__(self, x, is_training):
        x = fnn.Conv(features=32, kernel_size=(3, 3))(x)
        x = fnn.BatchNorm(use_running_average=not is_training, momentum=0.1)(x)
        x = fnn.relu(x)
        x = fnn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = fnn.Conv(features=64, kernel_size=(3, 3))(x)
        x = fnn.BatchNorm(use_running_average=not is_training, momentum=0.1)(x)
        x = fnn.relu(x)
        x = fnn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = fnn.Dense(features=256)(x)
        x = fnn.relu(x)
        x = fnn.Dense(features=10)(x)
        x = fnn.log_softmax(x)
        
        return x
	
# モデルの初期化
model = CNN()
variables = model.init(jax.random.PRNGKey(0), jnp.ones([1, 32, 32, 3]), True)
params = variables["params"]
batch_stats = variables["batch_stats"]
tx = optax.adam(0.001)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)

モデルを初期化するとparams, batch_statsが出てくるので、paramsTrainStateに渡し、batch_statsはそのまま控えておきます。

学習・推論ステップ

学習時には、あるステップで得られたbatch_statsを次のステップに引き継ぐ必要があります。そこで、学習ステップの返り値としてmutated_varsの中に入っているbatch_statsを与えます。

@partial(jax.jit, static_argnums=(4,))
def step(x, y, state, batch_stats, is_training=True):
    def loss_fn(params, batch_stats):
        y_pred, mutated_vars = state.apply_fn({"params": params, "batch_stats": batch_stats}, x, is_training, mutable=["batch_stats"]) 
        new_batch_stats = mutated_vars["batch_stats"]
        loss = optax.softmax_cross_entropy(logits=y_pred, labels=y).mean()
        return loss, (y_pred, new_batch_stats)
    y = jnp.eye(10)[y]
    if is_training:
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss, (y_pred, new_batch_stats)), grads = grad_fn(state.params, batch_stats)
        state = state.apply_gradients(grads=grads)
    else:
        loss, (y_pred, new_batch_stats) = loss_fn(state.params, batch_stats)
    return loss, y_pred, state, new_batch_stats

これにあわせて、学習・推論ループも少しだけ変わります。

for e in range(5):
    tic = time.time()
    train_loss, val_loss, acc = 0., 0., 0.
    for x, y in train_set.as_numpy_iterator(): 
        loss, y_pred, state, batch_stats = step(x, y, state, batch_stats, is_training=True)
        train_loss += loss
    train_loss /= len(train_set)
... 

完成: JAX移行後


import time
from functools import partial

import jax
import jax.numpy as jnp
import flax.linen as fnn
from flax.training.train_state import TrainState
import optax
import tensorflow as tf
import tensorflow_datasets as tfds

def preprocessing(x, y):
    x = tf.cast(x, tf.float32) / 255.
    
    return x, y

ds = tfds.load("cifar10", as_supervised=True, shuffle_files=False, download=True)
train_set = ds["train"]
train_set = train_set.shuffle(len(train_set), seed=0, reshuffle_each_iteration=True).batch(32).map(preprocessing).prefetch(1)
val_set = ds["test"]
val_set = val_set.batch(32).map(preprocessing).prefetch(1)

# model
class CNN(fnn.Module):

    @fnn.compact
    def __call__(self, x, is_training):
        x = fnn.Conv(features=32, kernel_size=(3, 3))(x)
        x = fnn.BatchNorm(use_running_average=not is_training, momentum=0.1)(x)
        x = fnn.relu(x)
        x = fnn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = fnn.Conv(features=64, kernel_size=(3, 3))(x)
        x = fnn.BatchNorm(use_running_average=not is_training, momentum=0.1)(x)
        x = fnn.relu(x)
        x = fnn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = fnn.Dense(features=256)(x)
        x = fnn.relu(x)
        x = fnn.Dense(features=10)(x)
        x = fnn.log_softmax(x)
        
        return x

model = CNN()
variables = model.init(jax.random.PRNGKey(0), jnp.ones([1, 32, 32, 3]), True)
params = variables["params"]
batch_stats = variables["batch_stats"]
tx = optax.adam(0.001)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)

@partial(jax.jit, static_argnums=(4,))
def step(x, y, state, batch_stats, is_training=True):
    def loss_fn(params, batch_stats):
        y_pred, mutated_vars = state.apply_fn({"params": params, "batch_stats": batch_stats}, x, is_training, mutable=["batch_stats"]) 
        new_batch_stats = mutated_vars["batch_stats"]
        loss = optax.softmax_cross_entropy(logits=y_pred, labels=y).mean()
        return loss, (y_pred, new_batch_stats)
    y = jnp.eye(10)[y]
    if is_training:
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss, (y_pred, new_batch_stats)), grads = grad_fn(state.params, batch_stats)
        state = state.apply_gradients(grads=grads)
    else:
        loss, (y_pred, new_batch_stats) = loss_fn(state.params, batch_stats)
    return loss, y_pred, state, new_batch_stats

for e in range(5):
    tic = time.time()
    train_loss, val_loss, acc = 0., 0., 0.
    for x, y in train_set.as_numpy_iterator(): 
        loss, y_pred, state, batch_stats = step(x, y, state, batch_stats, is_training=True)
        train_loss += loss
    train_loss /= len(train_set)

    for x, y in val_set.as_numpy_iterator(): 
        loss, y_pred, state, batch_stats = step(x, y, state, batch_stats, is_training=False)
        val_loss += loss
        acc += (jnp.argmax(y_pred, 1) == y).mean()
    val_loss /= len(val_set)
    acc /= len(val_set)
    elapsed = time.time() - tic
    
    print(f"train_loss: {train_loss:0.2f}, val_loss: {val_loss:0.2f}, val_acc: {acc:0.2f}, elapsed: {elapsed:0.2f}")

回してみましょう。

train_loss: 1.28, val_loss: 1.04, val_acc: 0.63, elapsed: 29.05
train_loss: 0.94, val_loss: 0.97, val_acc: 0.66, elapsed: 13.19
train_loss: 0.81, val_loss: 0.87, val_acc: 0.69, elapsed: 21.37
train_loss: 0.73, val_loss: 0.83, val_acc: 0.72, elapsed: 13.03
train_loss: 0.65, val_loss: 0.87, val_acc: 0.71, elapsed: 12.82

コンパイルが入る最初のエポックを除いて、PyTorch実装のほぼ2倍くらい速くなりました。これはすごい。

まとめ

CNNの学習を例に、PyTorchのコードをJAX (Flax + Optax) に移行する方法を紹介しました。
1エポック32sec->13secはなかなか脅威ですね。

FlaxにおけるBatchNormの利用はPyTorchと比べて一癖あり、バッチの統計量を自分で管理しなければなりません。この辺についての公式説明が現状割と不足しており(c.f. https://github.com/google/flax/issues/932 )チュートリアルのさらなる充実が望まれそうです。普通にえいやで使うだけであれば、本記事が参考になれば幸いです。

Discussion