🕊️

flax/nnx でとびたつ

2024/11/19に公開

JAX

JAXは端的にいうとGoogleが開発するGPU,TPUなどのアクセラレータでブーストしたnumpy
主な用途は機械学習ライブラリのベースとなっていて、jitvmapなどの使って計算を高速化する

Flax

FlaxはJAXをベースとした機械学習ライブラリで, Googleが開発
DeepMindが開発していたHaikuの開発チームと統合された
Flax(というかJax)の特徴として, 各処理が関数型を意識したコードで直感的でわかりやすい

flax.linen

以下にTransformerを学習するコード例を示します。

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax

def get_dummy_data(batch_size, seq_length, vocab_size):
    return jnp.random.randint(0, vocab_size, (batch_size, seq_length))

class Transformer(nn.Module):
    vocab_size: int
    d_model: int
    num_heads: int
    num_layers: int

    @nn.compact
    def __call__(self, x):
        # 埋め込み層
        x = nn.Embed(num_embeddings=self.vocab_size, features=self.d_model)(x)
        # 位置エンコーディンは簡略化のため省略

        # エンコーダーブロック
        for _ in range(self.num_layers):
            x = nn.SelfAttention(num_heads=self.num_heads)(x)
            x = nn.Dense(self.d_model)(x)

        # 出力層
        logits = nn.Dense(self.vocab_size)(x)
        return logits

def compute_loss(params, x, y):
    logits = model.apply({'params': params}, x)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
    return loss

@jax.jit
def train_step(state, x, y):
    def loss_fn(params):
        return compute_loss(params, x, y)
    grads = jax.grad(loss_fn)(state.params)
    new_state = state.apply_gradients(grads=grads)
    return new_state

vocab_size = 1000
seq_length = 32
batch_size = 64
d_model = 128
num_heads = 8
num_layers = 2

model = Transformer(vocab_size, d_model, num_heads, num_layers)
params = model.init(jax.random.PRNGKey(0), jnp.ones((batch_size, seq_length), dtype=jnp.int32))['params']

tx = optax.adam(learning_rate=0.001)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)


num_epochs = 10

for epoch in range(num_epochs):
    x = get_dummy_data(batch_size, seq_length, vocab_size)
    y = get_dummy_data(batch_size, seq_length, vocab_size)
    state = train_step(state, x, y)
    print(f"Epoch {epoch+1}/{num_epochs} 完了")

特徴として、モデル(nn.Module)は__call__で実際の処理の内容を記述します
Flaxは@nn.compactデコレータで初期化しつつ処理も実装できるので、コードが非常に見通しやすくなります、これは初期のTensorFlowと同じですかね

そして、モデルの処理とパラメータは完全に分けられています。
params = model.init(...)で初期化されたパラメータが得られますが、これをstateと呼ばれるステートに渡します。
このstateはモデルパラメータとoptimizer(tx)のパラメータ(paramsのEMAなど)とmoduleの__call__に当たるapply_fnが渡されます
結局, 学習ではこのstateは毎回更新されますが、実際は状態を変えずに, 新しいstateが元のstateから作られます、ここがポイントです
ようするに

new_state = state.apply_gradients(grads)

これはnew_state = state + diffとなっています
これがコードの見通しやjitでは扱いやすくなっていますが、複雑なことをするときにかなり面倒になります。
例えば、モデル中の特定のレイヤーだけ置き換えたいときなどは面倒です。
paramsから置き換えないレイヤーのみ持ってきて置き換えたりします。

model1 = Model1(...)
params1 = model1.init(...)
...
# train model1
...
model2 = Model2(...)
params2 = model2.init(...)
params2['layer1'] = param21['layer1']
...

この記事のテーマのnnxではsateやモデルに対して参照を持っており, これを介して操作をします。

nnx

以下にflax.linenと同じコード例を示します

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax

def get_dummy_data(batch_size, seq_length, vocab_size):
    return jnp.random.randint(0, vocab_size, (batch_size, seq_length))

class Transformer(nnx.Module):
    vocab_size: int
    d_model: int
    num_heads: int
    num_layers: int

    def __init__(self, rngs: nnx.Rngs):
        self.embed = nnx.Embed(num_embeddings=self.vocab_size, features=self.d_model, rngs=rngs)
        self.layers = [nnx.Sequential([
            nnx.SelfAttention(num_heads=self.num_heads, rngs=rngs),
            nnx.Dense(features=self.d_model, rngs=rngs)
        ]) for _ in range(self.num_layers)]
        self.output_layer = nnx.Dense(features=self.vocab_size)

    def __call__(self, x):
        x = self.embed(x)
        for layer in self.layers:
            x = layer(x)
        logits = self.output_layer(x)
        return logits

def compute_loss(model, x, y):
    logits = model(x)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
    return loss

@nnx.jit
def train_step(model, optimizer, x, y):
  """Train for a single step."""
    def loss_fn(params):
        return compute_loss(params, x, y)
  grad_fn = nnx.grad(loss_fn)
  grads = grad_fn(model, x, y)
  optimizer.update(grads)  # In-place updates.

vocab_size = 1000
seq_length = 32
batch_size = 64
d_model = 128
num_heads = 8
num_layers = 2

model = Transformer(vocab_size, d_model, num_heads, num_layers, rngs=nnx.Rngs(0))

tx = optax.adam(learning_rate=0.001)
optimizer = nnx.Optimizer(model, tx)

num_epochs = 10

for epoch in range(num_epochs):
    x = get_dummy_data(batch_size, seq_length, vocab_size)
    y = get_dummy_data(batch_size, seq_length, vocab_size)
    train_step(model, optimizer, x, y)
    print(f"Epoch {epoch+1}/{num_epochs} 完了")

こうなります、細かいところを具体的に見ていきます。

Module?

flax.linenとはことなり, パラメータの初期化はモデルを作ったときに行われます。
flax.linenでは若干無理矢理感のあったrng keyの扱いも一貫しています。

State?

flax.nnxでは明確なstateは存在しますが、できるだけ表に出ないように(直接操作したり)することのない設計になっています。
上の例では、nnx.Optimizerstateを持っており, gradsを入力としてupdateを行います
そしてoptimizertrain_stepにおいてin-placeな更新を行います

nnx.jit?

flax.nnxではこのようなin-placeなupdateなどを行うために, nnx.jitnnx.gradなどが独自に用意されています
基本的にnnx側のものを使ってよいです

nnxの威力

以下のような例が可能です

class Model1(nnx.Module):
    encoder: nnx.Module
    decoder: nnx.Module
...

class Model2(nnx.Module):
    encoder: nnx.Module
    head: nnx.Module
...

model1 = Model1(...)
# some training steps
...
model2 = Model2(...)
model2.encoder = model1.encoder

こうすることで簡単に事前学習を行ったencoderを別のモデルにわたすことができます
flax.linenで行うと

model2 = Model2(...)
params2 = model2.init(...)
params2['encoder'] = params1['encoder']

とパラメータをやり取りすることになります
この例ではそこまで恩恵がないように見えますが、ここで, LoRAを既存のモデルに対して適用することを考えます


class LoraParam(nnx.Param):
    pass

class LoraLinear(nnx.Module):
    def __init__(self, linear, rank, rngs):
        self.linear = linear
        dim_in, dim_out = linear.kernel.value.shape
        self.A = LoraParam(init_fn(rngs(), (dim_in, rank), param_dtype))
        self.B = LoraParam(init_fn(rngs(), (rank, dim_out), param_dtype))
    
    def __call__(self, x: jax.Array):
        return self.linear(x), x @ self.A.value @ self.B.value)

このようにLoRAを定義したとして、あるモデルのLinearレイヤーをすべてLoRAに置き換えます

model = Model(...)
def apply_LoRA(model, rank, rngs):
    for path, module in model.iter_modules():
        if type(module) is nnx.Linear:
            new_module = uux.LoraLinear(module, rank, rngs)
            update_module(model, path, new_module)
        
def update_module(model, path, new_module):
    if type(path) in (list, tuple) and len(path) > 1:
        if type(model) in (list, tuple, dict):
            model = model[path[0]]
        else:
            model = getattr(model, path[0])
        update_module(model, path[1:], new_module)
    if len(path) == 1:
        path = path[0]
        if hasattr(model, path):
            setattr(model, path, new_module)
    return model

このように機械的に行うことができます

Discussion