拡散生成モデルで学ぶJax/Flaxによる深層学習プログラミング
はじめに
深層学習モデルやその学習を実装する際には、多くの場合でPyTorchやKerasなどのフレームワークが使われます。本記事では、Googleより公開されているJaxというフレームワークを用いた深層学習プログラミングを紹介します。
コードは以下に配置しています。
Jaxとは
JaxはGoogleから公開されている、自動微分を備えた数値計算ライブラリと言えます。Numpyとほぼ同じように計算処理を実装でき、またGPUやTPUによって高速に演算を実行することもできます。これによって深層学習モデルを実装し、学習することができます。またNumpyと近い使い方ができるので、やろうと思えば深層学習以外の多くのアルゴリズムを実装することもできます。
Jaxにはいくつかの派生ライブラリがあります。深層学習でよく利用されるような畳み込み層やバッチ正規化などはFlaxというライブラリで提供されており、本記事でもこれを使ってモデルを構築していきます。他にDeepMindからも深層学習モデル構築を意識したHaikuや、強化学習向けのRLaxなどがJaxベースの派生ライブラリとして公開されています。
拡散生成モデル
拡散生成モデル(Denoising Diffusion Model)は、画像に対するノイズの付与とそれに対するデノイズ(ノイズ除去)によって画像の生成を学習するモデルです。最終的にはノイズ画像から反復的にノイズ除去を繰り返すことで、画像生成が可能になります。
本記事ではJaxの使い方に注目するため、拡散生成モデルについては詳しく扱いません。拡散生成モデルについては下記の記事などが参考になります。
Jax/Flaxによる深層学習プログラミング
ここではJaxとFlaxを用いて、深層学習モデルの構築や学習などがどのように実装できるか紹介します。
モデルの実装
ニューラルネットワークのモデル構築では、FlaxのAPIであるflax.linen
を用います。Jaxのみでナイーブに実装することもできますが、Flaxを用いることで重みの生成(初期化)や更新を簡単に実行できます。Exampleなどではimport flax.linen as nn
として利用されることが多いようです。
全結合層や畳み込み層などといった基本的な処理はnn.Dense, nn.Conv
などのAPIが提供されています。独自の処理やまとまった処理層を実装したい場合は、nn.Module
を継承してクラスを実装することになります。例えば、今回のコードにおけるResidualBlockは以下のように実装できます。
from flax import linen as nn
class ResidualBlock(nn.Module):
features: int
@nn.compact
def __call__(self, x, train: bool):
input_features = x.shape[3]
if input_features == self.features:
residual = x
else:
residual = nn.Conv(self.features, kernel_size=(3, 3))(x)
x = nn.BatchNorm(use_running_average=not train,
use_bias=False, use_scale=False)(x)
x = nn.Conv(self.features, (3, 3), 1, 1)(x)
x = nn.swish(x)
x = nn.Conv(self.features, (3, 3), 1, 1)(x)
x += residual
return
処理を実行する__call__
メソッドが@nn.compact
デコレータで修飾されていることがわかります。これによって処理内で使う層の指定と、処理の実行をまとめて記載することができます。
一方で、層の指定を分けて記載する方法もあります。この場合はsetup
メソッドを用意し、その中で必要な層を記載することになります。例えば、今回のコードでは拡散生成モデル全体を表すクラスで、setup
メソッドを用いています。
class DiffusionModel(nn.Module):
# UNet parameters
feature_stages: List[int] = field(default_factory=lambda:
[32, 64, 96, 128])
blocks: int = 2
min_freq: float = 1.0
max_freq: float = 1000.0
embedding_dims: int = 32
# Sampling (reverse diffusion) parameters
min_signal_rate: float = 0.02
max_signal_rate: float = 0.95
def setup(self):
self.normalizer = nn.BatchNorm(use_bias=False, use_scale=False)
self.network = UNet(feature_stages=self.feature_stages,
blocks=self.blocks,
min_freq=self.min_freq,
max_freq=self.max_freq,
embedding_dims=self.embedding_dims)
def __call__(self, images, rng, train: bool):
...
基本的にsetup, nn.compact
は同じように動作します。Flaxのドキュメントによれば、それぞれを使うモチベーションとしては以下が挙げられています。
-
nn.compact
を使うモチベーション- サブモジュールやパラメータ、その他の変数を利用箇所と同じ場所に記載できるため、コードがまとまる
- 条件分岐やForループなどに応じてサブモジュールやパラメータを定義する際に、コードの重複が少なくて済む
- 数式的な記法に近く処理を実装できる場合がある
- 入力変数のshapeに応じてパラメータのshapeや値を決めたい場合
-
setup
を使うモチベーション- PyTorchに近い感覚で実装できる
- 人によってはサブモジュールや変数の定義とその利用箇所を分けて記載する方が自然に感じる
- 単一のフォワードパス以外にも、複数の処理を実装できる
データセットの処理
JaxはNumpy配列を入力としてシームレスに処理を実行できるため、モデルの学習に用いるデータセットはいろいろな方法で実装できます。今回のコードでは、TensorFlow Datasetsを用いて、Oxford Flowers102という画像のデータセットをNumpy配列をイテレートします。
import tensorflow as tf
import tensorflow_datasets as tfds
def preprocess_image(data, image_size):
# 画像のクロップ、リサイズ、値のスケーリング
...
return tf.clip_by_value(image / 255.0, 0.0, 1.0)
def prepare_datasets(image_size: int = 64,
batch_size: int = 64):
dataset_name = 'oxford_flowers102'
split_train = 'train[:80%]+validation[:80%]+test[:80%]'
split_val = 'train[80%:]+validation[80%:]+test[80%:]'
preprocess_fn = partial(preprocess_image, image_size=image_size)
ds_train = tfds.load(dataset_name, split=split_train, shuffle_files=True)\
.map(preprocess_fn, num_parallel_calls=tf.data.AUTOTUNE)\
.cache()\
.shuffle(buffer_size=10*batch_size)\
.batch(batch_size, drop_remainder=True)\
.prefetch(buffer_size=tf.data.AUTOTUNE)
ds_train = tfds.as_numpy(ds_train)
ds_val = tfds.load(dataset_name, split=split_val, shuffle_files=True)\
.map(preprocess_fn, num_parallel_calls=tf.data.AUTOTUNE)\
.cache()\
.batch(batch_size, drop_remainder=True)\
.prefetch(buffer_size=tf.data.AUTOTUNE)
ds_val = tfds.as_numpy(ds_val)
return ds_train, ds_val
if __name__ == '__main__':
tf.config.experimental.set_visible_devices([], 'GPU')
...
データのキャッシュやミニバッチ化などはtf.data
APIを通じて実行しており、最後にtfds.as_numpy
によってNumpy配列をイテレートするデータセットインスタンスを生成しています。
今回は画像の前処理で部分的にTensorFlowを使っています。このときTensorFlowによってGPUが確保されてしまうのを防ぐため、tf.config.experimental.set_visible_devices([], 'GPU')
という指示を実行しています。
下記のドキュメントではそれぞれtensorflow/datasets
とPyTorchのデータローダーを使った場合のデータのロード方法が紹介されています。
モデルの学習・評価
Flaxによって作成したモデルとNumpy配列によるデータセットを用いて、モデルを学習します。
明示的な乱数生成
Jaxにおける擬似乱数生成(pseudo random number generation: PRNG)では、処理の再現性や並列性、ベクトル化のために乱数生成のキーを明示的に指定します。
最初のキーはjax.random.PRNGKey(0)
として作成できます(引数は0でなくても良い)。生成したキーはjax.random.split
によって分割し、独立した新たなキーを作成していくことができます。
同じキーを用いた乱数生成は同じ結果になります。よって独立した乱数を生成したい場合には、随時キーを分割して行く必要があります。またドキュメントによれば、異なる乱数生成関数でも、同じキーを利用すると結果に相関が発生し得るため、基本的に非推奨とされています。
モデルの初期化によるパラメータの生成
学習を実行する前に、モデルを初期化してパラメータを生成する必要があります。Flaxのnn.Module
を継承して作成したモデルクラスは、model.init
メソッドによってパラメータ生成のための乱数生成キーとサンプル入力を渡すことで初期化し、パラメータを生成することができます。
今回のコードでは以下のようになります。
import jax
from model import DiffusionModel
...
def run(...):
rng = jax.random.PRNGKey(0)
rng, key_init, key_diffusion = jax.random.split(rng, 3)
image_shape = (batch_size, image_size, image_size, 3)
dummy = jnp.ones(image_shape, dtype=jnp.float32)
model = DiffusionModel()
variables = model.init(key_init, dummy, key_diffusion,
train=True)
...
はじめにjax.random.PRNGKey(0)
によって乱数生成キーを作成しています。これをさらにjax.random.split
によって分割しています。key_init
はモデルの初期化のため、key_diffusion
は拡散生成モデルで用いる乱数画像を生成するため、rng
は以降の処理で引き続き乱数生成キーを作るためのキーです。
model.init
はモデルの初期化キーと、推論のためのサンプル入力を受け取り、パラメータを初期化するメソッドです。DiffusionModel
クラスはnn.Module
を継承しているためこのメソッドを通じてパラメータを初期化することができます。他の引数であるdummy, key_diffusion, train
はDiffusionModel.__call__
メソッドの引数として渡されます。
初期化によって得られるvariables
は辞書型の変数であり、variables['params']
にモデルのパラメータが格納されています。また今回のモデルではバッチ正規化を用いており、バッチ統計量についてはvariables['batch_stats']
からアクセスできます。
KerasやPyTorchでは、パラメータはモデルインスタンスの内部に保持されていますが、Flaxではモデルとパラメータを明示的に別に扱っており、特徴的に感じます。
モデルの推論
Flaxにおけるモデルの推論は、model.apply(variables, *args, ...)
という形で実行できます。variables
はモデルで利用するためのパラメータやバッチ統計量などの値であり、*args
は推論時の入力変数です。
apply
メソッドにはmethod
という引数があり、これによって実行する推論メソッドを指定できます。デフォルトではmethod=None
となっており、自動的にmodel.__call__
が実行されます。
モデル内でバッチ正規化などの処理を用いている場合は、推論時に変数を上書きします。この場合にはapply
メソッドでmutable
引数を指定する必要があります。mutable
引数を指定した場合は、推論の出力と同時に変更された変数も出力されます。
outputs, mutated_vars = model.apply(variables, *args, mutable=['batch_stats'])
# mutated_vars['batch_stats']に変更後のバッチ統計量が格納される
パラメータの更新
Flaxによるモデルの学習では、学習状態(パラメータやオプティマイザ)を管理するためにflax.training.train_state.TrainState
というクラスを以下のように利用できます。
state = TrainState.create(
apply_fn=model.apply,
params=variables['params'],
tx=tx)
apply_fn
はstate
インスタンスから推論処理を実行するための引数であり、前述のmodel.apply
を渡すことで柔軟に推論処理を指定できます。params
は学習時の管理対象となるパラメータであり、tx
はオプティマイザです。
TrainState
は管理対象となるパラメータやオプティマイザと、勾配法によるパラメータ更新メソッドを備えただけのシンプルなクラスであるため、簡単に拡張できるとされています。例えば今回のコードでは、バッチ統計量も合わせて保持しておくため、新たにTrainState
クラスを用意しています。
from typing import Any
from flax.training import train_state
import optax
class TrainState(train_state.TrainState):
batch_stats: Any
def run(...):
...
state = TrainState.create(
apply_fn=model.apply,
params=variables['params'],
batch_stats=variables['batch_stats'],
tx=optax.adamw(learning_rate, weight_decay=weight_decay)
)
...
損失関数と勾配の計算、それに基づくパラメータの更新は以下のように実装できます。
def l1_loss(predictions, targets):
return jnp.abs(predictions - targets)
@jax.jit
def train_step(state, batch, rng):
def loss_fn(params):
outputs, mutated_vars = state.apply_fn(
{'params': params, 'batch_stats': state.batch_stats},
batch, rng, train=True,
mutable=['batch_stats']
)
noises, images, pred_noises, pred_images = outputs
noise_loss = l1_loss(pred_noises, noises).mean()
image_loss = l1_loss(pred_images, images).mean()
loss = noise_loss + image_loss
return loss, mutated_vars
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, mutated_vars), grads = grad_fn(state.params)
state = state.apply_gradients(
grads=grads,
batch_stats=mutated_vars['batch_stats'])
return state, loss
def run(...):
...
for epoch in range(epochs):
for images in ds_train:
state, loss = train_step(state, images, key)
...
train_step
が1回のパラメータ更新を表す関数であり、loss_fn
がパラメータを入力として損失(および更新後のバッチ統計量)を計算する関数です。jax.grad(loss_fn)
とすることで、損失関数に対するパラメータの導関数を得ることができます。
今回のコードではjax.value_and_grad
という関数を使っており、これによってloss_fn
の出力とパラメータの勾配をまとめて返す関数を得ることができ、grad_fn(state.params)
によってそれぞれの値を計算しています。
各パラメータの勾配を得たあとは、state.apply_gradients
によってパラメータを更新します。今回のコードでは、引数にbatch_stats
を合わせて指定することで、state
インスタンス内のバッチ統計量も置き換えています。
JITコンパイルによる高速化
JaxではJITコンパイルを用いることで、関数の処理をコンパイル・キャッシュしておき、GPUやTPUでの計算を高速化することができます。今回の例ではパラメータの更新とモデルの評価(画像生成の検証)を@jax.jit
によってデコレートしています。
@jax.jit
def train_step(state, batch, rng):
...
return state, loss
@partial(jax.jit, static_argnums=4)
def evaluate(state, params, rng, batch, diffusion_steps: int):
variables = {'params': params, 'batch_stats': state.batch_stats}
generated_images = state.apply_fn(variables,
rng, batch.shape, diffusion_steps,
method=DiffusionModel.generate)
return generated_images
evaluate
関数については、@partial(jax.jit, static_argnums=4)
によって、diffusion_steps
(0から数えて4番目の引数)が定数であることを指定しています。static_argnums
で指定した引数が同じ場合のみ、コンパイル・キャッシュされた処理が実行されます。定数として指定した値が変更されると再度コンパイルの時間が発生するため注意が必要です。
JITコンパイルを適用することで、初回実行時はコンパイルの時間を要しますが、以降同じ処理を実行する際はキャッシュされた関数が利用され、処理を高速化することができます。特にevaluate
関数は、JITコンパイルによって大きく高速化できることを確認しました。
基本的にはプログラムの中で何回も実行される、ある程度大きな処理のまとまり(今回はパラメータの更新や画像生成の評価)をjax.jit
でコンパイルするのが良いようです。これによってコンパイラによる最適化の余地が大きくなります。
その他
本節では、今回のコードでは使っていないものの有用・重要と思われるものを簡単にまとめています。
テンソル操作
einopsは多次元のテンソル操作を直感的に実装するためのライブラリです。PyTorchやTensorFlowに加えてJaxもサポートしており、複雑なテンソル操作を簡潔に実装する上で有用と言えます。
学習率のスケジューリング
勾配法によるパラメータ更新では、学習の進行に伴って学習率を調整(スケジューリング)することがあります。JaxではオプティマイザのAPIを提供しているOptaxを通じて、学習率のスケジューリングを利用できます。
マルチデバイスの利用
深層学習の開発では、複数のGPUやTPUを利用することがあり、Jaxでもこれが可能です。関数をjax.pmap
でデコレートすることで、複数のデバイスごとに対象の処理を実行することができます。またflax.jax_utils.replicate, unreplicate
を用いることで、デバイス間で変数を複製・統合することができます。
下記ドキュメントでは、デバイスごとに異なる乱数キーを用いてモデルを初期化し、複数のモデルを並行で学習する方法を紹介しています。
単一のモデルを複数のデバイスに複製し、学習データをデバイスに分散して推論・学習する場合(データ並列)も概ね同様に実装できると考えられます。また下記ディスカッションでは、マルチデバイスでバッチ正規化の統計量を同期する方法について触れています。
学習済みモデル
JaxはPyTorchやKerasほどには学習済みモデルは豊富ではない一方で、いくつかの学習済みモデルが公開されています。以下のリポジトリではResNetやGPT-2が公開されています。
またTransformerベースのモデルを多く公開しているhuggingface/transformersでも、いくつかのモデルがJax形式で提供されています。
デプロイ
Jaxでは実験段階のAPIとして、FlaxのモデルをTensorFlowのSavedModel形式に変換するjax2tf APIを提供しています。SavedModel形式のモデルは、TensorFlowのエコシステムであるTFLiteやTF.jsなどを利用することができます。
まとめ
本記事ではJaxおよびその派生ライブラリであるFlaxを用いて、拡散生成モデルを実装する上で必要になったTipsを紹介しました。PyTorchやKerasなどの深層学習ライブラリに比べると、Jaxはパラメータの保持や更新などで自分で実装する部分が多いと言えます。一方でAPIは洗練されており、またライブラリに隠匿される部分が少ないため、細かいところまで思い通りに実装しやすいとも感じました。今後も機会があれば深層学習モデルの実装に使ってみようと思います。
Discussion