flax/nnx でとびたつ
JAX
JAXは端的にいうとGoogleが開発するGPU,TPUなどのアクセラレータでブーストしたnumpy
主な用途は機械学習ライブラリのベースとなっていて、jit
やvmap
などの使って計算を高速化する
Flax
FlaxはJAXをベースとした機械学習ライブラリで, Googleが開発
DeepMindが開発していたHaikuの開発チームと統合された
Flax(というかJax)の特徴として, 各処理が関数型を意識したコードで直感的でわかりやすい
flax.linen
以下にTransformerを学習するコード例を示します。
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
def get_dummy_data(batch_size, seq_length, vocab_size):
return jnp.random.randint(0, vocab_size, (batch_size, seq_length))
class Transformer(nn.Module):
vocab_size: int
d_model: int
num_heads: int
num_layers: int
@nn.compact
def __call__(self, x):
# 埋め込み層
x = nn.Embed(num_embeddings=self.vocab_size, features=self.d_model)(x)
# 位置エンコーディンは簡略化のため省略
# エンコーダーブロック
for _ in range(self.num_layers):
x = nn.SelfAttention(num_heads=self.num_heads)(x)
x = nn.Dense(self.d_model)(x)
# 出力層
logits = nn.Dense(self.vocab_size)(x)
return logits
def compute_loss(params, x, y):
logits = model.apply({'params': params}, x)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
return loss
@jax.jit
def train_step(state, x, y):
def loss_fn(params):
return compute_loss(params, x, y)
grads = jax.grad(loss_fn)(state.params)
new_state = state.apply_gradients(grads=grads)
return new_state
vocab_size = 1000
seq_length = 32
batch_size = 64
d_model = 128
num_heads = 8
num_layers = 2
model = Transformer(vocab_size, d_model, num_heads, num_layers)
params = model.init(jax.random.PRNGKey(0), jnp.ones((batch_size, seq_length), dtype=jnp.int32))['params']
tx = optax.adam(learning_rate=0.001)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)
num_epochs = 10
for epoch in range(num_epochs):
x = get_dummy_data(batch_size, seq_length, vocab_size)
y = get_dummy_data(batch_size, seq_length, vocab_size)
state = train_step(state, x, y)
print(f"Epoch {epoch+1}/{num_epochs} 完了")
特徴として、モデル(nn.Module
)は__call__
で実際の処理の内容を記述します
Flaxは@nn.compact
デコレータで初期化しつつ処理も実装できるので、コードが非常に見通しやすくなります、これは初期のTensorFlowと同じですかね
そして、モデルの処理とパラメータは完全に分けられています。
params = model.init(...)
で初期化されたパラメータが得られますが、これをstate
と呼ばれるステートに渡します。
このstate
はモデルパラメータとoptimizer(tx
)のパラメータ(paramsのEMAなど)とmoduleの__call__に当たるapply_fn
が渡されます
結局, 学習ではこのstate
は毎回更新されますが、実際は状態を変えずに, 新しいstate
が元のstate
から作られます、ここがポイントです
ようするに
new_state = state.apply_gradients(grads)
これはnew_state = state + diff
となっています
これがコードの見通しやjit
では扱いやすくなっていますが、複雑なことをするときにかなり面倒になります。
例えば、モデル中の特定のレイヤーだけ置き換えたいときなどは面倒です。
params
から置き換えないレイヤーのみ持ってきて置き換えたりします。
model1 = Model1(...)
params1 = model1.init(...)
...
# train model1
...
model2 = Model2(...)
params2 = model2.init(...)
params2['layer1'] = param21['layer1']
...
この記事のテーマのnnx
ではsateやモデルに対して参照を持っており, これを介して操作をします。
nnx
以下にflax.linen
と同じコード例を示します
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
def get_dummy_data(batch_size, seq_length, vocab_size):
return jnp.random.randint(0, vocab_size, (batch_size, seq_length))
class Transformer(nnx.Module):
vocab_size: int
d_model: int
num_heads: int
num_layers: int
def __init__(self, rngs: nnx.Rngs):
self.embed = nnx.Embed(num_embeddings=self.vocab_size, features=self.d_model, rngs=rngs)
self.layers = [nnx.Sequential([
nnx.SelfAttention(num_heads=self.num_heads, rngs=rngs),
nnx.Dense(features=self.d_model, rngs=rngs)
]) for _ in range(self.num_layers)]
self.output_layer = nnx.Dense(features=self.vocab_size)
def __call__(self, x):
x = self.embed(x)
for layer in self.layers:
x = layer(x)
logits = self.output_layer(x)
return logits
def compute_loss(model, x, y):
logits = model(x)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
return loss
@nnx.jit
def train_step(model, optimizer, x, y):
"""Train for a single step."""
def loss_fn(params):
return compute_loss(params, x, y)
grad_fn = nnx.grad(loss_fn)
grads = grad_fn(model, x, y)
optimizer.update(grads) # In-place updates.
vocab_size = 1000
seq_length = 32
batch_size = 64
d_model = 128
num_heads = 8
num_layers = 2
model = Transformer(vocab_size, d_model, num_heads, num_layers, rngs=nnx.Rngs(0))
tx = optax.adam(learning_rate=0.001)
optimizer = nnx.Optimizer(model, tx)
num_epochs = 10
for epoch in range(num_epochs):
x = get_dummy_data(batch_size, seq_length, vocab_size)
y = get_dummy_data(batch_size, seq_length, vocab_size)
train_step(model, optimizer, x, y)
print(f"Epoch {epoch+1}/{num_epochs} 完了")
こうなります、細かいところを具体的に見ていきます。
Module?
flax.linen
とはことなり, パラメータの初期化はモデルを作ったときに行われます。
flax.linen
では若干無理矢理感のあったrng keyの扱いも一貫しています。
State?
flax.nnx
では明確なstate
は存在しますが、できるだけ表に出ないように(直接操作したり)することのない設計になっています。
上の例では、nnx.Optimizer
がstate
を持っており, grads
を入力としてupdateを行います
そしてoptimizer
は train_step
においてin-placeな更新を行います
nnx.jit?
flax.nnx
ではこのようなin-placeなupdateなどを行うために, nnx.jit
やnnx.grad
などが独自に用意されています
基本的にnnx
側のものを使ってよいです
nnxの威力
以下のような例が可能です
class Model1(nnx.Module):
encoder: nnx.Module
decoder: nnx.Module
...
class Model2(nnx.Module):
encoder: nnx.Module
head: nnx.Module
...
model1 = Model1(...)
# some training steps
...
model2 = Model2(...)
model2.encoder = model1.encoder
こうすることで簡単に事前学習を行ったencoder
を別のモデルにわたすことができます
flax.linen
で行うと
model2 = Model2(...)
params2 = model2.init(...)
params2['encoder'] = params1['encoder']
とパラメータをやり取りすることになります
この例ではそこまで恩恵がないように見えますが、ここで, LoRAを既存のモデルに対して適用することを考えます
class LoraParam(nnx.Param):
pass
class LoraLinear(nnx.Module):
def __init__(self, linear, rank, rngs):
self.linear = linear
dim_in, dim_out = linear.kernel.value.shape
self.A = LoraParam(init_fn(rngs(), (dim_in, rank), param_dtype))
self.B = LoraParam(init_fn(rngs(), (rank, dim_out), param_dtype))
def __call__(self, x: jax.Array):
return self.linear(x), x @ self.A.value @ self.B.value)
このようにLoRAを定義したとして、あるモデルのLinearレイヤーをすべてLoRAに置き換えます
model = Model(...)
def apply_LoRA(model, rank, rngs):
for path, module in model.iter_modules():
if type(module) is nnx.Linear:
new_module = uux.LoraLinear(module, rank, rngs)
update_module(model, path, new_module)
def update_module(model, path, new_module):
if type(path) in (list, tuple) and len(path) > 1:
if type(model) in (list, tuple, dict):
model = model[path[0]]
else:
model = getattr(model, path[0])
update_module(model, path[1:], new_module)
if len(path) == 1:
path = path[0]
if hasattr(model, path):
setattr(model, path, new_module)
return model
このように機械的に行うことができます
Discussion