🌟

jax/flaxの思想:オブジェクト指向との違い

2023/06/01に公開

はじめに

jax/flaxはTensorFlow, PyTorchに次ぐ第3の選択肢として新しくGoogleによって開発された深層学習用フレームワークです。
jaxを使うと簡単に高速化ができますが、設計思想がオブジェクト指向の考え方からかなり違うものになっていて、慣れないとプログラムがうまく書けません。
特に、疑似乱数やTrainStateなど内部状態については真逆とも言える考え方をしていて、オブジェクト指向で作ったプロジェクトの移行には設計の見直しが必要になります。
そこで、筆者なりにjax/flaxの設計思想とオブジェクト指向の違いを解説してみました。

jax/flaxの紹介や使い方、チュートリアルは他の記事に譲ることとして、この記事では

  • なぜ疑似乱数を使うときに不便そうに見える書き方をするのか
  • TrainStateとは何者なのか、継承して変更を加えるときの指針
    を中心に解説します。

基本的には、チュートリアルを動かしてみたが、なぜこんな書き方をするのかわからなかった、という人向けの記事になります。

なお、jax/flaxは関数型言語に強く影響を受けていると思われますが、筆者は関数型言語は少し触ったことがある程度の知識しかないので、用語などは使い方が間違っているものもあると思います。そのときはコメントなどいただければ幸いです。

jax/flax/optaxとは

TensorFlowやPyTorchは深層学習用ライブラリとして、モデルを作るためのクラスや関数、学習アルゴリズム、自動微分機能、GPUを使う機能などが一通り含まれていますが、jaxはそうではありません。
分離できるものは別のライブラリとして提供されており、役割は以下のようになっています。

jax: numpy互換な、自動微分機能とGPU計算、JITによる高速化機能のある数値計算ライブラリ
flax: jaxを使って深層学習のモデルを定義するためのライブラリ
optax: jaxとflaxで定義されたモデルを学習するためのアルゴリズム(勾配法やAdamなど)

jaxとflaxはgoogleのgithubリポジトリですが、optaxはdeepmindのリポジトリで公開されています。

オブジェクトはステートマシン

jaxの説明の前にオブジェクトの副作用の話をします。

ステートマシンとは、入力を受け取って出力を返すものですが、特に状態という概念があり、同じ入力でも状態に応じて出力が変わります。また、状態も入力に応じて随時変わっていきます。
ステートマシンの挙動を図で描いたものが状態遷移図なので、状態遷移図に従って動くもの、と考えてもいいかもしれません。

これをクラスで定義するとこんな感じです。

class StateMachine:
    def __init__(self):
        self.state = initial_state
    
    def action(self, input):
        outuput, self.state = 何らかの処理(self.state, input)
        return output

machine = StateMachine()
for i in range(N):
    output = machine.action(input)

このStateMachineオブジェクトはstateというメンバ変数を外部から見えなくすることで、利便性を向上しています。つまり、このStateMachineを使いたい人は今、内部がどんな状態であるか、どのように更新されるか知らなくともStateMachine.actionというメソッドを使うことができます。
実際にmachineインスタンスを生成して利用する側ではstateには見かけ上一切関与していません。

しかし、ここで副作用の問題が発生します。
プログラムが意図しない動作をした場合にstateがわからないとバグが再現できないという問題です。
actionの中身が複雑になってきたとき、実装がよくわからない外部ライブラリを使っているときに、意図せずstateが変わっていると、このようなバグが発生します。

この例ではメンバ変数とメソッドが一つしかないシンプルなクラスですが、一般のオブジェクトでも上記の問題は常につきまといます。

副作用を回避する

副作用の問題を回避する方法の一つにstateを外に出すというやり方があります。
先ほどの例を以下のように書き換えます。ただしstateは後々のためにクラスにしています。

class State:
    def __init__(self):
        self.value1 = init_value1
        self.value2 = init_value2

def action(state, input):
    output, state = 何らかの処理(state, input)
    return output, state

state = State()
for i in range(N):
    output, state = action(state, input)

これはオブジェクト指向から構造体しかないC言語に戻った印象を受けるかもしれませんが、jax/flaxではこの考え方が推奨されています。

疑似乱数の例

実際に擬似乱数生成器に上の考え方を適用してみます。

まず、疑似乱数の一種である線形合同法が、ステートマシンであることを確認します。線形合同法は直前の値だけで次の値が決まるので次のように書けます。

class RandGenerator:
    def __init__(self, seed):
        self.state = seed

    def generate(self, input=None):
        # inputは使わない
        next_state = 漸化式(self.state)
        # outputとself.stateは同じ
	output = next_state
	self.state = next_state
	return output

# 初期化
rng = RandGenerator(seed)
# 使用時
output = rng.generate()
output = rng.generate()

inputが不要になり、stateとoutputが同じ値になっていますが、最初の特殊な形になっていますが、基本的には同じです。
2回乱数を生成していますが、見た目はどちらもrng.generate()で全く同じです。乱数なので出力は異なるはずなのですが、引数がなく、なぜ異なる結果になるか見た目からはわかりません。
実際には内部でstateが変わっているのですが、それが外に見えておらず、特定の状況の再現が困難になるというのが副作用の問題です。

これを先程の状態を外に出すやり方で書き直してみます。

def generate(state):
    next_state = 漸化式(state)
    return next_state, next_state # 1つ目はoutputに相当

# 初期化
state = seed

# 使用時 1回目
output, state = generate(state)
val = (outputを所望の範囲の数値に変換など)

# 使用時 2回目
output, state = generate(state)
val = (outputを所望の範囲の数値に変換など)

線形合同法なので、outputとstateを分ける意味がないですが、StateMachineの説明に合わせるために無理やりそうしています。

オブジェクト指向のときと違って1回目と2回目でstateが変化しているのがポイントで、outputが異なるのは入力のstateが異なるからだ、ということが使う側にも見えていることが重要です。これが副作用がない状態で、バグが発生したときのstateの値のログを出せば、そのときの状況を確実に再現できるという利点があります。

次に、jaxでの疑似乱数の使い方を見てみましょう。

from jax import random

# 初期化
key = random.PRNGKey(0)

# 使用時 2回目
key, subkey1 = random.split(key)
val = random.normal(subkey1, shape=(1,))
# 使用時 2回目
key, subkey2 = random.split(key)
val = random.normal(subkey2, shape=(1,))

初期化はrandom.PRNGKey(0)で行います。これはkey(実態はuint32の2要素配列)を返します。

このkeyがstateかというとそうでもないみたいです(jaxの疑似乱数生成器は並列化やベクトル化がしやすい設計をしていて、筆者は把握しているわけではないです)が、線形合同法のときと似たような使い方になっています。

splitはkeyを分離して新しいkeyを複数(同時に3つ以上作ることも可能)生成します。乱数を生成するための関数(random.randintやrandom.normalなど)にkeyを渡すと疑似乱数を出力しますが、渡すkeyが同じだと同じ数値が帰ってきます。つまり、出力は引数だけに依存していて副作用がないことがわかります。

ちなみにsplit後のkeyはどれを使ってもいいですが、以降も疑似乱数を生成する場合は1つ残しておいて、次のsplitには未使用のkeyを使わなければいけません。
並列化と再現性を両立するためにこのような仕様になっているみたいですが、使う直前でsplitすると覚えておけばいいかと思います。

モデルの状態

線形回帰モデルのクラスを考えています。わかりやすさのために線型回帰にしましたが、基本的なことはNeural Networkも同じです。

モデルはy=wx+bとします。学習するパラメータはwb、入出力がxyです。
ここではflaxを使わず、自作のクラスを作ってみます。

class Model:
    def __init__(self):
        self.w = init_w
	self.b = init_b

    def predict(self, x):
        return self.w * x + self.b

    def update(self, delta):
        self.w = self.w + delta["w"]
        self.b = self.b + delta["b"]

# 初期化
model = Model()

# 推論
y = model.predict(x)

# 学習(1step分)
grads = 勾配計算(loss_func, dataset, model)
delta = - learning_rate * grads
model.update(delta)

SGD()は勾配を計算して学習率を乗じて更新差分を求める関数としています。

StateMachineとの対応をとると、wとbを合わせたものがstateになります。メソッドがaction1つからpredictupdateの2つに増えています。引数がそれぞれx, deltaだけですが、出力や次の状態が、内部状態であるself.wself.bに依存します。

ここから副作用のない書き方に変更します。

def predict(params, x):
    return params["w"]*x + params["b"]

def update(params, updates):
    return {
        "w": params["w"] + updates["w"],
        "b": params["b"] + updates["b"],
    }

# 初期化
params = {
    "w": init_w,
    "b": init_b,
}

# 推論
y = predict(params, x)

# 学習(1step分)
grads = 勾配計算(loss_func, dataset, params)
updates = - learning_rate * grads
params = update(params, updates)

Modelクラスがなくなってしまいました。変わりにパラメータが状態として外に出ています。つまり、Neural Networkはパラメータだけ管理すればよいということですね。

predictupdateも返り値が引数だけで決まり、副作用がありません。paramsという数値データさえあれば、完全に同じ状況が再現でき、デバッグや学習途中から再開などがやりやすくなります。

オプティマイザ

他に深層学習で状態を持つものは何でしょうか。実はオプティマイザが状態を持っています。
単純な勾配法には状態はありませんが、モーメンタムを使う場合や、Adamなど高度なものは学習率が更新される=状態を持っています。

最適化アルゴリズムのライブラリであるoptaxのSGDの使い方を見てみましょう。
momeuntumを指定しないSGDは状態がなく、常に同じ挙動をするので、momentumを設定してみます。

モーメンタム法は細かい違いでいくつかの流儀がありますが、以下のwikipediaの計算法を採用すると\Delta wを状態として管理すれば良いことがわかります。

\Delta w \leftarrow \eta \nabla Loss(w) + \alpha \Delta w\\w \leftarrow w - \Delta w

# 初期化
tx = optax.sgd(learning_rate=0.1, momentum=0.9)
opt_state = tx.init(params) # NNのparamsを指定する

# 更新
grads = 勾配計算(loss_func, dataset, params)
updates, opt_state = tx.update(grads, opt_state)

txインスタンスをinitメソッドで初期化します。初期化の返り値がstateであり、これがオプティマイザの初期状態を表します。
tx.updateで勾配gradsと前の状態opt_state=\Delta wから次の更新差分updates=\Delta wを計算します。
モーメンタム法では、疑似乱数のときと同じく求めたい更新差分と状態は同じですが、汎用性のためにこのような形式になります。

状態をまとめる

では、オプティマイザを使ってモデルを更新するコードを書いてみましょう。

# 初期化
params = {
    "w": init_w,
    "b": init_b,
}
tx = optax.sgd(learning_rate=0.1, momentum=0.9)
opt_state = tx.init(params) # NNのparamsを指定する

# 推論
y = predict(params, x)

# 学習(1step分)
grads = 勾配計算(loss_func, dataset, params)
updates, opt_state = tx.update(grads, opt_state)
params = update(params, updates)

初期化と学習のブロックにモデルとオプティマイザの初期化と更新がそれぞれ並んでいます。
管理する状態が少なければいいですが、増えてくると煩雑になりそうです。
そこで、2つの状態をまとめてしまいましょう。
また、updateもまとめてできるようにしておくと便利です。

class State:
    def __init__(self, params, tx):
        self.params = params
	self.tx = tx
	self.opt_state = tx.init(params)

    def update(self, grads):
        updates, opt_state = self.tx.update(grads, self.opt_state)
        params = update(self.params, updates) # updatesでパラメータを更新
	return self.replace(params, opt_state)

    def replace(self, params, opt_state):
        copied = self.copy()
	copied.params = params
	copied.opt_state = opt_state

# 初期化
state = State(
    init_params,
    optax.sgd(learning_rate=0.1, momentum=0.9)
    )
# 学習(1step分)
grads = 勾配計算(loss_func, dataset, state.params)
state = state.update(grads)

学習部分がStateupdateメソッドに移動してスッキリ書けるようになりました。

Stateクラスに新しくreplaceというのが増えています。これは何でしょうか。

ここまで触れてきませんでしたが、副作用を回避するためには状態は変更できないという制約が必要になります。この制約がないと、関数に渡したparamsやopt_stateが勝手に変わっていないという保証がなくなるからです。
jax/flaxではFrozenDictという変更できないデータ型(immutable)を使います。
ですが、意図した変更はどうするのか、というと明示的に新しいインスタンスを作って上書きします。
新しいインスタンス生成と上書きを分けて書くと次のような感じになりますが、今まではまとめて1行で書いていました。

new_state = なんかの処理(state)
state = new_state

新しいインスタンスを作るときに元の状態から一部だけを変えたいときに、毎回すべてのメンバを指定するのは冗長なので、変えたい部分だけ指定するためにreplaceを使います。
今回の例はparamsopt_stateを変えてtxだけが残るので、効果がわかりにくいですが、管理している状態が増えたときに必要になります。

ここで、以前の書き方とクラスの役割の違いを考えてみましょう。
従来の書き方ではモデルとオプティマイザが主役になり、内部状態を自身が管理していました。

副作用を回避する方法では状態をimmutableな状態クラスで表現します。この状態クラスはモデルとオプティマイザの更新用関数を知っており、状態を更新するときにこれら更新用関数を使って新しいインスタンスを生成します。生成したインスタンスで上書きするのはクラスの外で行います。

変数の更新がクラスの内部で勝手に行われるか、外部で行われるかというのが大きな違いです。
この違いがあるため、今までと同じ考え方ではうまくコードが書けないので、設計段階で意識しておく必要があります。

jax/flaxのサンプルコード

ここで自作したStateはflaxではTrainStateというクラスで用意されています。updateメソッドの名前がapply_gradientsになったり、initの変わりにクラスメソッドのcreateを使うようになっていたりという違いはありますが、やろうとしていることは同じです。

また、Modelも自作していたので、flaxが用意したものを使うようにします。flaxではnn.Moduleというクラスを継承してモデルを作るとパラメータをFrozenDictで作ってくれます。
書き方にはPyTorchライクなsetup方式とcompact方式がありますが、ここではわかりやすさのため、setup方式を採用しますが、compact方式のほうがコード行数は短くなります。

まず、人工のデータセットを1次関数で作っておきます。線型回帰でwbがそれぞれ3と1が正解になります。

@jax.jit
def func(x):
    return 3*x + 1

X = jnp.arange(-1, 1, 0.01).reshape(200,1)
Y = func(X)
ds = {
    'input': X,
    'output': Y
}

jax.jitというのは関数をJITでコンパイルするためのデコレータです。初めて計算するときにコンパイルされ、2回目以降は高速計算が可能です。

jnpというのはjaxでnumpyを置き換えたもので、import jax.numpy as jnpという定義をしておきます。
numpy互換の関数が使えます。
jaxでGPUが有効化されている場合、最初からGPUのメモリに乗っています。

疑似乱数を初期化します。

rng_key = jax.random.PRNGKey(0)

モデルを用意します。nnflax.linenというモジュールで、importのときにimport flax.linen as nnとしておきます。

class Model(nn.Module):
    def setup(self):
        self.dense = nn.Dense(1)

    def __call__(self, x):
        y = self.dense(x)
        return y

# モデル初期化
rng_key, model_init_key = jax.random.split(rng_key)
model = Model()
_m = model.init(model_init_key, jnp.ones([1,1])) # 乱数でパラメータを初期化

モデルの構造をsetupで定義しています。これを元にinitで渡した仮の入力値でパラメータ数が決まり、乱数のkeyを渡すことでそれぞれのパラメータが乱数で初期化されます。

乱数を使う際には必ず事前にsplitしておきます。
次に乱数を使うときは使っていない方のrng_keyをsplitして新しいものを生成してから使用します。使うもの+次に回すものにsplitする、と考えるクセをつけておくと良いと思います。

モデルの初期化で得られる_mの中身はこのようになっています。kernelは1x1のdense層のカーネルなので、y=wx+bwに、biasbになります。

FrozenDict({
    params: {
        dense: {
            kernel: Array([[-1.3232204]], dtype=float32),
            bias: Array([0.], dtype=float32),
        },
    },
})

オプティマイザはoptaxに用意されているものを使います。ここではSGDを使います。

tx = optax.sgd(learning_rate=0.1, momentum=0.9)

モデルとオプティマイザを格納するTrainStateを作ります。
create内部でオプティマイザは初期化され、両者を格納したTrainStateが作られます。

# TrainStateにモデルのパラメータとオプティマイザを登録
state = TrainState.create(
    apply_fn=model.apply,
    params=_m['params'],
    tx=tx)

学習ステップを定義します。

@jax.jit
def train_step(state, batch):

    def loss_fn(params):
        pred_y = state.apply_fn({'params': params}, batch['input'])
        loss = optax.l2_loss(pred_y, batch['output']).mean()
        return loss
    val_grad_fn = jax.value_and_grad(loss_fn)

    loss, grads = val_grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return loss, state

jax.value_and_gradは導関数を求めるものです。lossが必要ない場合はjax.gradを使います。loss_fnを入力params、出力lossと考えて、lossparamsで微分して得られる導関数\nabla lossval_grad_fnになります。
このval_grad_fnは関数になっていて任意のパラメータwに対する勾配\nabla loss(w)を求められます。val_grad_fnの返り値は元のloss_fnの出力であるlossと勾配gradになります。

この勾配をstate.apply_gradientsに渡せば、stateに格納されているparamsopt_stateを更新したstateが返されます。
このとき、元のstateはimmutableなFrozenDictであるためapply_gradients内部では更新されません。
ここでは、stateを更新したいので、新しいインスタンスである返り値でstateを上書きします。

これを何回か繰り返すと勾配法で線型回帰ができます。
ミニバッチのやり方は本記事の主題から外れるので、今回はミニバッチを使わず、全データを使って勾配を計算しています。

for t in range(100):
    loss, state = train_step(state, ds)

全体では以下のようになります。

import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.training.train_state import TrainState
import optax

import matplotlib.pyplot as plt
import seaborn as sns

# 人工データセット
@jax.jit
def func(x):
    return 3*x + 1

X = jnp.arange(-1, 1, 0.01).reshape(200,1)
Y = func(X)
ds = {
    'input': X,
    'output': Y
}

sns.set()
plt.plot(X, Y)
plt.show()

# 疑似乱数
rng_key = jax.random.PRNGKey(0)

# モデル
class Model(nn.Module):
    def setup(self):
        self.dense = nn.Dense(1)

    def __call__(self, x):
        y = self.dense(x)
        return y

# モデル初期化
rng_key, model_init_key = jax.random.split(rng_key)
model = Model()
_m = model.init(rng_key, jnp.ones([1,1])) # 乱数でパラメータを初期化

# 最適化アルゴリズム
tx = optax.sgd(learning_rate=0.1, momentum=0.9)

# TrainStateにモデルのパラメータとオプティマイザを登録
state = TrainState.create(
    apply_fn=model.apply,
    params=_m['params'],
    tx=tx)

# 学習ステップの定義
@jax.jit
def train_step(state, batch):

    def loss_fn(params):
        pred_y = state.apply_fn({'params': params}, batch['input'])
        loss = optax.l2_loss(pred_y, batch['output']).mean()
        return loss
    val_grad_fn = jax.value_and_grad(loss_fn)

    loss, grads = val_grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return loss, state

# 学習
loss_history = []
opt_w_history = []
opt_b_history = []
w_history = []
b_history = []

for t in range(100):
    loss, state = train_step(state, ds)
    
    loss_history.append(loss)
    opt_w_history.append(state.opt_state[0][0]['dense']['kernel'][0])
    opt_b_history.append(state.opt_state[0][0]['dense']['bias'][0])
    w_history.append(state.params['dense']['kernel'][0])
    b_history.append(state.params['dense']['bias'][0])

plt.plot(loss_history, label="loss")
plt.plot(w_history, label="w")
plt.plot(b_history, label="b")
plt.plot(opt_w_history, label="opt_w")
plt.plot(opt_b_history, label="opt_b")
plt.legend()
plt.show()

最後にloss, w, b, モーメンタムのw, bをグラフにすると以下のようになります。

results

まとめ

副作用のないコードを書くという観点でオブジェクト指向の深層学習からはじめて、jax/flaxの流儀にたどり着くまでの流れを解説しました。皆様の参考になれば幸いです。

Discussion