flax/nnx でとびたつ
JAX
JAXは端的にいうとGoogleが開発するGPU,TPUなどのアクセラレータでブーストしたnumpy
主な用途は機械学習ライブラリのベースとなっていて、jitやvmapなどの使って計算を高速化する
Flax
FlaxはJAXをベースとした機械学習ライブラリで, Googleが開発
DeepMindが開発していたHaikuの開発チームと統合された
Flax(というかJax)の特徴として, 各処理が関数型を意識したコードで直感的でわかりやすい
flax.linen
以下にTransformerを学習するコード例を示します。
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
def get_dummy_data(batch_size, seq_length, vocab_size):
return jnp.random.randint(0, vocab_size, (batch_size, seq_length))
class Transformer(nn.Module):
vocab_size: int
d_model: int
num_heads: int
num_layers: int
@nn.compact
def __call__(self, x):
# 埋め込み層
x = nn.Embed(num_embeddings=self.vocab_size, features=self.d_model)(x)
# 位置エンコーディンは簡略化のため省略
# エンコーダーブロック
for _ in range(self.num_layers):
x = nn.SelfAttention(num_heads=self.num_heads)(x)
x = nn.Dense(self.d_model)(x)
# 出力層
logits = nn.Dense(self.vocab_size)(x)
return logits
def compute_loss(params, x, y):
logits = model.apply({'params': params}, x)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
return loss
@jax.jit
def train_step(state, x, y):
def loss_fn(params):
return compute_loss(params, x, y)
grads = jax.grad(loss_fn)(state.params)
new_state = state.apply_gradients(grads=grads)
return new_state
vocab_size = 1000
seq_length = 32
batch_size = 64
d_model = 128
num_heads = 8
num_layers = 2
model = Transformer(vocab_size, d_model, num_heads, num_layers)
params = model.init(jax.random.PRNGKey(0), jnp.ones((batch_size, seq_length), dtype=jnp.int32))['params']
tx = optax.adam(learning_rate=0.001)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)
num_epochs = 10
for epoch in range(num_epochs):
x = get_dummy_data(batch_size, seq_length, vocab_size)
y = get_dummy_data(batch_size, seq_length, vocab_size)
state = train_step(state, x, y)
print(f"Epoch {epoch+1}/{num_epochs} 完了")
特徴として、モデル(nn.Module)は__call__で実際の処理の内容を記述します
Flaxは@nn.compactデコレータで初期化しつつ処理も実装できるので、コードが非常に見通しやすくなります、これは初期のTensorFlowと同じですかね
そして、モデルの処理とパラメータは完全に分けられています。
params = model.init(...)で初期化されたパラメータが得られますが、これをstateと呼ばれるステートに渡します。
このstateはモデルパラメータとoptimizer(tx)のパラメータ(paramsのEMAなど)とmoduleの__call__に当たるapply_fnが渡されます
結局, 学習ではこのstateは毎回更新されますが、実際は状態を変えずに, 新しいstateが元のstateから作られます、ここがポイントです
ようするに
new_state = state.apply_gradients(grads)
これはnew_state = state + diffとなっています
これがコードの見通しやjitでは扱いやすくなっていますが、複雑なことをするときにかなり面倒になります。
例えば、モデル中の特定のレイヤーだけ置き換えたいときなどは面倒です。
paramsから置き換えないレイヤーのみ持ってきて置き換えたりします。
model1 = Model1(...)
params1 = model1.init(...)
...
# train model1
...
model2 = Model2(...)
params2 = model2.init(...)
params2['layer1'] = param21['layer1']
...
この記事のテーマのnnxではsateやモデルに対して参照を持っており, これを介して操作をします。
nnx
以下にflax.linenと同じコード例を示します
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
def get_dummy_data(batch_size, seq_length, vocab_size):
return jnp.random.randint(0, vocab_size, (batch_size, seq_length))
class Transformer(nnx.Module):
vocab_size: int
d_model: int
num_heads: int
num_layers: int
def __init__(self, rngs: nnx.Rngs):
self.embed = nnx.Embed(num_embeddings=self.vocab_size, features=self.d_model, rngs=rngs)
self.layers = [nnx.Sequential([
nnx.SelfAttention(num_heads=self.num_heads, rngs=rngs),
nnx.Dense(features=self.d_model, rngs=rngs)
]) for _ in range(self.num_layers)]
self.output_layer = nnx.Dense(features=self.vocab_size)
def __call__(self, x):
x = self.embed(x)
for layer in self.layers:
x = layer(x)
logits = self.output_layer(x)
return logits
def compute_loss(model, x, y):
logits = model(x)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
return loss
@nnx.jit
def train_step(model, optimizer, x, y):
"""Train for a single step."""
def loss_fn(params):
return compute_loss(params, x, y)
grad_fn = nnx.grad(loss_fn)
grads = grad_fn(model, x, y)
optimizer.update(grads) # In-place updates.
vocab_size = 1000
seq_length = 32
batch_size = 64
d_model = 128
num_heads = 8
num_layers = 2
model = Transformer(vocab_size, d_model, num_heads, num_layers, rngs=nnx.Rngs(0))
tx = optax.adam(learning_rate=0.001)
optimizer = nnx.Optimizer(model, tx)
num_epochs = 10
for epoch in range(num_epochs):
x = get_dummy_data(batch_size, seq_length, vocab_size)
y = get_dummy_data(batch_size, seq_length, vocab_size)
train_step(model, optimizer, x, y)
print(f"Epoch {epoch+1}/{num_epochs} 完了")
こうなります、細かいところを具体的に見ていきます。
Module?
flax.linenとはことなり, パラメータの初期化はモデルを作ったときに行われます。
flax.linenでは若干無理矢理感のあったrng keyの扱いも一貫しています。
State?
flax.nnxでは明確なstateは存在しますが、できるだけ表に出ないように(直接操作したり)することのない設計になっています。
上の例では、nnx.Optimizerがstateを持っており, gradsを入力としてupdateを行います
そしてoptimizerは train_stepにおいてin-placeな更新を行います
nnx.jit?
flax.nnxではこのようなin-placeなupdateなどを行うために, nnx.jitやnnx.gradなどが独自に用意されています
基本的にnnx側のものを使ってよいです
nnxの威力
以下のような例が可能です
class Model1(nnx.Module):
encoder: nnx.Module
decoder: nnx.Module
...
class Model2(nnx.Module):
encoder: nnx.Module
head: nnx.Module
...
model1 = Model1(...)
# some training steps
...
model2 = Model2(...)
model2.encoder = model1.encoder
こうすることで簡単に事前学習を行ったencoderを別のモデルにわたすことができます
flax.linenで行うと
model2 = Model2(...)
params2 = model2.init(...)
params2['encoder'] = params1['encoder']
とパラメータをやり取りすることになります
この例ではそこまで恩恵がないように見えますが、ここで, LoRAを既存のモデルに対して適用することを考えます
class LoraParam(nnx.Param):
pass
class LoraLinear(nnx.Module):
def __init__(self, linear, rank, rngs):
self.linear = linear
dim_in, dim_out = linear.kernel.value.shape
self.A = LoraParam(init_fn(rngs(), (dim_in, rank), param_dtype))
self.B = LoraParam(init_fn(rngs(), (rank, dim_out), param_dtype))
def __call__(self, x: jax.Array):
return self.linear(x), x @ self.A.value @ self.B.value)
このようにLoRAを定義したとして、あるモデルのLinearレイヤーをすべてLoRAに置き換えます
model = Model(...)
def apply_LoRA(model, rank, rngs):
for path, module in model.iter_modules():
if type(module) is nnx.Linear:
new_module = uux.LoraLinear(module, rank, rngs)
update_module(model, path, new_module)
def update_module(model, path, new_module):
if type(path) in (list, tuple) and len(path) > 1:
if type(model) in (list, tuple, dict):
model = model[path[0]]
else:
model = getattr(model, path[0])
update_module(model, path[1:], new_module)
if len(path) == 1:
path = path[0]
if hasattr(model, path):
setattr(model, path, new_module)
return model
このように機械的に行うことができます
Discussion