⚔️

JAX の学習ステップはクロージャとして実装するといい

に公開

この記事では、JAX を使って深層学習を実装する際、学習ステップをクロージャとして実装するというアイデアを紹介します。

学習ステップ

この記事において、学習ステップとはミニバッチを 1 つ受け取ってモデルやオプティマイザの状態を更新する関数とします。JAX を用いた実装では、以下の関数 train_step のようなイメージで実装されることが多いです。

def compute_loss(params, batch):
    # Flax などのライブラリを使って順伝播を行い、ロスを計算する
    ...

def update_train_state(grads, train_state):
    # Optax などのライブラリを使ってパラメータの更新を行う
    ...

def train_step(train_state, batch):
    inputs, labels = batch

    # ニューラルネットワークのパラメータについて、その勾配を計算
    grad_fn = jax.grad(compute_loss)
    grads = grad_fn(train_state.params, batch)

    # 勾配を使ってニューラルネットワークのパラメータやオプティマイザの状態などを更新
    new_train_state = update_train_state(grads, train_state)
    return new_train_state

train_step は、高速化のために jax.jit などを使って JIT コンパイルされた後、学習が完了するまで繰り返しモデルの状態を更新します。

num_epochs = 100  # エポック数
train_state = ...  # モデルのパラメータやオプティマイザの状態などを初期化
data_loader = ... # ミニバッチを生成するデータローダ。例: tf.data.Dataset, torch.utils.data.DataLoader

p_train_step = jax.jit(train_step)  # 学習ステップを JIT コンパイル
for epoch in range(num_epochs):
    for batch in iter(data_loader):
        train_state = p_train_step(train_state, batch)

ハイパーパラメータの注入

モデルの学習を行う際、少しでも複雑な学習を行おうとすると、ほぼ確実に学習ステップ train_step 内でハイパーパラメータが必要になります。

最も単純に情報を追加する方法は train_step 内で固定値としてハイパーパラメータを定義してしまう方法ですが、柔軟性などの観点から好ましくありません。そのため、外部からハイパーパラメータを渡す方法を考える必要があります。

具体的には、以下の方法が簡単に思いつくでしょうか。

  • train_step の引数にハイパーパラメータを設定する
  • train_step をクラスとして定義し、初期化時にハイパーパラメータを渡す

私は、いずれの方法も微妙であると考えています。
これから、その理由について説明します。

train_step の引数にハイパーパラメータを設定する

train_step の引数にハイパーパラメータを設定する方法は、以下のように実装できます。

def train_step(train_state, batch, **hyperparams):
    ...

hyperparams が必要なハイパーパラメータを受け取るための引数となります。後は、train_step を JIT コンパイルし、呼び出す際にハイパーパラメータを渡すだけです。

p_train_step = jax.jit(train_step)  # 学習ステップを JIT コンパイル
for epoch in range(num_epochs):
    for batch in iter(data_loader):
        # ここでエラーになりうる
        train_state = p_train_step(train_state, batch, **hyperparams)

ところが、この方法だとエラーが発生する可能性があります。なぜなら、train_step に渡される hyperparams が必ずしも PyTree であるとは限らないためです。(PyTree の説明については省略します。気になる方は、公式ドキュメント等を参考にしてください。)

このエラーを避けるためには、JIT コンパイル時に hyperparams の要素が PyTree でないことを明示的に伝えるために、static_argnamesを指定する必要があります。

hyperparams = {
    "loss_fun": ...,  # 損失関数を計算するための関数。関数は PyTree ではないオブジェクトのひとつ
}
p_train_step = jax.jit(train_step, static_argnames=["loss_fun"])
for epoch in range(num_epochs):
    for batch in iter(data_loader):
        train_state = p_train_step(train_state, batch, **hyperparams)

ただ、static_argnames を指定するのは面倒ですし、忘れがちです。そのため、保守性や可読性の面でも難があるため、train_step の引数にハイパーパラメータを設定する方法はあまりおすすめできません。

train_step をクラスとして定義し、初期化時にハイパーパラメータを渡す

train_step をクラスとして定義し、初期化時にハイパーパラメータを渡す方法は、以下のように実装できます。

class TrainStep:
    def __init__(self, **hyperparams):
        self.hyperparams = hyperparams

    def __call__(self, train_state, batch):
        ...

train_step がクラスとして定義されているため、ハイパーパラメータを初期化時に渡すことができます。

この方法は、深層学習の実装に習熟している人にとっては自然な方法に見えます。例えば、PyTorch Lightning はこの方法を採用していますよね。

しかし、個人的な意見としては、この方法と JAX は相性が悪いと感じています。なぜなら、JAX はクラスと組み合わせた際に癖のある挙動をすることがあるためです。

具体的な癖のある挙動の例として、JIT コンパイルを実行した後にクラスの属性値を変更しても、__call__ メソッド内部ではその変更が反映されないという点が挙げられます。以下のコードを見てください。

import jax

class AddNumber:
    """入力に number を加える関数"""
    def __init__(self):
        self.number = 0

    def __call__(self, x):
        return x + self.number

add_number = AddNumber()
p_add_number = jax.jit(add_number)

AddNumber は、入力値に number を加えるシンプルな処理を行うクラスです。このクラスを初期化し、JIT コンパイルを行っています。では、AddNumberに対して以下のようなコードを実行したらどうなるでしょうか?

print(add_number.number, p_add_number(jax.numpy.array(1.0)))

add_number.number = 1
print(add_number.number, p_add_number(jax.numpy.array(1.0)))

結果は、以下のような出力を得られます。

0 1.0
1 1.0

この出力からわかるように、add_number.number の値を変更しても、__call__ メソッド内部ではその変更が反映されていません。一方で add_number 自体の属性値は変更されているため、JIT コンパイルされたメソッドとその他で属性値の不一致が生じています。

実は、この現象は以下のように __call__ メソッドを直接 JIT コンパイルしても避けることは出来ません。

from functools import partial

class AddNumber:
    """入力に number を加える関数"""
    def __init__(self):
        self.number = 0

    @partial(jax.jit, static_argnames="self")
    def __call__(self, x):
        return x + self.number

このような挙動は、JAX の仕様通りの挙動ではあるのですが、Python の一般的な挙動とは異なるため、JAX 初学者にとって混乱の原因となりえます。例えば、あるエポック数に達した段階で損失関数を別のものに切り替えたい場合、train_step がクラスとして定義されていると、ついついそのクラスの属性値を変更したくなってしまいませんか?

そのような誤った使い方を避けるためにも、train_step をクラスとして定義する方法はあまり良くないと考えています。

提案:学習ステップをクロージャとして実装する

上記の問題点を踏まえて、私がおすすめするのはクロージャとして学習ステップを実装する方法です。以下に、クロージャを使った学習ステップの実装例を示します。

def make_train_step(**hyperparams):

    def train_step(train_state, batch):
        ...

    return train_step

train_step = make_train_step(model, optimizer)

このクロージャを使った実装方法を使うことで、クラスを使った方法と同様に train_step の引数にハイパーパラメータを渡す必要がなくなります。そのため、「train_step の引数にハイパーパラメータを設定する」で述べた問題点を回避することができます。

更に、ハイパーパラメータが変更できないことを明示的に示せるというメリットもあります。
JAX と関係なく、Python の仕様上クロージャの外部にある変数を変更することは基本的にはできません。そのため、IDE や GitHub Copilot などによるコード補完機能を使っても、ハイパーパラメータの変更が提案される可能性を抑えることが可能です。

具体例

具体的な例として、JAX を使って、label smoothing を適用したクラス分類の実装例を示します。学習ステップに加えて、初期化関数もクロージャとして実装し、ハイパーパラメータとして以下の値を受け取るようにしました。

  • flax_model: Flax で実装されたモデル
  • optimizer: オプティマイザ
  • num_classes: クラス数
  • label_smoothing: ラベルスムージングの値
from typing import NamedTuple, Callable
import jax.numpy as jnp
from flax import linen as nn
import optax


class TrainState(NamedTuple):
    params: dict
    opt_state: dict


class Trainer(NamedTuple):
    init_step: Callable
    train_step: Callable


def make_trainer(
    flax_model: nn.Module,
    optimizer: optax.GradientTransformation,
    num_classes: int,
    label_smoothing: float = 0.1,
) -> Trainer:

    def init_step(rng, inputs):
        params = flax_model.init(rng, inputs)
        opt_state = optimizer.init(params)
        return TrainState(params=params, opt_state=opt_state)

    def train_step(train_state, batch):
        inputs, labels = batch
        labels = labels * (1 - label_smoothing) + label_smoothing / num_classes

        def compute_loss(params, batch):
            logits = flax_model.apply(params, inputs)
            loss = -jnp.mean(labels * jax.nn.log_softmax(logits))
            return loss

        grad_fn = jax.grad(compute_loss)
        grads = grad_fn(train_state.params, batch)

        updates, new_opt_state = optimizer.update(grads, train_state.opt_state)
        new_params = optax.apply_updates(train_state.params, updates)
        return TrainState(params=new_params, opt_state=new_opt_state)

    return Trainer(init_step, train_step)


# ハイパーパラメータの設定
flax_model = ...
optimizer = optax.adam(1e-3)
num_classes = 10
label_smoothing = 0.1

# ダミーデータの準備
inputs = jnp.ones((32, 28, 28, 1))
labels = jnp.eye(num_classes)[jnp.arange(32) % num_classes]
train_data = [(inputs, labels)] * 100

# 初期化関数と学習ステップ関数を作成
trainer = make_trainer(flax_model, optimizer, num_classes, label_smoothing)
p_train_step = jax.jit(trainer.train_step)

# 状態の初期化 & 学習
train_state = trainer.init_step(jax.random.PRNGKey(0), inputs)
for _ in range(100):
    for batch in iter(train_data):
        train_state = p_train_step(train_state, batch)

いかがでしょうか。個人的にはかなりスッキリしたコードになっていると感じています。

まとめ

この記事では、JAX を使って深層学習を実装する際、学習ステップをクロージャとして実装するというアイデアを紹介しました。クロージャを使うことで、学習ステップの可読性や保守性を向上させることができると考えています。

また、クロージャを使った方法は、最後の実装例で示したように学習ステップだけでなく、状態の初期化を行う関数や、モデルを評価する関数などにも応用することができます。

とはいえ、この方法が JAX のベストプラクティスであるとは限りません。
※個人的にはそう思っていますが

ご意見やご感想、アドバイスなどがあれば、ぜひコメント欄にお寄せください。

Discussion