今こそはじめるJAX/Flax入門 Part 1
1. はじめに
2012年から始まった深層学習の発展の過程で、さまざまな学習フレームワークが登場しました。中でもPyTorchとTensorflowは最も広く使われており、それぞれのフレームワークが支持されている背景には、柔軟性、拡張性、そして使いやすさがあります。
一方で、これらのフレームワークはその機能を拡張し続けてきた結果として、全体として非常に巨大で複雑なライブラリになっています。そのため、独自に機能拡張を行いたいユーザーにとっては扱いづらく、性能的にもオーバーヘッドを感じさせることがあります。
そこで新たに出てきたのが「JAX」とその関連ライブラリの組み合わせになります。2019年に登場して以降、特に海外の開発者に支持されてきました。近年注目されている大規模言語モデル(LLM)の分野においても、JAXによるモデルが公開されていることは珍しくなくなりつつあります。
PyTorch(赤)とJAX(青)のGitHubスター数の推移
JAXは、数値計算を行うための高性能ライブラリで、組み込みの自動微分やXLAコンパイラを活用することで、大規模な数値計算を高速に実行することが可能です。FlaxはJAXの上に構築された汎用ニューラルネットワークライブラリで、機械学習のためにより簡潔で直感的なAPIを提供しています。
本記事とそのシリーズでは、JAXとその機械学習用の高レベルAPIであるFlaxの基本から、具体的な実装例までを解説し、これらの技術が機械学習の最前線でどのように活用されているかを紹介していきます。
この記事で解説している内容:
- JAXの特徴と基本的な使い方
- Flaxの特徴と基本的な使い方
- JAX/FlaxによるMNIST学習の実装
2. JAXの基本
2-1. JAXとは?
JAXは、高性能数値計算、特に機械学習研究用に設計されたPythonライブラリです。もともとはGoogle Research(現 Google Deepmind)のチームによって開発されました。そのため、特にGoogle内部の研究者・開発者がよく使っている傾向があります。
JAXは非常にざっくりいうと、機械学習用に設計されたNumPyと説明することができます。実際、JAXの数値関数用のAPIは、NumPyと高度に互換性があり、NumPyと基本的には全く同じようにコードを実装することが可能です。このNumPy APIに加えて、機械学習に役立つ拡張として、次のような特徴があります。
-
自動微分のサポート: 深層学習の基礎となる順方向/逆方向の勾配ベースの演算をネイティブにサポートしています。
grad
,hessian
,jacfwd
,jacrev
などの関数変換を用いて簡単に実現できます。 -
ベクトル化:深層学習では、バッチ全体のLossを計算したり、複数GPUによる分散学習など、単一の処理や関数を多くのデータやデバイスに適用することがよくあります。JAXは任意の関数を
vmap
を介して並列化したり、単一のデバイスでは大きすぎる処理をpmap
により分散させたりすることが可能です。 - JITコンパイル:JAXはXLA(Accelerated Linear Algebra)に基づくJIT(Just-In-Time)コンパイルによる高速化が可能です。XLAは線形代数のためのドメイン固有のコンパイラで、計算グラフを最適化し、効率的な実行を目指すもので、特にGoogle TPU(Tensor Processing Unit)にやNVIDIA GPU向けに最適化されています。
これらの機能により、「高速性」と「スケーラビリティ」を両立した計算処理をNumpyベースのインターフェースで簡単に書くことができます。以下ではJAXの公式ドキュメントを参考に、JAXの基本的な使い方をみていきたいと思います。
2-2. インストール方法
JAXでサポートされているプラットフォームとしては次の通りになっています。ここではUbuntuにおけるCPU/GPU環境でのインストール方法を見ていきます。
- CPU
- NVIDIA GPU
- Google TPU
- AMD GPU (Experimental)
- Apple GPU (Experimental)
CPU
pip install --upgrade pip
pip install --upgrade "jax[cpu]"
GPU
現在JAXはCUDA Capability 5.2 (Maxwell) 以降を搭載したNVIDIA GPUをサポートしています。pip経由でインストールする場合、CUDA12がインストールされている必要があります。(CUDA≥12.1, cuDNN ≥ 8.9, <9.0, NCCL ≥ 2.18)
pip install --upgrade pip
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
それ以外にもGPU向けのDockerコンテナも用意されています。
2-3. JAXの基本操作
JAXはPythonで実装されており、APIも含めてNumPyと高度な互換性があります。ここでは例としてシンプルな行列積を扱います。
まずはNumPyでの実装例です。
import numpy as np
a = np.array([[1, 2, 3], [4, 5, 6]])
b = a.T
c = np.dot(a, b)
JAXではjax.numpy
を用いることでNumPyと全く同じように記述することができます。(ちなみにjax.numpy
は慣習的にjnp
として書くことが多いようです。)
import jax.numpy as jnp
a = jnp.array([[1, 2, 3], [4, 5, 6]])
b = a.T
c = jpn.dot(a, b)
次に、ランダムな行列積を計算する方法を見ていきます。NumPyで3000x3000の行列積を計算してみます。
import numpy as np
a = np.random.normal(size=(3000, 3000)).astype(np.float32)
b = a.T
c = np.dot(a, b)
JAXではjax.random
を用いて行列を定義します。
import jax.numpy as jnp
from jax import random
key = random.key(0)
a = random.normal(key, (3000, 3000), dtype=jnp.float32)
b = a.T
c = np.dot(a, b)
NumPyと違い、JAXでは乱数生成のためのkey
を明示的に定義して渡す必要があります。JAXは関数型言語の思想が強く反映されており、内部状態を持たない純粋関数を必要としていることに由来しています。そのため、random関係の実装はNumPyと差分が生じることに注意してください。
jit()
さて、ここからはJAX特有の機能について見ていきます。まずはJITコンパイルによる高速化です。SELU
関数を実装し、実行速度を計測します。
import jax.numpy as jnp
from jax import random
import timeit
def selu(x, alpha=1.67, lamba=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
key = random.key(0)
x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()
# 1.07 ms ± 261 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
JAXはCPU/GPU/TPU上で透過的に実行されますが、特に指定をしない場合は一つの処理ごとに毎回GPUやTPUにディスパッチして実行されオーバーヘッドが大きくなります。そこでjit()
で一連の操作をまとめてXLAコンパイルすることで高速化できます。
from jax import jit
selu_jit = jit(selu)
%timeit selu(x).block_until_ready()
# 127 µs ± 1.43 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
また、@jit
デコレータを用いることで関数定義時にjitを指定することも出来ます。
@jit
def selu(x, alpha=1.67, lamba=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
grad()
JAXでは数値を計算するだけでなく、自動微分の機能がサポートされています。Autogradと同じようにgrad
で簡単に微分を計算することができます。
import jax.numpy as jnp
from jax import grad
def tanh(x):
y = jnp.exp(-2.0 * x)
return (1.0 - y) / (1.0 + y)
grad_tanh = grad(tanh)
print(grad_tanh(1.0)) # 0.4199743
print(grad(grad(grad(tanh)))(1.0)) # 0.62162673
より高度な自動微分を実装したい場合は自動微分のAPIリファレンスを参照してください。
vmap()
JAXの強力な機能として、vmap()
(Vectorized Map)による自動ベクトル化があります。ベクトル化をざっくりと説明すると、ある関数に対して複数のデータをバッチ処理として適用できるように拡張する機能です。おそらく実装をみたほうが理解しやすいかと思います。
例えばsquare関数を定義し、バッチ処理したい場合、手動で実装すると次のようになります。
import jax.numpy as jnp
from jax import ramdom
def square(x):
return x * x
def batch_square(vec):
return jnp.array([square(x) for x in vec])
key = random.key(0)
v = random.normal(key, (1000000,))
y = batch_square(v)
これはPythonによる繰り返し処理に依存するため、パフォーマンスは低くなることが多いです。また、複雑なデータ構造や関数に対してはそもそも実装が難しいケースもあります。このような場合、vmapを用いることで非常に簡潔かつ高速に動作させることができます。
import jax.numpy as jnp
from jax import ramdom, vmap
def square(x):
return x * x
key = random.key(0)
v = random.normal(key, (1000000,))
# vmapを使用して自動ベクトル化します。
vectorized_square = vmap(square)
y = vectorized_square(v)
vmap()
は、元の関数を取り、それを配列の各要素に自動的に適用する新しい関数を返します。これにより、ループを明示的に書くことなく、高速なベクトル演算を利用して、複数の入力に対して関数を適用することができます。例えば、スカラー入力を取る関数がある場合、vmapを使用すると、その関数をベクトルまたは行列の各要素に適用できるようになります。
pmap()
機械学習では、大量のデータを扱う場合、複数のGPUやTPUを用いて並列計算が必要になるケースがあります。そのような場合にはpmap()
を使うことで高速並列通信の操作を含む、単一プログラム複数データ(SMDP)処理を簡単に実装することができます。
import jax.numpy as jnp
from jax import random, pmap
# 8枚のGPUがある環境で行列積を並列実行します。
# 8つのランダムな5000x6000 行列を定義します。
keys = random.split(random.PRNGKey(0), 8)
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)
# 明示的なデータ転送なく、行列積を各GPUで並列実行できます。
result = pmap(lambda x: jnp.dot(x, x.T))(mats) # result.shape = (8, 5000, 5000)
# 各GPUの計算結果の平均を求めます。
print(pmap(jnp.mean)(result))
このように、JAXはPython/NumPyのインターフェースを備えつつ、GPU/TPUなどで容易に高速化・並列化が可能なフレームワークとなっています。
3. Flaxの基本
3-1. Flaxとは?
JAXの使い方を紹介してきましたが、実はこのままでは深層学習の処理を直接実行すること簡単ではありません。そのために、JAXをベースとした機械学習ライブラリを組み合わせる必要があります。JAXを基盤とした機械学習ライブラリは複数ありますが、Flaxはその中でも最も人気があるフレームワークの一つです。
FlaxはもともとGoogle ResearchのBrainチーム内のエンジニアと研究者によって (JAXチームと緊密に連携して) 開発が開始されましたが、現在はオープンソース コミュニティと共同開発されています。
JAX/Flaxで利用できるモデルとして、HuggingFaceで自然言語処理、画像処理、音声処理をはじめとする様々なモデルが公開されています。
3-2. インストール方法
JAXがインストールされていればpipでインストールすることができます。
pip install flax
3-3. Flaxの基本操作
FlaxはJAXと併用することで、簡単に機械学習を実装することが可能です。まずは単一のDenseレイヤーを例に使い方を確認していきたいと思います。
まずはnn.Dense
を定義して推論してみます。
import flax
import jax
from flax import linen as nn
from jax import random, numpy as jnp
# 単一のDenseレイヤーを定義します。
model = nn.Dense(features=5)
# モデルパラメータの初期化を行います。
key1, key2 = random.split(random.key(0))
x = random.normal(key1, (10,)) # ダミーの入力データ
params = model.init(key2, x) # パラメータの初期化
# モデルの推論
model.apply(params, x)
# DeviceArray([-0.7358944, 1.3583755, -0.7976872, 0.8168598, 0.6297793], dtype=float32)
このモデルを学習していきます。まずは学習用のデータx_sample
とy_sample
を用意します。ランダムに生成した重みとバイアスでx_sample
を入力して計算し、ノイズを付与してy_sample
を求めます。
n_samples = 20
x_dim = 10
y_dim = 5
# 新規に乱数生成用のキーを用意します。
key_w, key_b, key_sample, key_noise, _ = random.split(key1, 5)
x_sample = random.normal(key_sample, (n_sample, x_dim))
# 教師データ生成用の重み
W = random.normal(key_w, (x_dim, y_dim))
b = random.normal(key_b, (y_dim,))
noise = random.normal(key_noise,(n_samples, y_dim))
y_sample = jnp.dot(x_samples, W) + b + 0.1 * noise
学習にはJAXベースの最適化ライブラリであるOptaxを使用します。もともとFlaxには独自の組み込み最適化パッケージflax.optaxがあったのですが、現在では廃止されてOptaxが標準になっています。
import optax
# MSE損失を定義します。
@jax.jit
def mse_loss(params, x, y):
pred = model.apply(params, x)
return jnp.mean(optax.l2_loss(pred, y))
# AdamによるOptimizer設定します。
learning_rate = 0.3
optimizer = optax.adam(learning_rate=learning_rate)
opt_state = optimizer.init(params)
loss_grad_fn = jax.value_and_grad(mse_loss)
# 訓練の実行
for i in range(101):
loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
if i % 10 == 0:
print(f'Loss step {i}: ', loss_val)
# Loss step 0: 0.011576377
# Loss step 10: 0.0115710115
# Loss step 20: 0.011569244
# Loss step 30: 0.011568661
# Loss step 40: 0.011568454
# Loss step 50: 0.011568379
# Loss step 60: 0.011568358
# Loss step 70: 0.01156836
# Loss step 80: 0.01156835
# Loss step 90: 0.011568353
# Loss step 100: 0.011568348
学習したモデルパラメータを保存するにはserialization.to_bytes
を用います。
from flax import serialization
bytes_save = serialization.to_bytes(params)
with open("./params.bin", "wb") as f:
f.write(bytes_save)
逆に保存したファイルからモデルパラメータを呼びだすにはserialization.from_bytes
を使います。
with open("./params.bin", "rb") as f:
bytes_load = f.read()
serialization.from_bytes(params, bytes_load)
Flaxでは独自のニューラルネットワークを定義するとき、nn.Module
および@nn.compact
を使って簡潔に記述することが出来ます。下記の例は単純な畳み込みニューラルネットワークですが、さらに複雑なモデルも同様にModuleベースで実装することができます。
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3), strides=(1, 1))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # Flatten the layer
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x
4. JAX/FlaxでMNISTを学習する
JAX/Flaxを使った初歩的な機械学習として、MNISTの学習を実装してみます。簡単のため、データセットはtorchvisionを使ってダウンロードします。
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
def load_data():
# PyTorchのデータローディングAPIを使用してMNISTを読み込みます
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
return train_loader, test_loader
モデルを定義します。ここでは3-3で定義したCNNを流用します。
from flax import linen as nn
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3), strides=(1, 1))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # Flatten the layer
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x
訓練の状態を管理するため、Train Stateを導入します。TrainState
クラスは、機械学習モデルの訓練中の状態を管理するための便利なツールです。このクラスは訓練プロセスにおけるさまざまなコンポーネント(パラメータ、オプティマイザの状態、メトリクスなど)を一元管理し、これによってコードの整理が容易になり、エラーが減少します。
from flax.training import train_state
class CNNTrainState(train_state.TrainState):
batch_stats: Any
def create_train_state(rng_key, learning_rate):
cnn = CNN()
params = cnn.init(rng_key, jnp.ones([1, 28, 28, 1]))['params']
optimizer = optax.adam(learning_rate)
return CNNTrainState.create(
apply_fn=cnn.apply, params=params, tx=optimizer
)
損失関数と訓練ステップを実装します。train_step
関数をJITコンパイルすることで効率的に学習を実行します。ここではsoftmax_cross_entry
損失を用い、勾配を計算してTrain Stateに反映させます。
import jax
import jax.numpy as jnp
import optax
@jax.jit
def train_step(state, batch):
def loss_fn(params):
logits = state.apply_fn({'params': params}, batch['image'])
loss = jnp.mean(optax.softmax_cross_entropy(
logits=logits, labels=jax.nn.one_hot(batch['label'], num_classes=10))
)
return loss
grad_fn = jax.value_and_grad(loss_fn)
loss, grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
return state, loss
最後に訓練を実行します。
def main():
train_loader, test_loader = load_data()
key = random.key(0)
state = create_train_state(key, learning_rate=0.001)
for epoch in range(10):
for batch in train_loader:
images, labels = batch
batch = {'image': np.array(images), 'label': np.array(labels)}
state, loss = train_step(state, batch)
print(f'Epoch {epoch}: loss = {loss}')
main()
学習したモデルの推論結果
5. 結局、JAX/Flaxは何がうれしいのか?
ここからは少しJAX/Flaxに関する「お気持ち」を書いていきたいと思います。おそらくこの記事でJAX/Flaxに触れた人の中には「それPyTorchでよくない?」という感想を持った方もいるのではないかと思います。実際、単一の計算機で機械学習をする場合は、PyTorchを使うほうが便利な方がほとんどかと思います。
本記事では部分的にしか紹介できませんでしたが、JAXは関数型プログラミングの思想に強く影響を受けています。そのため、JAXで独自のフレームワークを実装する際にも関数型プログラミングのメリットや技法をそのまま適用できます。具体的には、純粋関数による副作用の回避やイミュータブルなデータ構造によるバグの低減、関数の組み合わせと再利用性などが例としてあげられます。このような利点を活用することで、より生産性高く機械学習のコードを実装することができます。
個人的には、JAX/Flaxは安全性と高速性、スケーラビリティを要求する大規模な機械学習プロダクトの開発に最大限価値を発揮するのではないのかと思います。特にJAXはバックエンドデバイスやその並列化・高速化の機能が強力なので、シミュレータと統合した学習環境を容易に構築することができます。
例えば、JAXで実装されているBraxというライブラリはロボット工学、人間の知覚、材料科学、強化学習、およびその他のシミュレーションを重視したアプリケーションの研究開発に使用される、高速で完全に微分可能な物理エンジンとして開発されています。
Braxのシミュレーター環境
また、Google傘下の自動運転を開発企業であるWaymoは最近JAXベースの自動運転シミュレータWaymaxを公開しました。Waymaxは自動運転研究用の軽量、マルチエージェント、JAX ベースのシミュレーターです。計画のための閉ループ シミュレーションやシミュレーション エージェントの研究から開ループの動作予測に至るまで、自動運転における動作研究のさまざまな側面の研究をサポートするように設計されています。
Waymaxの自動運転シミュレーション
JAXの特性に着目して開発されているライブラリは他にも多く開発されています。このようなシミュレータやパイプラインを組み合わせることで、今までにない強力な深層学習の枠組みを実現できる可能性があります。
近年、LLMをはじめとする大規模なAIモデルの応用がますます広がってきています。一方でテキストだけでなく現実の環境で動作させるには課題も多くあり、そのうちの一つが「身体性(embodiment)の獲得」です。
身体性は機械やAIが物理的な環境において自らの身体(あるいは身体に相当するメカニズム)をどのように認識し、制御し、そこから学習していくかを指す概念です。AIが自らの感覚を統合し、物理法則や他の物体とのインタラクション、自己認識を得る必要があり、大規模なシミュレータ環境による学習が必要になる可能性があります。
万人向けのフレームワークとはいえないものの、新しい機械学習の枠組みを生み出すようなケースでは、JAX/Flaxはますます有力な選択肢となりそうです。
6. まとめ
この記事ではJAX/Flaxの基本的な使い方について紹介しました。JAX/Flaxは速度やTPUへの実行が着目されがちですが、その本質は関数型プログラミングの思想に基づく安全性とスケーラビリティにあるのではないかと思います。
本記事の続編では、環境シミュレータを含めた、より総合的な事例を取り上げます。JAX/Flaxのさらに実践的な使い方として、実装が容易で速度が重要なオセロAIを題材に、以下のような内容の公開を予定しています。
- Part 2: JAXによるオセロ盤面の高速化
- Part 3: Flaxによる方策関数の学習
- Part 4: 分散学習環境における深層強化学習の実現
最後に、筆者の所属するTuringでは、自動運転を実現するためのシステム開発を行っており、このような大規模で新しい機械学習を実現する方法ついても着目しています。Turingは完全自動運転EVの開発を目指すスタートアップで、機械学習から3Dアプリケーション、車体の制作に至るまで、ソフトウェア・ハードウェアを横断的なエンジニアリングに取り組んでいます。
もし興味のある方は私のTwitterのDMにてご連絡いただくか、採用ページで「スカウトを待つ」にぜひ登録をお願いします!
Discussion