🦄

拡散生成モデルで学ぶJax/Flaxによる深層学習プログラミング

2022/07/31に公開

はじめに

深層学習モデルやその学習を実装する際には、多くの場合でPyTorchやKerasなどのフレームワークが使われます。本記事では、Googleより公開されているJaxというフレームワークを用いた深層学習プログラミングを紹介します。

コードは以下に配置しています。

https://github.com/daigo0927/jax-ddim

Jaxとは

JaxはGoogleから公開されている、自動微分を備えた数値計算ライブラリと言えます。Numpyとほぼ同じように計算処理を実装でき、またGPUやTPUによって高速に演算を実行することもできます。これによって深層学習モデルを実装し、学習することができます。またNumpyと近い使い方ができるので、やろうと思えば深層学習以外の多くのアルゴリズムを実装することもできます。

Jaxにはいくつかの派生ライブラリがあります。深層学習でよく利用されるような畳み込み層やバッチ正規化などはFlaxというライブラリで提供されており、本記事でもこれを使ってモデルを構築していきます。他にDeepMindからも深層学習モデル構築を意識したHaikuや、強化学習向けのRLaxなどがJaxベースの派生ライブラリとして公開されています。

https://jax.readthedocs.io/en/latest/

拡散生成モデル

拡散生成モデル(Denoising Diffusion Model)は、画像に対するノイズの付与とそれに対するデノイズ(ノイズ除去)によって画像の生成を学習するモデルです。最終的にはノイズ画像から反復的にノイズ除去を繰り返すことで、画像生成が可能になります。

本記事ではJaxの使い方に注目するため、拡散生成モデルについては詳しく扱いません。拡散生成モデルについては下記の記事などが参考になります。

https://keras.io/examples/generative/ddim/

https://cvpr2022-tutorial-diffusion-models.github.io/

Jax/Flaxによる深層学習プログラミング

ここではJaxとFlaxを用いて、深層学習モデルの構築や学習などがどのように実装できるか紹介します。

モデルの実装

ニューラルネットワークのモデル構築では、FlaxのAPIであるflax.linenを用います。Jaxのみでナイーブに実装することもできますが、Flaxを用いることで重みの生成(初期化)や更新を簡単に実行できます。Exampleなどではimport flax.linen as nnとして利用されることが多いようです。

全結合層や畳み込み層などといった基本的な処理はnn.Dense, nn.ConvなどのAPIが提供されています。独自の処理やまとまった処理層を実装したい場合は、nn.Moduleを継承してクラスを実装することになります。例えば、今回のコードにおけるResidualBlockは以下のように実装できます。

model.py
from flax import linen as nn

class ResidualBlock(nn.Module):
    features: int

    @nn.compact
    def __call__(self, x, train: bool):
        input_features = x.shape[3]
        if input_features == self.features:
            residual = x
        else:
            residual = nn.Conv(self.features, kernel_size=(3, 3))(x)

        x = nn.BatchNorm(use_running_average=not train,
                         use_bias=False, use_scale=False)(x)
        x = nn.Conv(self.features, (3, 3), 1, 1)(x)
        x = nn.swish(x)
        x = nn.Conv(self.features, (3, 3), 1, 1)(x)
        x += residual
        return 

処理を実行する__call__メソッドが@nn.compactデコレータで修飾されていることがわかります。これによって処理内で使う層の指定と、処理の実行をまとめて記載することができます。

一方で、層の指定を分けて記載する方法もあります。この場合はsetupメソッドを用意し、その中で必要な層を記載することになります。例えば、今回のコードでは拡散生成モデル全体を表すクラスで、setupメソッドを用いています。

model.py
class DiffusionModel(nn.Module):
    # UNet parameters
    feature_stages: List[int] = field(default_factory=lambda:
                                      [32, 64, 96, 128])
    blocks: int = 2
    min_freq: float = 1.0
    max_freq: float = 1000.0
    embedding_dims: int = 32

    # Sampling (reverse diffusion) parameters
    min_signal_rate: float = 0.02
    max_signal_rate: float = 0.95

    def setup(self):
        self.normalizer = nn.BatchNorm(use_bias=False, use_scale=False)
        self.network = UNet(feature_stages=self.feature_stages,
                            blocks=self.blocks,
                            min_freq=self.min_freq,
                            max_freq=self.max_freq,
                            embedding_dims=self.embedding_dims)

    def __call__(self, images, rng, train: bool):
        ...

基本的にsetup, nn.compactは同じように動作します。Flaxのドキュメントによれば、それぞれを使うモチベーションとしては以下が挙げられています。

  • nn.compactを使うモチベーション
    • サブモジュールやパラメータ、その他の変数を利用箇所と同じ場所に記載できるため、コードがまとまる
    • 条件分岐やForループなどに応じてサブモジュールやパラメータを定義する際に、コードの重複が少なくて済む
    • 数式的な記法に近く処理を実装できる場合がある
    • 入力変数のshapeに応じてパラメータのshapeや値を決めたい場合
  • setupを使うモチベーション
    • PyTorchに近い感覚で実装できる
    • 人によってはサブモジュールや変数の定義とその利用箇所を分けて記載する方が自然に感じる
    • 単一のフォワードパス以外にも、複数の処理を実装できる

https://flax.readthedocs.io/en/latest/guides/setup_or_nncompact.html

データセットの処理

JaxはNumpy配列を入力としてシームレスに処理を実行できるため、モデルの学習に用いるデータセットはいろいろな方法で実装できます。今回のコードでは、TensorFlow Datasetsを用いて、Oxford Flowers102という画像のデータセットをNumpy配列をイテレートします。

train.py
import tensorflow as tf
import tensorflow_datasets as tfds

def preprocess_image(data, image_size):
    # 画像のクロップ、リサイズ、値のスケーリング
    ... 
    return tf.clip_by_value(image / 255.0, 0.0, 1.0)


def prepare_datasets(image_size: int = 64,
                     batch_size: int = 64):
    dataset_name = 'oxford_flowers102'
    split_train = 'train[:80%]+validation[:80%]+test[:80%]'
    split_val = 'train[80%:]+validation[80%:]+test[80%:]'

    preprocess_fn = partial(preprocess_image, image_size=image_size)
    
    ds_train = tfds.load(dataset_name, split=split_train, shuffle_files=True)\
                   .map(preprocess_fn, num_parallel_calls=tf.data.AUTOTUNE)\
                   .cache()\
                   .shuffle(buffer_size=10*batch_size)\
                   .batch(batch_size, drop_remainder=True)\
                   .prefetch(buffer_size=tf.data.AUTOTUNE)
    ds_train = tfds.as_numpy(ds_train)
                   
    ds_val = tfds.load(dataset_name, split=split_val, shuffle_files=True)\
                 .map(preprocess_fn, num_parallel_calls=tf.data.AUTOTUNE)\
                 .cache()\
                 .batch(batch_size, drop_remainder=True)\
                 .prefetch(buffer_size=tf.data.AUTOTUNE)
    ds_val = tfds.as_numpy(ds_val)

    return ds_train, ds_val
    
if __name__ == '__main__':
    tf.config.experimental.set_visible_devices([], 'GPU')
    ...

データのキャッシュやミニバッチ化などはtf.dataAPIを通じて実行しており、最後にtfds.as_numpyによってNumpy配列をイテレートするデータセットインスタンスを生成しています。

今回は画像の前処理で部分的にTensorFlowを使っています。このときTensorFlowによってGPUが確保されてしまうのを防ぐため、tf.config.experimental.set_visible_devices([], 'GPU')という指示を実行しています。

下記のドキュメントではそれぞれtensorflow/datasetsとPyTorchのデータローダーを使った場合のデータのロード方法が紹介されています。

https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html#data-loading-with-tensorflow-datasets

https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html

モデルの学習・評価

Flaxによって作成したモデルとNumpy配列によるデータセットを用いて、モデルを学習します。

明示的な乱数生成

Jaxにおける擬似乱数生成(pseudo random number generation: PRNG)では、処理の再現性や並列性、ベクトル化のために乱数生成のキーを明示的に指定します。

最初のキーはjax.random.PRNGKey(0)として作成できます(引数は0でなくても良い)。生成したキーはjax.random.splitによって分割し、独立した新たなキーを作成していくことができます。

同じキーを用いた乱数生成は同じ結果になります。よって独立した乱数を生成したい場合には、随時キーを分割して行く必要があります。またドキュメントによれば、異なる乱数生成関数でも、同じキーを利用すると結果に相関が発生し得るため、基本的に非推奨とされています。

https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html

モデルの初期化によるパラメータの生成

学習を実行する前に、モデルを初期化してパラメータを生成する必要があります。Flaxのnn.Moduleを継承して作成したモデルクラスは、model.initメソッドによってパラメータ生成のための乱数生成キーとサンプル入力を渡すことで初期化し、パラメータを生成することができます。

今回のコードでは以下のようになります。

train.py
import jax

from model import DiffusionModel

...

def run(...):
    rng = jax.random.PRNGKey(0)
    rng, key_init, key_diffusion = jax.random.split(rng, 3)

    image_shape = (batch_size, image_size, image_size, 3)
    dummy = jnp.ones(image_shape, dtype=jnp.float32)
    
    model = DiffusionModel()
    variables = model.init(key_init, dummy, key_diffusion,
                           train=True)
    ...

はじめにjax.random.PRNGKey(0)によって乱数生成キーを作成しています。これをさらにjax.random.splitによって分割しています。key_initはモデルの初期化のため、key_diffusionは拡散生成モデルで用いる乱数画像を生成するため、rngは以降の処理で引き続き乱数生成キーを作るためのキーです。

model.initはモデルの初期化キーと、推論のためのサンプル入力を受け取り、パラメータを初期化するメソッドです。DiffusionModelクラスはnn.Moduleを継承しているためこのメソッドを通じてパラメータを初期化することができます。他の引数であるdummy, key_diffusion, trainDiffusionModel.__call__メソッドの引数として渡されます。

初期化によって得られるvariablesは辞書型の変数であり、variables['params']にモデルのパラメータが格納されています。また今回のモデルではバッチ正規化を用いており、バッチ統計量についてはvariables['batch_stats']からアクセスできます。

KerasやPyTorchでは、パラメータはモデルインスタンスの内部に保持されていますが、Flaxではモデルとパラメータを明示的に別に扱っており、特徴的に感じます。

https://flax.readthedocs.io/en/latest/advanced_topics/module_lifecycle.html#construction-initialization

モデルの推論

Flaxにおけるモデルの推論は、model.apply(variables, *args, ...)という形で実行できます。variablesはモデルで利用するためのパラメータやバッチ統計量などの値であり、*argsは推論時の入力変数です。

applyメソッドにはmethodという引数があり、これによって実行する推論メソッドを指定できます。デフォルトではmethod=Noneとなっており、自動的にmodel.__call__が実行されます。

モデル内でバッチ正規化などの処理を用いている場合は、推論時に変数を上書きします。この場合にはapplyメソッドでmutable引数を指定する必要があります。mutable引数を指定した場合は、推論の出力と同時に変更された変数も出力されます。

outputs, mutated_vars = model.apply(variables, *args, mutable=['batch_stats'])
# mutated_vars['batch_stats']に変更後のバッチ統計量が格納される

https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#module

パラメータの更新

Flaxによるモデルの学習では、学習状態(パラメータやオプティマイザ)を管理するためにflax.training.train_state.TrainStateというクラスを以下のように利用できます。

state = TrainState.create(
    apply_fn=model.apply,
    params=variables['params'],
    tx=tx)

apply_fnstateインスタンスから推論処理を実行するための引数であり、前述のmodel.applyを渡すことで柔軟に推論処理を指定できます。paramsは学習時の管理対象となるパラメータであり、txはオプティマイザです。

TrainStateは管理対象となるパラメータやオプティマイザと、勾配法によるパラメータ更新メソッドを備えただけのシンプルなクラスであるため、簡単に拡張できるとされています。例えば今回のコードでは、バッチ統計量も合わせて保持しておくため、新たにTrainStateクラスを用意しています。

train.py
from typing import Any
from flax.training import train_state
import optax


class TrainState(train_state.TrainState):
    batch_stats: Any
    
def run(...):
    ...
    state = TrainState.create(
        apply_fn=model.apply,
        params=variables['params'],
        batch_stats=variables['batch_stats'],
        tx=optax.adamw(learning_rate, weight_decay=weight_decay)
    )
    ...

https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#train-state

損失関数と勾配の計算、それに基づくパラメータの更新は以下のように実装できます。

train.py
def l1_loss(predictions, targets):
    return jnp.abs(predictions - targets)
    
    
@jax.jit
def train_step(state, batch, rng):
    def loss_fn(params):
        outputs, mutated_vars = state.apply_fn(
            {'params': params, 'batch_stats': state.batch_stats},
            batch, rng, train=True,
            mutable=['batch_stats']
        )
        noises, images, pred_noises, pred_images = outputs
        
        noise_loss = l1_loss(pred_noises, noises).mean()
        image_loss = l1_loss(pred_images, images).mean()
        loss = noise_loss + image_loss
        return loss, mutated_vars
    
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, mutated_vars), grads = grad_fn(state.params)
    state = state.apply_gradients(
        grads=grads,
        batch_stats=mutated_vars['batch_stats'])
    return state, loss
    

def run(...):
    ...
    for epoch in range(epochs):
        for images in ds_train:
            state, loss = train_step(state, images, key)
    ...

train_stepが1回のパラメータ更新を表す関数であり、loss_fnがパラメータを入力として損失(および更新後のバッチ統計量)を計算する関数です。jax.grad(loss_fn)とすることで、損失関数に対するパラメータの導関数を得ることができます。

今回のコードではjax.value_and_gradという関数を使っており、これによってloss_fnの出力とパラメータの勾配をまとめて返す関数を得ることができ、grad_fn(state.params)によってそれぞれの値を計算しています。

各パラメータの勾配を得たあとは、state.apply_gradientsによってパラメータを更新します。今回のコードでは、引数にbatch_statsを合わせて指定することで、stateインスタンス内のバッチ統計量も置き換えています。

https://flax.readthedocs.io/en/latest/guides/jax_for_the_impatient.html#gradients-and-autodiff

https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#taking-derivatives-with-grad

JITコンパイルによる高速化

JaxではJITコンパイルを用いることで、関数の処理をコンパイル・キャッシュしておき、GPUやTPUでの計算を高速化することができます。今回の例ではパラメータの更新とモデルの評価(画像生成の検証)を@jax.jitによってデコレートしています。

train.py
@jax.jit
def train_step(state, batch, rng):
    ...
    return state, loss
        

@partial(jax.jit, static_argnums=4)
def evaluate(state, params, rng, batch, diffusion_steps: int):
    variables = {'params': params, 'batch_stats': state.batch_stats}
    generated_images = state.apply_fn(variables,
                                      rng, batch.shape, diffusion_steps,
                                      method=DiffusionModel.generate)
    return generated_images

evaluate関数については、@partial(jax.jit, static_argnums=4)によって、diffusion_steps(0から数えて4番目の引数)が定数であることを指定しています。static_argnumsで指定した引数が同じ場合のみ、コンパイル・キャッシュされた処理が実行されます。定数として指定した値が変更されると再度コンパイルの時間が発生するため注意が必要です。

JITコンパイルを適用することで、初回実行時はコンパイルの時間を要しますが、以降同じ処理を実行する際はキャッシュされた関数が利用され、処理を高速化することができます。特にevaluate関数は、JITコンパイルによって大きく高速化できることを確認しました。

基本的にはプログラムの中で何回も実行される、ある程度大きな処理のまとまり(今回はパラメータの更新や画像生成の評価)をjax.jitでコンパイルするのが良いようです。これによってコンパイラによる最適化の余地が大きくなります。

https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#using-jit-to-speed-up-functions

https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html

その他

本節では、今回のコードでは使っていないものの有用・重要と思われるものを簡単にまとめています。

テンソル操作

einopsは多次元のテンソル操作を直感的に実装するためのライブラリです。PyTorchやTensorFlowに加えてJaxもサポートしており、複雑なテンソル操作を簡潔に実装する上で有用と言えます。

https://github.com/arogozhnikov/einops

学習率のスケジューリング

勾配法によるパラメータ更新では、学習の進行に伴って学習率を調整(スケジューリング)することがあります。JaxではオプティマイザのAPIを提供しているOptaxを通じて、学習率のスケジューリングを利用できます。

https://flax.readthedocs.io/en/latest/guides/lr_schedule.html

https://optax.readthedocs.io/en/latest/api.html#optimizer-schedules

マルチデバイスの利用

深層学習の開発では、複数のGPUやTPUを利用することがあり、Jaxでもこれが可能です。関数をjax.pmapでデコレートすることで、複数のデバイスごとに対象の処理を実行することができます。またflax.jax_utils.replicate, unreplicateを用いることで、デバイス間で変数を複製・統合することができます。

下記ドキュメントでは、デバイスごとに異なる乱数キーを用いてモデルを初期化し、複数のモデルを並行で学習する方法を紹介しています。

https://flax.readthedocs.io/en/latest/guides/ensembling.html

単一のモデルを複数のデバイスに複製し、学習データをデバイスに分散して推論・学習する場合(データ並列)も概ね同様に実装できると考えられます。また下記ディスカッションでは、マルチデバイスでバッチ正規化の統計量を同期する方法について触れています。

https://github.com/google/flax/discussions/2080

学習済みモデル

JaxはPyTorchやKerasほどには学習済みモデルは豊富ではない一方で、いくつかの学習済みモデルが公開されています。以下のリポジトリではResNetやGPT-2が公開されています。

https://github.com/matthias-wright/flaxmodels

またTransformerベースのモデルを多く公開しているhuggingface/transformersでも、いくつかのモデルがJax形式で提供されています。

https://github.com/huggingface/transformers

デプロイ

Jaxでは実験段階のAPIとして、FlaxのモデルをTensorFlowのSavedModel形式に変換するjax2tf APIを提供しています。SavedModel形式のモデルは、TensorFlowのエコシステムであるTFLiteやTF.jsなどを利用することができます。

https://flax.readthedocs.io/en/latest/guides/flax_basics.html#exporting-to-tensorflow-s-savedmodel-with-jax2tf

https://www.tensorflow.org/lite/examples/jax_conversion/overview

まとめ

本記事ではJaxおよびその派生ライブラリであるFlaxを用いて、拡散生成モデルを実装する上で必要になったTipsを紹介しました。PyTorchやKerasなどの深層学習ライブラリに比べると、Jaxはパラメータの保持や更新などで自分で実装する部分が多いと言えます。一方でAPIは洗練されており、またライブラリに隠匿される部分が少ないため、細かいところまで思い通りに実装しやすいとも感じました。今後も機会があれば深層学習モデルの実装に使ってみようと思います。

Discussion