📐

JAXベースの深層学習ライブラリMetoryxを開発している

に公開

JAX は、GPU/TPU での高速計算や自動微分が魅力的な Google 製の数値計算ライブラリです。また、関数型プログラミングのスタイルを採用しており、副作用のない純粋関数を基本としているという点も特徴的です。

JAX は特に深層学習の分野で活用されており、HaikuFlax といった深層学習ライブラリが JAX を基盤として構築され、利用されてきました。これらのライブラリは、init-apply 式の API を採用しており、モデルとそのパラメータを明確に分離する「関数型」の考え方に基づいています。このおかげで JAX の自動微分や JIT コンパイルといった強力な機能を最大限に活用できるのですが、一方で転移学習や LoRA のようにモデルの一部を動的に置き換えるような操作が煩雑になりがちであり、そもそもオブジェクト指向に慣れ親しんだユーザにとっては直感的に理解しづらい部分があるといった弱点もありました。

そのため、近年では Flax の新しい API である NNX や Equinox のような、モデルとパラメータを一体化させたオブジェクト指向に近いライブラリが人気を集めています。一方、init-apply 式のライブラリは、以下のように下火になりつつあります。

  • Haiku: 2023 年 7 月にメンテナンスモードに移行
  • 従来の Flax API (Linen): 非推奨とまではいかないが、新規ユーザは Flax NNX を使うことが推奨されている

しかし、世の中には JAX とシームレスに連携できる init-apply 式の API に魅力を感じているユーザもいるのではないでしょうか?(少なくとも、私はそうです。)
そんなユーザのために、JAX の関数型スタイルを貫きつつ、弱点だった「扱いにくさ」を克服した新しいライブラリ Metoryx を開発しています。

インストール

アルファ版ではありますが、PyPI で公開しています。以下のコマンドでインストールできます。

pip install metoryx

また、以下のリンクも興味があれば参照してください。

Metoryx の思想

Metoryx は、モデルを定義する定義フェイズと、モデルパラメータの初期化や実際の計算を行う実行フェイズ の 2 つのフェイズを明確に分離し、それぞれに最適化されたパラダイムを提供するという思想に基づいて開発されています。具体的には、定義フェイズではオブジェクト指向のスタイルを採用し、計算フェイズでは関数型プログラミングのスタイルを採用しています。

言葉で説明するよりも、実際にコードを見た方が早いでしょう。以下に Metoryx の使用例を示します。

import jax
import jax.numpy as jnp
import jax.random as jr
import metoryx as mx

#
# 定義フェイズ (オブジェクト指向)
#
class Mlp(mx.Module):
    def __init__(self, in_size, hidden_size, out_size):
        super().__init__()
        self.layer1 = mx.Dense(in_size, hidden_size)
        self.layer2 = mx.Dense(hidden_size, out_size)

    def __call__(self, x):
        x = self.layer1(x)
        x = mx.relu(x)
        return self.layer2(x)

model = Mlp(784, 256, 10)

#
# 定義フェイズから実行フェイズへの移行
#
init, apply = mx.transform(model)

#
# 実行フェイズ (関数型プログラミング)
#
variables = init(jr.PRNGKey(0))
outputs, new_variables = apply(variables, jr.PRNGKey(1), jnp.ones((1, 784)))

この例では、まず Mlp クラスを定義して、2 層の全結合ニューラルネットワークをオブジェクト指向のスタイルで表現しています。mx.Module クラスを継承し、各レイヤーをクラスの属性として持たせています。次に、mx.transform 関数を使って、定義フェイズで作成したモデルを実行フェイズで使用するための initapply という 2 つの関数に変換しています。最後に、init 関数を使ってモデルのパラメータを初期化し、apply 関数を使って実際に計算を行っています。

このモデルを関数に変換するアプローチを Metoryx では Module Transformation と呼んでいます。Metoryx はこのようなモジュールの変換を通じて、柔軟で直感的なオブジェクト指向のモデル定義と、JAX の関数型プログラミングのスタイルを両立しています。

次に、各フェイズの詳細を見ていきましょう。

実行フェイズ

まずは、Flax Linen や Haiku との差分が少ない実行フェイズから説明します。実行フェイズでは、init 関数と apply 関数を使って、モデルのパラメータの初期化と計算を行います。これらの関数は、原則として純粋関数であり、副作用を持ちません。そのため、JAX の JIT コンパイルや自動微分と組み合わせて使用することができます。

init関数

init 関数は、JAX の PRNG キーを引数に取り、モデルのパラメータを初期化して返します。

variables = init(jr.PRNGKey(0))

init 関数が出力する variables は、モデルのパラメータを格納した辞書であり、その構造は Flax Linen と同じものです。具体的には、以下のような構造を持ちます。

{
    'params': {
        'layer1': {
            'kernel': ...,
            'bias': ...,
        },
        'layer2': {
            'kernel': ...,
            'bias': ...,
        },
    },
}

最初のキー 'params' の下に、各レイヤーのパラメータが格納されています。この 'params' はコレクションと呼ばれ、誤差逆伝播で学習されるパラメータが含まれています。その他にも、例えばバッチ正規化の平均や分散のような学習されない状態を格納するための 'batch_stats' など、様々なコレクションを定義できます。

apply関数

apply 関数は、モデルのパラメータを含む variables 辞書、JAX の PRNG キー、そして入力データを引数に取り、モデルの出力と更新された variables 辞書を返します。

outputs, new_variables = apply(variables, jr.PRNGKey(1), jnp.ones((1, 784)))

Flax Linen や Haiku を使ったことがある方にとっては、見慣れた API ではないでしょうか。PRNG キーの扱い方には若干の違いがありますが、これは後ほど説明します。

パラメータに対する勾配の計算

深層学習モデルを学習する場合には、variables['params'] に対して誤差逆伝播を行い、パラメータを更新するのが一般的です。その場合、以下のようにして勾配を計算することになります。

# 第一引数に対して勾配を計算
def loss_fn(params, model_state, inputs, targets):
    variables = {'params': params, **model_state}  # 'params' とその他のコレクションをまとめる
    outputs, new_state = apply(variables, jr.PRNGKey(1), inputs)
    loss = ...
    return loss, new_state

variables = init(jr.PRNGKey(0))

# パラメータと状態を分離
model_state = variables
params = model_state.pop("params")

# パラメータに対する勾配を計算
grads, new_state = jax.grad(loss_fn, has_aux=True)(params, model_state, inputs, targets)

これも、Flax Linen や Haiku ではよく見られる実装パターンです。

定義フェイズ

定義フェイズでは、モデルのアーキテクチャを定義します。他のライブラリと同じように、mx.Module クラスを継承したクラスとしてモデルを定義し、各レイヤーをクラスの属性として持たせます。例えば、以下のようにして 2 層の全結合ニューラルネットワークを定義できます。

import metoryx as mx


class Mlp(mx.Module):
    def __init__(self, in_size, hidden_size, out_size):
        super().__init__()
        self.layer1 = mx.Dense(in_size, hidden_size)
        self.layer2 = mx.Dense(hidden_size, out_size)

    def __call__(self, x):
        x = self.layer1(x)
        x = mx.relu(x)
        return self.layer2(x)

# インスタンス化
model = Mlp(784, 256, 10)

また、リストや辞書を使って、以下のようにレイヤーをまとめることもできます。

class MlpList(mx.Module):
    def __init__(self, in_size, hidden_sizes, out_size):
        super().__init__()
        self.layers = [mx.Dense(in_size, hidden_sizes[0])]
        for i in range(1, len(hidden_sizes)):
            self.layers.append(mx.Dense(hidden_sizes[i-1], hidden_sizes[i]))
        self.layers.append(mx.Dense(hidden_sizes[-1], out_size))

    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = mx.relu(layer(x))
        return self.layers[-1](x)

モデルの動的な変更

ここからが、Metoryx が Flax Linen や Haiku と大きく異なる部分です。Metoryx では、モデルのインスタンスに対して直接変更を加えることができます。

例えば、モデルのレイヤーを置き換えたい場合は、以下のように実装します。

model = Mlp(784, 256, 10)
model.layer2 = mx.Dense(256, 20)  # 出力層を置き換え

状態やパラメータの初期値を事前に割り当てておくこともできます。初期値を割り当てておくことで、init 関数が呼ばれた際に、その初期値が返されます。

model = Mlp(784, 256, 10)
model.layer1.kernel.value = jnp.zeros((784, 256))  # layer1 のカーネルの初期値をゼロに割り当て

※割り当てられる値の形状は、必ずレイヤーのパラメータの形状と一致している必要があります。もし形状が一致していない場合は、その時点でエラーになります。

この機能を使えば、転移学習は以下のように実装できます。

model = Mlp(784, 256, 10)

# ダミーの事前学習済みパラメータを用意
pretrained_variables = {
    'params': {
        'layer1': {
            'kernel': jnp.ones((784, 256)),
            'bias': jnp.ones((256,)),
        },
        'layer2': {
            'kernel': jnp.ones((256, 10)),
            'bias': jnp.ones((10,)),
        },
    },
}

# 事前学習済みパラメータを割り当て
model = mx.assign_variables(model, pretrained_variables)
model.layer2 = mx.Dense(256, 2)  # 出力層の次元数を変更し、ランダムに初期化

ここで、mx.assign_variables 関数は、与えられた variables 辞書に基づいて、モデルの各レイヤーのパラメータを割り当てるヘルパーメソッドです。どうでしょうか?少しでも Flax Linen や Haiku を知っている方にとっては、Metoryx の API が驚くほど柔軟で直感的であることが伝わるのではないでしょうか。

Module Transformation

ここまで見てきたように、Metoryx では定義フェイズと実行フェイズを明確に分離し、それぞれに最適化されたパラダイムを提供しています。この 2 つのフェイズを橋渡しするのが、Module Transformation です。

Module Transformation は、定義フェイズで作成したモデルを実行フェイズで使用するための init 関数と apply 関数に変換するプロセスです。この変換は、mx.transform 関数を使って行います。

model = Mlp(784, 256, 10)
init, apply = mx.transform(model)

このとき、mx.transform は内部で model を deepcopy した新しいオブジェクトを作成し、そのオブジェクトに対して変換を行います。そのため、Module Transformation によって生成された関数は、もとのオブジェクトが変更されても影響を受けることはありません。この単純な仕組みによって、initapply を純粋関数として扱えるようになり、jax.jit によるコンパイルや jax.vmap による自動ベクトル化といった JAX の強力な機能を最大限に、そして安全に活用できるようになります。

Tips

用途によっては、__call__ 以外のメソッドを apply 関数として使用したい場合もあるでしょう。その場合は、mx.transform 関数の to_callable 引数に以下のような関数を渡すことで、任意のメソッドを apply 関数として使用できます。

def to_callable(model):
    return model.forward  # 例えば forward メソッドを使用したい場合

init, apply = mx.transform(model, to_callable=to_callable)

実装例: MNIST

最後に、Metoryx を使って MNIST の手書き数字分類を実装した例を示します。このコードを実行するには、metoryx に加えて optaxtensorflowtensorflow-datasets のインストールが必要です。

import itertools

import jax
import jax.random as jr
import metoryx as mx
import optax

import tensorflow as tf
import tensorflow_datasets as tfds


# ハイパーパラメータ
random_seed = 42
num_epochs = 10
batch_size = 64

# MNIST のメタデータ
num_images = 60000
num_steps_per_epoch = num_images // batch_size

# Tensorflow の乱数シードを設定
tf.random.set_seed(random_seed)


# MNIST データセットを読み込むための関数
def get_dataset(num_epochs, batch_size, split):
    ds = tfds.load("mnist", split=split)
    ds = ds.map(
        lambda item: {
            "image": tf.cast(item["image"], tf.float32) / 255.0,
            "label": item["label"],
        }
    )
    ds = ds.repeat(num_epochs)
    if split == "train":
        ds = ds.shuffle(1024)
    ds = ds.batch(batch_size, drop_remainder=(split == "train"))
    ds = ds.prefetch(1)
    return ds


# モデルの定義
class ConvNet(mx.Module):
    """A simple CNN for MNIST classification."""

    def __init__(self):
        super().__init__()
        self.conv1 = mx.Conv(1, 32, kernel_size=(3, 3), strides=(1, 1), padding="SAME")
        self.conv2 = mx.Conv(32, 64, kernel_size=(3, 3), strides=(1, 1), padding="SAME")
        self.fc1 = mx.Dense(7 * 7 * 64, 128)
        self.fc2 = mx.Dense(128, 10)

    def __call__(self, x, is_training):
        x = self.conv1(
            x
        )  # Input shape is expected to be (batch, height, width, channels)
        x = mx.relu(x)
        x = mx.max_pool(x, kernel_size=(2, 2), strides=(2, 2))
        x = self.conv2(x)
        x = mx.relu(x)
        x = mx.max_pool(x, kernel_size=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # Flatten
        x = mx.dropout(x, rate=0.5, is_training=is_training)
        x = self.fc1(x)
        x = mx.relu(x)
        return self.fc2(x)


# モデルのインスタンス化と Module Transformation の実行
net = ConvNet()
init, apply = mx.transform(net)


# 学習ステップと評価ステップの定義
# JIT コンパイルを適用して高速化
@jax.jit
def train_step(rng, params, state, opt_state, batch):
    def loss_fn(params):
        variables = {"params": params, **state}
        logits, new_state = apply(variables, rng, batch["image"], is_training=True)
        new_state.pop("params")

        # ロスと精度の計算
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits=logits, labels=batch["label"]
        ).mean()
        accuracy = (logits.argmax(axis=-1) == batch["label"]).mean()
        log_dict = {"loss": loss, "accuracy": accuracy}

        return loss, (new_state, log_dict)

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, (new_state, log_dict)), grads = grad_fn(params)

    # パラメータの更新
    updates, new_opt_state = optimizer.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_state, new_opt_state, log_dict


@jax.jit
def eval_step(params, state, batch):
    variables = {"params": params, **state}
    logits, _ = apply(variables, None, batch["image"], is_training=False)
    loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=batch["label"]
    ).mean()
    accuracy = (logits.argmax(axis=-1) == batch["label"]).mean()
    return {"loss": loss, "accuracy": accuracy}


# 初期化と学習の準備
rng = jr.PRNGKey(42)  # Set random seed
rng, init_rng = jr.split(rng)
variables = init(init_rng)

# パラメータと状態を分離
state = variables
params = state.pop("params")

# Optimizer の設定
optimizer = optax.sgd(learning_rate=0.01, momentum=0.9)
opt_state = optimizer.init(params)

# 学習の実行
train_ds = get_dataset(num_epochs, batch_size, split="train")
train_iter = iter(train_ds.as_numpy_iterator())
for epoch in range(num_epochs):
    # AverageMeter は、メトリクスの平均を計算するユーティリティクラス
    meter = mx.utils.AverageMeter()
    for batch in itertools.islice(train_iter, num_steps_per_epoch):
        rng, rng_apply = jr.split(rng)
        params, state, opt_state, metrics = train_step(
            rng_apply, params, state, opt_state, batch
        )
        meter.update(metrics, n=len(batch["image"]))
    print(meter.compute())

# テストの実行
test_ds = get_dataset(1, batch_size, split="test")
meter = mx.utils.AverageMeter()
for batch in test_ds.as_numpy_iterator():
    metrics = eval_step(params, state, batch)
    meter.update(metrics, n=len(batch["image"]))
print("Test:", meter.compute())

実行すると、以下のような出力が得られます。学習を通して、ロスが減少し、精度は向上していることから、ちゃんと学習できていることが分かりますね。

{'accuracy': 0.9218916755602988, 'loss': 0.24640471186576335}
{'accuracy': 0.9744530416221985, 'loss': 0.08127296438466523}
{'accuracy': 0.9802061099252934, 'loss': 0.062307235655616355}
{'accuracy': 0.9840081376734259, 'loss': 0.051378412921629694}
{'accuracy': 0.9864594450373533, 'loss': 0.04249529187615241}
{'accuracy': 0.9875433564567769, 'loss': 0.038763899686518936}
{'accuracy': 0.988977454642476, 'loss': 0.03436332501764504}
{'accuracy': 0.9900780416221985, 'loss': 0.03115085096888914}
{'accuracy': 0.9897778815368197, 'loss': 0.030075993747828107}
{'accuracy': 0.9914120864461046, 'loss': 0.02679101809990476}
Test: {'accuracy': 0.9925, 'loss': 0.020609416964557023}

おわりに

Metoryx は JAX の関数型プログラミングのスタイルを維持しつつ、直感的で柔軟なモデル定義が可能な深層学習ライブラリです。まだまだ発展途上のライブラリではあるものの、ようやく最低限の機能が揃ってきたと感じており、丁度良いタイミングだと思ってこの記事を公開することにしました。

もし興味があれば、ぜひ利用してみてください。フィードバック等頂けますと、めちゃくちゃありがたいです🙏

Appendix

おまけとして、Metoryx の細かい設計について説明します。これらの設計は Metoryx の概要を理解する上で必須ではありませんが、興味があれば目を通してみてください。

パラメータと状態の管理

新しいレイヤーを定義する際には、ユーザが独自にパラメータや状態を定義する必要があります。Metoryx では、mx.Parameter クラスと mx.State クラスを使って、パラメータと状態をそれぞれ定義することができます。

class MyLayer(mx.Module):
    def __init__(self, in_size, out_size):
        super().__init__()
        self.param = mx.Parameter((in_size, out_size), mx.initializers.lecun_normal())  # パラメータ
        self.state = mx.State("some_states", (out_size,), mx.initializers.zeros(), mutable=True)  # 状態

ざっくり言うと、mx.Parameter は誤差逆伝播で学習するパラメータを表し、mx.State は学習しない状態を表します。とはいえ、実装上 mx.Parametermx.State に 'params' という名前のコレクション名を与えたものに過ぎません。つまり、mx.Parametermx.State のシンタックスシュガーとして捉えることもできます。

パラメータや状態は、apply 関数の内部と外部で一部異なる振る舞いをします。中でも代表的なのが、value 属性の振る舞いです。

apply 関数の外部

apply 関数の外部では、value 属性にアクセスすることはできません。これは、パラメータや状態の値は init 関数を通じて初めて生成されるものであり、mx.Parametermx.State のインスタンスには値が存在しないためです。

一方、以下のようにvalue 属性に値を割り当てることは可能です。

p = mx.Parameter((10, 10), mx.initializers.zeros())
p.value = jnp.ones((10, 10))  # 初期値を割り当てる

ただし、この場合も mx.Parameter が割り当てられた初期値を保持するわけではありません。実際のところ、apply 関数の外部で value に値を設定するという操作は、mx.Parameterinitializer 属性に対し、常に jnp.ones((10, 10)) を返すイニシャライザを割り当てる操作のシンタックスシュガーに過ぎません。

apply 関数の内部

apply 関数の内部では、value 属性にアクセスすることができます。これは、apply 関数が呼び出されるたびに、入力された variables 辞書から、各パラメータや状態に対応する値が割り当てられるためです。

また、apply 関数の内部で value 属性に新しい値を割り当てると、状態の更新として扱われます。更新された状態は、apply 関数の戻り値として返される新しい variables 辞書に反映されます。例えば、以下のようにして状態を更新できます。

class MyLayer(mx.Module):
    def __init__(self, in_size, out_size):
        super().__init__()
        self.state = mx.State("some_states", (out_size,), mx.initializers.zeros(), mutable=True)
    def __call__(self, x):
        self.state.value = self.state.value + jnp.sum(x, axis=0)  # 状態を更新
        return x + self.state.value

注意点として、mutable=True として定義された mx.State のみが更新可能であり、mutable=Falsemx.Statemx.Parameter は更新できません。もし更新しようとすると、エラーになります。

PRNG キーの管理

mx.Module クラスの __call__ メソッド内で乱数を使用する場合、ユーザは Metoryx に乱数の管理を任せることができます。ユーザは、apply 関数の第 2 引数に PRNG キーを渡すだけで、Metoryx が自動的にキーを分割し、各レイヤーに適切なキーを供給します。PRNG キーは、以下のように mx.next_rng_key() 関数を使って取得できます。

class AddNoise(mx.Module):
    def __call__(self, x):
        rng = mx.next_rng_key()  # Metoryx が供給する PRNG キーを取得
        z = jr.normal(rng, x.shape)
        return x + z

model = AddNoise()
init, apply = mx.transform(model)
variables = init(jr.PRNGKey(0))

# 第 2 引数に PRNG キーを渡す
outputs, new_variables = apply(variables, jr.PRNGKey(1), jnp.ones((10, 10)))

もしモデル内部で乱数を使用しないのであれば、apply 関数の第 2 引数に None を渡すことも可能です。

outputs, new_variables = apply(variables, None, jnp.ones((10, 10)))

ここだけ見ると、Haiku の乱数の扱い方に似ていますね。

一方、Flax Linen では、apply 関数に辞書形式で複数の PRNG キーを渡すような API を採用しており、より細かい制御が可能です。そのため、Flax Linen の方が好みであるという方も多いのではないでしょうか。実は、Metoryx では Flax Linen のように複数の PRNG キーを渡すことも可能です。

# 第 2 引数に辞書形式で複数の PRNG キーを渡す
rngs = {"dropout": jr.PRNGKey(0), "noise": jr.PRNGKey(1)}
outputs, new_variables = apply(variables, rngs, jnp.ones((10, 10)))

このとき、mx.next_rng_key() 関数にキーの名前を渡すことで、特定のキーを取得できます。

class AddNoise(mx.Module):
    def __call__(self, x):
        rng = mx.next_rng_key("noise")  # noise という名前のキーを取得
        z = jr.normal(rng, x.shape)
        return x + z

また、Haiku と Flax Linen 両方のスタイルを組み合わせることもでき、その場合は mx.PRNGKeys メソッドを使用します。mx.PRNGKeys メソッドは、位置引数でデフォルトで使用される PRNG キーを、キーワード引数で名前付きの PRNG キーを受け取ります。

rngs = mx.PRNGKeys(jr.PRNGKey(42), dropout=jr.PRNGKey(0), noise=jr.PRNGKey(1))
outputs, new_variables = apply(variables, rngs, jnp.ones((10, 10)))

このメソッドを使用すると、mx.next_rng_key() 関数はまず指定された名前付きのキーを探し、見つからなかった場合や名前が指定されなかった場合にはデフォルトのキーを使用するようになります。

mx.next_rng_key("noise")  # noise という名前のキーを使用
mx.next_rng_key()         # デフォルトのキーを使用
mx.next_rng_key("foo")    # foo という名前のキーがないので、デフォルトのキーを使用

このように、Metoryx ではユーザが自身のニーズに応じて、柔軟に乱数の扱い方を選択できるようになっています。

Discussion