🕌

関数型プログラミングっぽくDeep Learningをしたい

2023/12/23に公開

本記事は Uzabase Advent Calendar 2023 の 23 日目の記事になります。

はじめに

なんとなく社内 Slack の盛り上がりを感じてアドベントカレンダーに勢いで立候補したものの最近はブログを一切書いていなかったので、取り立てて書きたいようなネタはなかったので、とりあえず最近プライベートでよく使って遊んでいる JAX/Flax について書いていこうかと思います。
なぜ JAX/Flax かというと、ちょっと前に以下の whisper-jax を見かけて pytorch の実装と比べ、 jax で実装するとこんなに推論速度を早くできるのかとワクワクしたというのが最初の動機です。

今回どう言う内容で書いていくかというと、速度の検証とかは他の人のブログでよく見かけたので、JAX/Flax の安全性や面白い仕様の部分に焦点を絞って書いていくつもりです。
とにかく JAX/Flax はいいぞをこの記事で少しでも伝えられれば良いなと思います。
また、私は関数型プログラミングに関しては会社で Rust や F# をまれに書いたりしてるくらいなだけで、そんなに知識がある方ではないので、「こんなん全然関数型プログラミングっぽくない」などあれば Discussion などで優しくご指摘いただければと思います。

JAX とは

jax が何かをざっくり言うと、numpy に自動微分がついた数値計算のライブラリという感じのものです。
さらに、マルチ GPU やマルチ TPU をサポートしているので、GPU リソースを活用して計算速度を上げることができたりする優れものです。
numpy のインターフェースに似ているので、numpy に慣れている人なら容易に使えるようなライブラリになっています。
また、JAX は XLA を使用して GPU や TPU 上でプログラムを jit コンパイルして、実行できるので計算速度を上げる工夫もできます。
jit コンパイルを行う方法は numba のようにデコレーター@jitを関数に対してつけるだけでできると言う点もお手軽です。
個人的に良いなと感じている点は行列をイミュータブルなオブジェクトとして定義している点です。
numpy の場合、行列のオブジェクトが可変で副作用があるので、脳を numpy に支配されてしまった場合、常人には解読不能な黒魔術的なコードを書いてしまってコードが読みづらくなってしまったり、テストが書きづらいようなコードになってしまったりしたことはないでしょうか?
黒魔術とは違いますが、よくある簡単な例で言えば、以下のようなスクリプトを書いたとします。

import numpy

def example_main():
    a = numpy.ones((5, 5))
    print("before: ", a)
    example_function(a)
    print("after: ", a)

def example_function(a):
    a[1:4, 2:3] = 0
    return iikanji_function(a)

def iikanji_function(a):
    pass

example_main()

このように書いたとき、実は大変なことが起こってしまいます。実行してみると以下のような標準出力がでます。

before:  [[1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]]
after:  [[1. 1. 1. 1. 1.]
 [1. 1. 0. 1. 1.]
 [1. 1. 0. 1. 1.]
 [1. 1. 0. 1. 1.]
 [1. 1. 1. 1. 1.]]

このように example_main 関数の中の a の値が example_function 関数によって書き換えられてしまうのです。
これが原因で数多のバグを引き起こしてしまう惨劇を人類は何度繰り返してしまったか…実装者のスキル不足といえばそれまでですが、できればそういうのは仕組みで解決したいと思うのが自然です。行列がイミュータブルなオブジェクトであればどうなるでしょうか?
同様のコードを jax で書いてみます。ほとんど numpy と同じように実装できるので、簡単に既存のコードの置き換えができるのも嬉しいポイントです。

import jax

def example_main():
    a = jax.numpy.ones((5, 5))
    print("before: ", a)
    example_function(a)
    print("after: ", a)

def example_function(a):
    a[1:4, 2:3] = 0
    return iikanji_function(a)

def iikanji_function(a):
    pass

example_main()

これを実行してみると以下のようなエラーが出ます。

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<stdin>", line 4, in example_main
  File "<stdin>", line 2, in example_function
  File "/HOGEHOGE/python3.9/site-packages/jax/_src/numpy/array_methods.py", line 285, in _unimplemented_setitem
    raise TypeError(msg.format(type(self)))
TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

エラーの意味としては行列に代入するときはa[1:4, 2:3] = 0ではなく、a = a.at[1:4, 2:3].set(0)と言う形式で書けと言うことを言っています。この書き方をすることで、元々のaの行列を直接変更せず、aの変更が行われたような新しい行列を作成してそれをaに代入する形にすることで行列の副作用をなくしています。
では言われた通りに書き換えてみましょう。

import jax

def example_main():
    a = jax.numpy.ones((5, 5))
    print("before: ", a)
    example_function(a)
    print("after: ", a)

def example_function(a):
    a = a.at[1:4, 2:3].set(0)
    return iikanji_function(a)

def iikanji_function(a):
    pass

example_main()

実行結果は以下のようになります。aの値が上書きされることがなくなりました。

before:  [[1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]]
after:  [[1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]]

これにより、安全に行列の計算を行うことができ、Python 初心者のデータサイエンティストの書くバグを大幅に減らせることになるでしょう。素晴らしいですね。
Rust などの言語を利用していると最初からこう言う変なことをしなくて済む話ではありますが、Python の簡単さと Python の豊富なエコシステムをそのままに同様に対策ができるのはありがたいものです。

Flax とは

Flax とは JAX を利用したニューラル ネットワークのフレームワークです。
特徴としては、先述の JAX の利点に乗っかることができる点と、compactと言う仕組みによりモデルのコードがより短くできる点、なるべくモデルクラスの副作用をなくすような関数型言語っぽいな思想など、なるべく安全にコードを書く仕組みを取り入れている面白いフレームワークとなっています。
また、Flax の哲学が以下のドキュメントに記載されているので読んでみてください。

https://flax.readthedocs.io/en/latest/philosophy.html

この Flax の哲学の中でも以下の思想は個人的に気づきがあって好きでした。

“Read the manual” is not an appropriate response to developer confusion. The framework should guide developers towards good solutions, such as through assertions and error messages.
An unhelpful error message is a bug.

Flax はドキュメントも豊富ではあるものの、それだけで満足せずに上記にある通りエラー文にこだわっていて、実際エラー文を読めば大体原因がわかるので、触ってる分には非公式のドキュメントが少なくてもあまり困りませんでした。

まずはモデルのコードがより短くできる点について説明します。
例えば、モデルのコードは pytorch では以下のように書くことになります。

from torch import nn

class MyTorchModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(1, 100)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(100, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

一方で Flax では、モデルのコードは以下のように実装します。

from flax import linen as nn

class MyFlaxModel(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=100)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        return x

pytorch のモデルのコードと比べてコード量が少なく済むのがわかるかと思います。
pytorch のモデルではコンストラクタの中でモデルを定義する形になっているので、forward メソッドの中身を追いたい時に毎回initのところを見るためにスクロールしたり、ジャンプしたりする必要があり、読みづらさを感じたことがあると思います。(感じたことがない人もいるかもしれませんが)
これを Flax では@nn.compact のデコレーターをつけることでその課題を軽減しています。
とはいえ、例えば、複数の forward メソッドを用意したいなどモデルの仕様によっては定義したい時もあると思います。その場合は setup メソッドを生やしてそこで pytorch と同様に定義することもできます。

また、Flax の面白い点としてはモデルのパラメータと状態の管理が分離されている点です。
どう言うことかというと、モデルを推論するための以下のコードを見てください。

from jax import random
key1, model_init_params = random.split(random.key(0), 2)
x = random.uniform(key1, (4,4))

model = MyFlaxModel()
params = model.init(model_init_params, x)
y = model.apply(params, x)

重要なのは実際に推論をおこなっているy = model.apply(params, x)の部分です。
注目してほしいのは推論時にモデルの推論のメソッドにモデルのパラメーターをモデルの外から渡していて、モデルのオブジェクトがパラメーターを持っていない点です。
こうするメリットとしては、以下のものがあると思います。(他にも色々ありそう)

  • 副作用がないので、モデルのテストが書きやすい。
  • paramsさえあれば完全に同じモデルを再現できるので、異なるパラメータを使って同じモデル構造で実験を行うことがやりやすい。

関数型言語的なクラス

Flax の面白いところはモデルのクラスにも関数型言語の思想が表れているところが特に面白いです。
どう言うことかというと、Flax のモデルクラスのメンバー変数は更新ができません。
なので、副作用のあるコードを書くことができない仕組みになっています。
実際に例を見てみましょう。以下のコードを見てください。

from flax import linen as nn

class MyFlaxNet(nn.Module):
    state: int
    @nn.compact
    def __call__(self, x):
        self.state["hoge"] = 1
        x = nn.Dense(features=100)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        return x

model_init_params = random.key(0)
x = jnp.ones((5,1))
model = MyFlaxNet(2)
params = model.init(model_init_params, x)

上記のコードを実行すると以下のようなエラーが出力されます。最高ですね。

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<stdin>", line 5, in __call__
flax.errors.SetAttributeFrozenModuleError: Can't set state=1 for Module of type MyFlaxNet: Module instance is frozen outside of setup method. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.SetAttributeFrozenModuleError)

モデルクラスの内部でメンバー変数を変更しちゃうことでバグが発生したりすることもあるので、これがあることで安全にコードを書くことができます。
ちなみに辞書のメンバー変数の場合でもイミュータブルになって、辞書に追加することができなくなります。
また、パラメーターについてもFrozenDictの形式で保持されているので、学習中などに dict による上書きをしてしまって、バグるみたいなことも防げたりします。
一方でこう言う仕組みによって自由度を失うことで窮屈さを感じる方もいるかとは思いますし、書きたい処理によっては結構面倒臭いと思います。
例えば、RNN のような再帰的なネットワークを書くときは毎回 forward に状態を渡すみたいなことを書かないといけなくなるので、煩わしさを感じるかもしれないです。
とは言え、以下のような状態を更新する裏技はあったりします。

from flax import linen as nn

class MyFlaxNet(nn.Module):
    @nn.compact
    def __call__(self, x, state):
        state[len(state)] = len(state)
        print("state:", state)
        x = nn.Dense(features=100)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        return x

model_init_params = random.key(0)
x = jnp.ones((5,1))
model = MyFlaxNet()
params = model.init(model_init_params, x, state)
state = {}
for _ in range(3):
  y_hat = model.apply(params, x, state)

これを実行すると、以下のような出力が得られます。

state: {0: 0, 1: 1}
state: {0: 0, 1: 1, 2: 2}
state: {0: 0, 1: 1, 2: 2, 3: 3}

こんなふうにメソッドの引数に関してはイミュータブルになっていないので、こう言う裏技が使えてしまう感じになっています。
これを使って state を書くことはできてしまうものの変数のスコープを意識する必要があるため基本的には脳内メモリに厳しいのでなるべくこの書き方はしない方が良いでしょう。
状態を保持したいそんなケースもあるでしょう。その場合は公式から推奨されているvariableメソッドを利用して状態を管理すると良さそうです。

テストが書きやすい

前述の通り、Flax はモデルクラスを関数型言語のように副作用がないクラスを利用するので、モデルの内部状態がどうなっているかを意識せずにテストが書けるので比較的書きやすい感じになっています。
例えば、最初にモデルがあるパラメータの時に何かを返すと言う以下のようなテストが pytorch と比べて分かりやすく書けます。

from jax import random
import jax.numpy as jnp

def test_flax_model_with_zero_parameters():
    model_init_params = random.key(0)
    model = MyFlaxModel()
    x = jnp.ones((5,1))
    params = model.init(model_init_params, x)
    # パラメーターの行列を全部0の行列にする
    all_zero_params = jax.tree_map(lambda x: jnp.zeros_like(x), params)
    actual = model.apply(all_zero_params, x)
    expected = jnp.zeros((5,10))
    assert actual.tolist() == expected.tolist()

pytorch で同様のテストをしようとすると以下のようなテストになります。

import torch
def test_torch_model_with_zero_parameters():
    model = MyTorchModel()
    # パラメーターの行列を全部0の行列にする
    for param in model.parameters():
        torch.nn.init.zeros_(param)
    x = torch.ones((5,1))
    actual = model(x)
    expected = torch.zeros((5,10))
    assert actual.data.tolist() == expected.data.tolist()

違いとしては Flax の場合はモデルの外からパラメーターを渡しているのに対して、pytorch ではモデルをインスタンス化して、モデルの内部の変数を書き換える形になっている点です。
「どちらでも同じようにテストができるんだから良いじゃん」と言えなくも無いのですし、「コード量を見ても pytorch の方が少ないし、良いのでは?」と思うかもしれません。
pytorch のコードの場合はモデルのパラメーターを変更することを意識しておかないとバグを産んでしまうわけですが、flax では apply でモデルのパラメーターを求められるので強制的に意識させてくれるので若干脳内メモリに優しい感じになっています。
また、テストコードが長くなってしまっているケースだと、デバッグの際に pytorch のコードだとテスト全体を読まないとパラメーターを変更している箇所が分かりにくいですが、flax の場合はmodel.applyメソッドからパラメーターを辿ることができるのでそこは若干楽かもです。

異なるパラメータを使って同じモデル構造で実験しやすい

異なるパラメータを使って同じモデル構造で実験はしやすいです。例えば、アンサンブルみたいなのは書きやすいかもしれません。
まずは pytorch の例で簡単な感じで異なるパラメータを使って同じモデル構造で推論をするモデルを書くと大体以下のような感じになると思います。(vmap 使って高速化できる余地はあるけどそこは本題じゃないのでここでは取り扱いません)
ちなみに以下の例だとすでにモデルファイルが存在する前提のコードになります。

def ensemble_pytorch_model():
    NUM_MODELS = 3
    x = torch.ones((5,1))
    predictions = []
    for index in range(NUM_MODELS):
        model_tmp = torch.load(f"/tmp/model-{index}")
        predictions.append(model_tmp(x))
    return predictions

このコードではモデルを丸ごと読み込んで、そのモデルで推論みたいなことをしていますが、どの構造のモデルを使ったのかがコードからは分かりません。
そのため、モデルの構造を知りたい場合は print をせざるを得ないので、あまり可読性の高いコードとは呼べません。なので、おそらくほとんどの人はパラメーターをtorch.saveで保存して、それを読み込む以下のようなコード[1]で実装しているかと思います。(これがpytorch の推奨しているやり方でもあります)

def ensemble_pytorch_model():
    NUM_MODELS = 3
    x = torch.ones((5,1))
    predictions = []
    for index in range(NUM_MODELS):
        model = MyTorchModel()
        model.load_state_dict(torch.load(f"/tmp/model-{index}"))
        predictions.append(model(x))
    return predictions

これでモデル構造をコードから読めるようになりました。しかし、毎回モデルを変えるたびにモデルを再度インスタンス化している処理に気づいたでしょうか?若干面倒臭いですね。さらに、このループの中で色々な処理をしていることもあり得ますし、その際に model を変更してしまうコードを書かないように注意を払って実装を進める必要があったり、他の人が副作用を気にせず書いてしまって、バグを産んでしまったりすることもあるかもしれません。(そこまで model を変更しちゃうケースはないとは思いますけどね)

Flax の方を見てみましょう。実際は TrainState を checkpoint で保存するので微妙に実用と違いますが、以下のコードではパラメーターだけ保存しておいたのを読み込んで apply に渡すことで、異なるパラメータを使って同じモデル構造で推論しています。

import orbax.checkpoint as ocp

def ensemble_flax_model():
    NUM_MODELS = 3
    x = jax.numpy.ones((5,1))
    model = MyFlaxModel()
    checkpointer = ocp.PyTreeCheckpointer()
    predictions = []
    for index in range(NUM_MODELS):
      param = checkpointer.restore(f"/tmp/model-flax-{index}")
      predictions.append(model.apply(param, x))
    return predictions

これもどちらでも同じようにかけますが、pytorch の時に微妙だった model を毎回インスタンスする必要がなくてループ内のコードがスッキリしていますし、モデルを変更する心配もなく、安心して実装していくことができます。

まとめ

ここまでで自分が面白いなと思った JAX/Flax の仕様や思想について以下のような内容について書きました。

  • モデルクラスを関数型言語のように副作用がない形で実装を強制してくれる
  • モデルのパラメーターを外部から渡すことで、モデルの内部の状態を意識する必要がなくなる
  • アンサンブルや実験などで同じ構造のモデルを別のパラメーターで推論を実行する処理でモデルの内部状態を意識せずにパラメーターを切り替えられる

ここまで JAX/Flax の良いところを書いてきましたが、pytorch と比べて機能は少ないですし、所々で実装しづらい仕様になっている部分もあって学習コストも比較的高めにはなってしまうので、実際のプロジェクトで採用するかのどうかの判断は慎重に行った方が良いかと思います。
とは言え、特にチーム開発をする上では pytorch よりかは JAX/Flax ではイミュータブルな変数を扱うことを強制されるので安全に書けて、バグへの懸念が少なく済むことで、コードレビューの負担が減るので、良い選択肢にもなり得るかと思います。
また、今回は特に記載はしなかったですが、以下のように素の pytorch より処理速度は早いと言う話もあるので、それだけでも選ぶ価値はあるかと思います。

とは言え、pytorch も今だと色々速度改善する方法は出てきているので、色々工夫していくと結局 Flax よりも速いと言う結果が出る可能性もあるかもしれません。
なので、その辺りの高速化技術についても検討して技術選定していく必要はありそうです。

今回で TrainState、Optax などを利用した学習部分のコードの書き方についてや並列で学習する容易さについても触れたいところではありましたが、気力の都合により省きました。

JAX/Flax は安全性と実行速度を両立した技術的に面白いライブラリ・フレームワークなので、ぜひ触ってみてください。

脚注
  1. 念の為注意書きなのですが、今回はエンジニアリング力がない人がコードを書いたときの状況に近くするのと説明を簡単にするためにあえて 1 つの関数で完結させるように書いてますが、実際は load する処理は関数の外側に書いた方が良いコードになります。モデルを読み込む責務と推論する責務の 2 つが含まれていて単一責務の原則に反しているからです。 ↩︎

Discussion