PyTorch to JAX 移行ガイド(MLP学習編)

16 min read 2

背景

「JAX最高」「GoogleではみんなJAXやってる」などと巷で言われているが、研の活動をやってると、比較手法がPyTorchで提供されていたり、ちょっと特殊な損失関数とかを使わないといけなかったり、あとはネットワーク魔改造をしたくなったりと、「とりあえずまずはPyTorchでやっとくか…」と思わせる要素がたくさんあり、PyTorchから抜け出せずにいた。

ムムッでもこれは2013年ごろを思い出す…その頃自分はとにかくMatlabで全部書いてて、なかなかPythonに移行出来ずにいた。そんななか「飯の種ネタをPythonで書き始めれば、Pythonできない→成果が出ない→死」なので自動的にPythonを習得できるのでは???と思い、えいやとPythonの海に飛び込んだのである。思えばPyTorchもDockerもそんな感じで飛び込んだが、今こそJAXに飛び込む時なのかもしれない。

移行手順

  • データローダをtensorflow-datasetsで書き直す
  • ネットワークをflax.linenで書き直す
  • 学習の1ステップをJAXで書き直してjax.jitでデコレートする

JAX自体はNN学習に関するあれこれをサポートしていないので、それ用のライブラリを追加で利用する必要があります。公式で紹介されている有名どころはflaxdm-haikuですが、optax(損失関数やoptimizerを実装したライブラリ)との組み合わせやすさという点からflaxを選びました。

python==3.7, jax==0.2.20, flax==0.3.5, optax==0.0.9で動作確認をしています。いずれもバージョンアップとともに実装が変わる可能性もあるので注意

サンプル: MLPをMNISTで学習する

https://gist.github.com/yonetaniryo/f40bde60faef48bace3f3b1b949b277e
コードはこちらにアップロードしています。速度は上記をGoogle colabのCPU環境で計測したものです。
import time

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader
from torchvision import datasets, transforms


# dataloader
train_loader = DataLoader(datasets.mnist.MNIST("./data", download=True, transform=transforms.ToTensor(), train=True), batch_size=32, shuffle=True)
val_loader = DataLoader(datasets.mnist.MNIST("./data", download=True, transform=transforms.ToTensor(), train=False), batch_size=32, shuffle=False)

# model
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(),
                                 nn.Linear(128, 256), nn.ReLU(),
                                 nn.Linear(256, 10), nn.LogSoftmax(-1))
        
    def forward(self, x):
        return self.net(x)


model = MLP()
opt = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.NLLLoss()

def step(x, y, is_train=True):
    x = x.reshape(-1, 28 * 28)
    y_pred = model(x)
    loss = criterion(y_pred, y)
    if is_train:
        opt.zero_grad()
        loss.backward()
        opt.step()
    return loss, y_pred
    
for e in range(10):
    tic = time.time()
    train_loss, val_loss, acc = 0., 0., 0.
    for x, y in train_loader: 
        model.train()
        loss, y_pred = step(x, y, is_train=True)
        train_loss += loss.item()
    train_loss /= len(train_loader)

    for x, y in val_loader:
        model.eval()
        with torch.no_grad():
            loss, y_pred = step(x, y, is_train=False)
        val_loss += loss.item()
        acc += (y_pred.max(-1)[1] == y).float().mean()
    val_loss /= len(val_loader)
    acc /= len(val_loader)
    elapsed = time.time() - tic
    
    print(f"train_loss: {train_loss:0.2f}, val_loss: {val_loss:0.2f}, val_acc: {acc:0.2f}, elapsed: {elapsed:0.2f}")

まあなんかいつもの感じですね。一つポイントを挙げるなら、学習・推論の1ステップをstep関数に切り出しておきます。こうすることで、移行が簡単になります。回すとだいたいこんな感じで、だいたい1epoch 14秒ちょっとです。

train_loss: 0.25, val_loss: 0.13, val_acc: 0.96, elapsed: 14.31
train_loss: 0.10, val_loss: 0.10, val_acc: 0.97, elapsed: 14.93
train_loss: 0.07, val_loss: 0.08, val_acc: 0.97, elapsed: 14.97
train_loss: 0.05, val_loss: 0.09, val_acc: 0.98, elapsed: 15.04
train_loss: 0.04, val_loss: 0.08, val_acc: 0.98, elapsed: 15.04
train_loss: 0.03, val_loss: 0.08, val_acc: 0.98, elapsed: 15.05
train_loss: 0.03, val_loss: 0.08, val_acc: 0.98, elapsed: 15.14
train_loss: 0.03, val_loss: 0.08, val_acc: 0.98, elapsed: 14.99
train_loss: 0.02, val_loss: 0.07, val_acc: 0.98, elapsed: 15.04
train_loss: 0.02, val_loss: 0.11, val_acc: 0.98, elapsed: 15.06

データローダをtensorflow-datasetsで書き直す

これは実のところjaxあんま関係ないのですが、データローダをPyTorch純正のものから、tensorflow-datasets (TFDS) のものに取り換えます。TFDSは「次のバッチをプリフェッチして高速化できる」「map関数を使った前処理やデータ拡張が簡単」「numpyでロードできる」などいくつかのうれしい機能があり、PyTorchと組み合わせても有効に使えます。以下を参考にしました -> Pytorch+Tensorflowのちゃんぽんコードのすゝめ(tfdsでpytorchをブーストさせる話)

置き換えるとこんな感じになります。as_numpy_iterator()を使って、バッチをnumpy.arrayで吐くイテレータが使えます。

# ...
# 前略
# ...

import tensorflow as tf
import tensorflow_datasets as tfds

def preprocessing(x, y):
    x = tf.cast(x, tf.float32) / 255.
    return x, y

ds = tfds.load("mnist", as_supervised=True, shuffle_files=False, download=True)
train_set = ds["train"]
train_set = train_set.shuffle(len(train_set), seed=0, reshuffle_each_iteration=True).batch(32).map(preprocessing).prefetch(1)
val_set = ds["test"]
val_set = val_set.batch(32).map(preprocessing).prefetch(1)

# ...
# 中略
# ...

for e in range(10):
    tic = time.time()
    train_loss, val_loss, acc = 0., 0., 0.
    for x, y in train_set.as_numpy_iterator():
        x = torch.from_numpy(x)
        y = torch.from_numpy(y)
        model.train()
        loss, y_pred = step(x, y, is_train=True)
        train_loss += loss.item()
    train_loss /= len(train_set)

# ...
# 後略
# ...

回してみましょう。

train_loss: 0.26, val_loss: 0.13, val_acc: 0.96, elapsed: 13.90
train_loss: 0.11, val_loss: 0.10, val_acc: 0.97, elapsed: 10.88
train_loss: 0.07, val_loss: 0.09, val_acc: 0.97, elapsed: 7.21
train_loss: 0.06, val_loss: 0.08, val_acc: 0.98, elapsed: 7.21
train_loss: 0.04, val_loss: 0.08, val_acc: 0.97, elapsed: 10.61
train_loss: 0.04, val_loss: 0.08, val_acc: 0.98, elapsed: 10.61
train_loss: 0.03, val_loss: 0.08, val_acc: 0.98, elapsed: 7.32
train_loss: 0.03, val_loss: 0.09, val_acc: 0.98, elapsed: 7.36
train_loss: 0.02, val_loss: 0.08, val_acc: 0.98, elapsed: 7.54
train_loss: 0.02, val_loss: 0.11, val_acc: 0.98, elapsed: 7.33

倍くらい速くなっていますね。TFDSすごい。

ネットワークをflax.linenで書き直す

さてここからが本番です。torch.nnに対応するモジュールとして、flax.linenがあり、この中に各種レイヤーが実装されています。PyTorchのnn.Moduleと違い、__init__forwardを書く必要はなく、以下のような@fnn.compactでデコレートされた__call__を書くことになります。

import jax
import jax.numpy as jnp
import flax.linen as fnn

class MLP(fnn.Module):
    @fnn.compact
    def __call__(self, x):
        x = fnn.Dense(128)(x)
        x = fnn.relu(x)
        x = fnn.Dense(256)(x)
        x = fnn.relu(x)
        x = fnn.Dense(10)(x)
        x = fnn.log_softmax(x)
        return x

PyTorchとflaxの大きな違いは、モデル自体がパラメタを保持しないという点です。flaxでは以下のように、パラメタが別のオブジェクトとして作成されます。

model = MLP()
params = model.init(jax.random.PRNGKey(0), jnp.ones([1, 28 * 28]))['params']

いろいろ出てきましたが順番に説明すると、

  • jax.numpynumpyに入っているいろんな関数が実装されています。たいていのnumpy関数はjax.numpyにも入っています。
  • model.initでは、特定の疑似乱数生成器(jax.random.PRNGKey)を使い、モデルのパラメタを初期化します。このとき、入力テンソルの形をjnp.ones([1, 28 * 28])のように与えることで、チャネル数が自動的によしなにされます(flattenしたあと入力のチャネル数がいくつになるっけ・・・?みたいなのを手計算しなくて良くなります)
  • このparamsの中身を見てみると、
FrozenDict({
    Dense_0: {
        kernel: DeviceArray([[-6.9373280e-02, -1.7459888e-02,  2.6923235e-04, ...,
                      -3.5230294e-02, -3.9390869e-02,  1.8164413e-02],
                     [ 7.5236415e-03,  2.2690799e-02, -1.8007368e-05, ...,
                       2.8825440e-02,  5.4414038e-02,  1.0891268e-02],
                     [-1.6200040e-02, -4.6700433e-02,  5.5902313e-02, ...,
                       5.1038559e-03, -3.8261801e-02, -1.7489832e-02],
                     ...,
                     [-1.9485658e-02, -1.1464893e-02,  1.1733949e-02, ...,
                      -3.8100831e-02,  4.0606514e-02, -2.5036685e-02],
                     [-7.7254638e-02,  3.7608985e-02,  5.9655067e-02, ...,
                       3.0473005e-02, -3.4684684e-02, -2.2349501e-02],
                     [ 5.9532784e-02,  6.1933022e-02,  2.4909737e-02, ...,
                      -3.8187259e-03, -1.4835574e-02,  4.1087948e-02]],            dtype=float32),
        bias: DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
    },
    Dense_1: {

みたいな感じで、Dict形式でパラメタが保存されています。

flax.training.train_state.TrainStateを使ってひとまとめに

Optimizerや後述の損失関数は、optaxというライブラリに入っているのでこれを使います。また、flaxにはTrainStateという、モデルのforward, パラメタ、optimzerをひとまとめにする便利なクラスがあるので、こいつも使います。

from flax.training.train_state import TrainState
import optax

model = MLP()
params = model.init(jax.random.PRNGKey(0), jnp.ones([1, 28 * 28]))['params']
tx = optax.adam(0.001)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)

学習の1ステップをJAXで書き直してjax.jitでデコレートする

def step(x, y, is_train=True):
    x = x.reshape(-1, 28 * 28)
    y_pred = model(x)
    loss = criterion(y_pred, y)
    if is_train:
        opt.zero_grad()
        loss.backward()
        opt.step()
    return loss, y_pred

これをjaxで書き直すとこんな感じになります。

@partial(jax.jit, static_argnums=(3,))
def step(x, y, state, is_train=True):
    def loss_fn(params):
        y_pred = state.apply_fn({'params': params}, x) 
        loss = optax.softmax_cross_entropy(logits=y_pred, labels=y).mean()
        return loss, y_pred
    x = x.reshape(-1, 28 * 28)
    y = jnp.eye(10)[y]  # optax.softmax_cross_entropyはone-hotを受け取る
    if is_train:
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss, y_pred), grads = grad_fn(state.params)
        state = state.apply_gradients(grads=grads)
    else:
        loss, y_pred = loss_fn(state.params)
    return loss, y_pred, state

ちょっとややこしいですね。一つずつみていきましょう。

    def loss_fn(params):
        y_pred = state.apply_fn({'params': params}, x) 
        loss = optax.softmax_cross_entropy(logits=y_pred, labels=y).mean()
        return loss, y_pred

ここはモデルのforward~損失関数の定義です。先に定義したTrainStateのインスタンスstateの持つstate.apply_fn関数にモデルのパラメタparamsと入力xを食わせると、モデルの出力y_predが得られます。損失関数はoptaxの中にいろいろ入っています。

grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, y_pred), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)

上記で定義したloss_fnについての勾配を計算する関数grad_fnを用意します。lossを出力しなくていい場合は、jax.gradを使います。loss_fnがlossだけ返す場合はhas_aux=False、loss以外も返す(y_predとか)場合はhas_aux=Trueにします。この関数でgradsを計算でき、その後のstate.apply_gradientsに与えることで、モデルのパラメタが更新されます。

else:
    loss, y_pred = loss_fn(state.params)

推論時でgradientの計算が不要な場合は、loss_fnをそのまま使います。

@partial(jax.jit, static_argnums=(3,))
def step(x, y, state, is_train=True):

jax.jitを使ってstep関数をXLAコンパイルし、高速化を目指します。詳細は以下のブログが詳しかったです -> JAX入門~高速なNumPyとして使いこなすためのチュートリアル~

ポイントとしては、is_trainは固定値なのでstatic_argnumsで指定する必要があります。

完成: JAX移行後

import time
from functools import partial

import jax
import jax.numpy as jnp
import flax.linen as fnn
from flax.training.train_state import TrainState
import optax
import tensorflow as tf
import tensorflow_datasets as tfds

def preprocessing(x, y):
    x = tf.cast(x, tf.float32) / 255.
    
    return x, y

ds = tfds.load("mnist", as_supervised=True, shuffle_files=False, download=True)
train_set = ds["train"]
train_set = train_set.shuffle(len(train_set), seed=0, reshuffle_each_iteration=True).batch(32).map(preprocessing).prefetch(1)
val_set = ds["test"]
val_set = val_set.batch(32).map(preprocessing).prefetch(1)

# model
class MLP(fnn.Module):
    @fnn.compact
    def __call__(self, x):
        x = fnn.Dense(128)(x)
        x = fnn.relu(x)
        x = fnn.Dense(256)(x)
        x = fnn.relu(x)
        x = fnn.Dense(10)(x)
        x = fnn.log_softmax(x)
        
        return x

model = MLP()
params = model.init(jax.random.PRNGKey(0), jnp.ones([1, 28 * 28]))['params']
tx = optax.adam(0.001)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)

@partial(jax.jit, static_argnums=(3,))
def step(x, y, state, is_train=True):
    def loss_fn(params):
        y_pred = state.apply_fn({'params': params}, x) 
        loss = optax.softmax_cross_entropy(logits=y_pred, labels=y).mean()
        return loss, y_pred
    x = x.reshape(-1, 28 * 28)
    y = jnp.eye(10)[y]
    if is_train:
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss, y_pred), grads = grad_fn(state.params)
        state = state.apply_gradients(grads=grads)
    else:
        loss, y_pred = loss_fn(state.params)
    return loss, y_pred, state

for e in range(10):
    tic = time.time()
    train_loss, val_loss, acc = 0., 0., 0.
    for x, y in train_set.as_numpy_iterator(): 
        loss, y_pred, state = step(x, y, state, is_train=True)
        train_loss += loss
    train_loss /= len(train_set)

    for x, y in val_set.as_numpy_iterator(): 
        loss, y_pred, state = step(x, y, state, is_train=False)
        val_loss += loss
        acc += (jnp.argmax(y_pred, 1) == y).mean()
    val_loss /= len(val_set)
    acc /= len(val_set)
    elapsed = time.time() - tic
    
    print(f"train_loss: {train_loss:0.2f}, val_loss: {val_loss:0.2f}, val_acc: {acc:0.2f}, elapsed: {elapsed:0.2f}")

回してみましょう

train_loss: 0.22, val_loss: 0.11, val_acc: 0.96, elapsed: 13.84
train_loss: 0.10, val_loss: 0.09, val_acc: 0.97, elapsed: 6.18
train_loss: 0.07, val_loss: 0.09, val_acc: 0.97, elapsed: 4.02
train_loss: 0.05, val_loss: 0.08, val_acc: 0.98, elapsed: 6.18
train_loss: 0.04, val_loss: 0.08, val_acc: 0.98, elapsed: 6.40
train_loss: 0.04, val_loss: 0.07, val_acc: 0.98, elapsed: 6.18
train_loss: 0.03, val_loss: 0.08, val_acc: 0.98, elapsed: 6.17
train_loss: 0.03, val_loss: 0.09, val_acc: 0.98, elapsed: 6.18
train_loss: 0.02, val_loss: 0.09, val_acc: 0.98, elapsed: 3.80
train_loss: 0.02, val_loss: 0.10, val_acc: 0.98, elapsed: 6.18

最初のエポックだけコンパイルが入るので時間がかかりますが、その後はかなり高速化されていますね。

実装 dataloader runtime/epoch (s)
PyTorch PyTorch 14秒前後
PyTorch TFDS 7.5~10秒前後
JAX TFDS 6秒前後

まとめ

MLPの学習を例に、PyTorchのコードをJAX (Flax + Optax) に移行する方法を紹介しました。割とすんなり移行できそうに感じます。今後FlaxとOptaxでカバーできるNNレイヤーやoptimizer/loss functionの種類が増えてくると、さらに使いやすくなりそうです。

今後、以下のような内容を取り上げたいと思います。

  • GPUでの学習
  • サンプリングが含まれるモデルの学習(VAE, GANなど)

Discussion

興味深い記事ありがとうございます!
完成: JAX移行後が、PyTorchのコードになってしまっている気がします

コメントありがとうございます!完全におっしゃるとおりでした & 修正しました!

ログインするとコメントできます