📝

Jax/FlaxでKaggleをやってみよう!

2022/12/19に公開

こんにちは。いのいちです。
この記事は(の18日目の記事ですこの記事はKaggle Advent Calendr 2022の18日目の記事です。

最近はスプラトゥーンにはまっていて毎日忙しくてあまりコンペに参加できてませんが、2023年はどんどんコンペに参加していきたいなと思っています。私はコンペに参加するときはいつも1つはこれまでやったことないことをやると決めています。そこで新しい挑戦としてJax/Flaxを使ってみようと思い至りました。私が普段参加するComputer Vison系のコンペでは主にPytorchが使用されており、TPUをぶん回すコンペでTensorflowが使われていたりします。Jax/Flaxのnotebookも時々見かけますが、まだベースラインとなるようなnotebookが共有されたり、がっつりJax/Flaxでコンペをやったというのは見たことがありません。しかし、調べていくとJax/Flaxを使ってもコンペで十分戦える環境が整ってきているように感じました。
この記事では、その第一歩としてJax/Flaxを使ってkaggle notebookで学習し、サブミットするまでのコードを作って簡単に紹介しようと思っています。まだ使い始めたばかりなので、間違ってるところやバッドプラクティスもあると思いますので、もしお気づきの点があれば指摘していただけると嬉しいです!

JaxとFlax

まず簡単にJaxとFlaxについて紹介します。

Jaxとは

JaxとはGoogleが開発しているライブラリで、第一印象はGPUやTPUを簡単に使える微分可能なNumpyという感じです。
https://github.com/google/jax
https://jax.readthedocs.io/en/latest/index.html#

JAXはAutogradとXLAを組み合わせた、高性能な数値計算と機械学習研究のためのツールです。Python+NumPyのプログラムに対して、微分、ベクトル化、並列化、GPU/TPUへのジャストインタイムコンパイルなどの変換を行うことが可能です。
(公式Documentより)

Kaggleでモデル学習する上では、このあたりが特に嬉しい特徴に感じます。

  • Numpyと同じように書ける
  • XLAコンパイルで高速化できる
  • TPUも簡単に使える
    深層学習だけでなく、様々な課題使えるライブラリだと思います。Jax自体には深層学習に特化したモジュールはあまりないので、Kaggleで使う場合は次に説明するFlaxを活用します。

Flaxとは

FlaxもGoogleが開発しているライブラリで、Jaxで深層学習をするために開発されています。
https://github.com/google/flax
https://flax.readthedocs.io/en/latest/
Jaxベースの深層学習ライブラリを使うことでPytorchで構築するくらいの労力で学習のパイプラインを組むことが可能です。類似ライブラリとして他にもGoogleの別チームが開発しているtraxやDeepmindが作っているhaikuなどがあります。Repositoryのスター数でいえばtraxが一番多いのですが、HuggingFaceでも採用されたりとFlaxの方が勢いがあるように感じました。ViTの実装をはじめ、Google researchの研究でもよく使われていることがわかります。(参考: Google researchのRepositoryでimport flaxしているコード)野良のコードはまだそんなに多くありませんが、Google researchの実装を参考にできるなら困ることはないと思います。さらにFlaxはドキュメントも丁寧で充実しており読んでいて面白いのも、他のライブラリよりも良いポイントだと思います。
なので、数あるJaxベースの深層学習ライブラリの中から、Flaxを使うことにしました。

なんでFlaxをkaggleで使うの?

TPUを簡単に使えるところが魅力です。実験のイテレーションを高速にまわすことはKaggleでは必須です。また大きいモデルを学習できることで有利になることも多いと思います。TPUをうまく使いこなせば計算リソースで遅れを取ることはないと思っています。
またTPUを使うだけならTensorflowでも良いのですが、やはり実装の柔軟性は低いと感じます(私が慣れていないというのもありますが...。)。モデルの構築だけでなく、Loss関数の設計もJaxとFlaxを使えばNumpyを使うかのごとく簡単に設計できます。Pytorchは実装はしやすいですが、未だにTPUがうまく使えるのか使えないのかよくわからない感じです。Flaxはちょうどこの2つのライブラリの中間くらいに位置すると考えています。

Flaxで学習

Flaxでの実装を説明するために、3年前に開催されていたFlower Classification with TPUs
を使っていきたいと思います。
Kaggle notebookでTPUが使えるようになった頃に開かれたコンペで、このコンペのCode欄ではTPUを使う方法が数多く共有されています。データもTFRecordの形で与えられておりすぐに使えます。
コンペの課題としてもとてもシンプルで、花の種類を分類するコンペです。評価指標はマクロF1です。

サンプルnotebook

細かい実装についてはサンプルのnotebookを用意したのでそちらを見てください。

訓練のnotebookをみるとValidationのF1スコアがおおよそ0.65くらいで、サブミットの結果がPublic 0.68378 / Private 0.65325なのでスコア自体は低いですが正しく学習、推論できてると思います。
Multi-GPU(T4 x 2枚)とTPUの1エポックあたりの学習速度を比較すると、

Multi-GPU TPU
27秒 2秒

びっくりですね。
条件にもよると思いますが、TPUで学習が高速になることがわかりました。

Dataloaderの準備

まずは学習のイテレーションを回すためのDataloaderを準備していきます。PytorchやTensorflowにはDataloaderが用意されているのですが、FlaxではDataloaderは用意されていません。Flaxのモデルへの入力はnumpyを使うため、PytorchのDataloaderでもTensorflowのDataloaderでも、入力の段階でndarrayになっていれば使えます。今回はTFRecordでデータが与えられているのでTensorflowを使っていきます。PytorchのDataloaderを使う場合は、自動でtorch.tensorに変換されてしまうのでcollate関数を用意してnumpyとして出力されるように用意しましょう。こちらのnumpy_collateを使用すればできます。
Tensorflowの場合は、numpyでイテレーションを回すメソッドが用意されているので、それでイテレーターを作成すればそのままFlaxの学習に使えます。
TFRecordをTensorflowで扱うにはtf.data.TFRecordDatasetを使います。TFRecordを読み込んだあとは、データをパースして画像をデコードします。これで今回のデータの最低限のデータの読み出しはできます。

AUTO = tf.data.experimental.AUTOTUNE

TFRECORD_FILES = tf.io.gfile.glob(
    '/path/to/tfrecord/*.tfrec'
)

# function for decoding and reshaping image
def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32)
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image


# function for reading labeled tf_record
def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "class": tf.io.FixedLenFeature([], tf.int64),
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    label = tf.cast(example['class'], tf.int32)
    return {"image": image, "label": label}


# function to load datasets
def load_dataset(filenames):
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
    dataset = dataset.map(
        read_labeled_tfrecord, num_parallel_calls=AUTO)
    return dataset
    
dataset = load_dataset(TFRECORD_FILES)

データに対して前処理やシャッフル、バッチ化をする場合は、作成したdatasetのmapやshuffle、batchを呼び出して行います。

# augmentation
def aug(data):
    img = data["image"]
    img = tf.image.random_flip_left_right(img)
    return dict(image=img, label=data["label"])

dataset = dataset.map(aug, AUTO)
dataset = dataset.shuffle(256)
dataset = dataset.batch(64, drop_remainder=True)

Flaxを使う上で重要なのは、出力するデータの形を[num_device, batch, height, width, channel]にすることです。Flaxではpmapを使用することで、簡単に複数のGPU/TPUに分散させて学習することができます。1枚のGPUの場合でも同じように使えるので、基本的にこの形にすれば良いと思います。pmapについてはあとで説明します。

# Shard data such that it can be distributed accross devices
num_devices = jax.local_device_count()

def _shard(data):
    data['image'] = tf.reshape(
        data['image'],
        [num_devices, -1, *image_size, data['image'].shape[-1]])
    data['label'] = tf.reshape(data['label'],
                               [num_devices, -1, num_classes])
    return data

dataset = dataset.map(_shard, AUTO)

これでDataloaderの準備は完了です。

モデルの準備

Flaxで簡単なモデルを作ってみましょう。

# Based on flax docs: https://flax.readthedocs.io/en/latest/#basic-usage
import flax.linen as nn
import jax.numpy as jnp
from jax.random import PRNGKey

class MLP(nn.Module):                              # create a Flax Module dataclass
  out_dims: int

  @nn.compact
  def __call__(self, x):
    x = x.reshape((x.shape[0], -1))
    x = nn.Dense(128)(x)                           # create inline Flax Module submodules
    x = nn.relu(x)
    x = nn.Dense(self.out_dims)(x)                 # shape inference
    return x

model = MLP(out_dims=10)                           # instantiate the MLP model

x = jnp.empty((4, 28, 28, 1))                      # generate random data
params = model.init(PRNGKey(42), x)["params"]      # initialize the weights
y = model.apply({"params"=params}, x)              # make forward pass

nn.Moduleのdataclassを継承する形でモデルを定義します。ここでは一番シンプルな@nn.compactで実装する方法を示しています。Pytorchのように初期化のためのsetup関数にレイヤーの定義をして、call関数で実際の計算する方法でも実装できます。
Pytorchとの大きな違いは、Pytorchだとモデル自体がパラメタを持っていますが、Flaxは関数型のライブラリなのでモデルとパラメタを別々で扱います。モデル自体は計算の方法だけをもっていて、各レイヤーのパラメタやBatchnormalizationの平均や分散などをモデルとは別のオブジェクトで管理します。
model.initの部分がパラメタを初期化している部分です。この出力のvariablesを学習のタイミングでアップデートしていきます。ランダムシードの扱いも厳格で、jax.random.PRNGKeyを使って生成します。乱数について詳しく知りたい方はJaxの公式ドキュメントを参照してください。
推論するときは、model.applyvariablesと入力xの2つ与えます。

残念ながらFlaxではtimmのようなモデルライブラリはまだありません。Pretrainedのモデルの数はそんなに多くありませんが、Pytorchからの移植が簡単にできるのでぜひ挑戦してみてください。

すでにFlaxで実装されているものだと、Googleのvision_transformerやtimmの作者の方が実装したefficientnet-jaxなどがあります。サンプルのKaggle notebookではシンプルさを重視してFlaxのRepositoryのexampleにあるResNetの実装を使っています。
ResNetで注意すべきところはBatch normalizationのところです。nn.Module.paramsで作られたパラメタはparamsで統一されているのですが、nn.Module.variableで作られたパラメタは別で保持されます。Batch normalizationのmeanとvarはnn.Module.variableで作られているので、モデルのパラメタとは別で扱います。

model = ResNet()

x = jnp.empty((4, 256, 256, 3))
variables = model.init(PRNGKey(42), x)
params = variables["params"]
batch_stats = variables["batch_stats"]
y = model.apply({"params"=params, "batch_stats": batch_stats}, x)

Pytorchだとこのあたりはモデル側でよしなにやってくれてるけど、Flaxだとちゃんと自分で管理しないといけないので少し注意が必要です。

学習の準備

基本的な流れは公式のドキュメントにシンプルにまとまっています。
https://flax.readthedocs.io/en/latest/getting_started.html

今回はTPUを使いたいので公式ドキュメントに加えて、vision_transformerのコードも参考にしました。

用意しているもの:

  1. loss関数
  2. metric関数
  3. optimizer
  4. 勾配更新用の関数
  5. 評価用関数
  6. 訓練のループ
    1つずつ見ていけば特に難しいところはありません。
    すこしややこしいのは勾配を更新するところだと思います。
    optimizerはoptaxというDeepmind製のライブラリを使用しています。optaxを使用するときの流れです。
# optimizerを呼び出す
optimizer = optax.adam(learning_rate)
# paramsでoptimizerを初期化
opt_state = optimizer.init(params)
# gradientでoptimizerをupdate
updates, opt_state = optimizer.update(grads, opt_state)
# paramsの値をupdate
params = optax.apply_updates(params, updates)

簡略化していますが、この4ステップが基本になります。Loss関数やSchedulerもoptaxでいくつか用意されているのでAPIのページを参考にしてください。

次に勾配更新用の関数をResNetのBatch normalizationも含んだもので簡単に説明します。

@jax.jit
def train_step(state, batch):
    def loss_fn(params, batch_stats):
        logits, new_model_state = CNN().apply({"params": params, "batch_stats": batch_stats}, batch["image"], mutable=['batch_stats'])
        loss = cross_entropy_loss(logits=logits, labels=batch["target"])
        return loss, (new_model_state, logits)
    
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    aux, grads = grad_fn(state.params, state.batch_stats)
    
    new_model_state, logits = aux[1]
    
    metrics = compute_metrics(logits=logits, labels=batch["target"])
    
    state = state.apply_gradients(grads=grads, batch_stats=new_model_state["batch_stats"])
    return state, metrics, logits, batch["target"]

損失関数を計算する関数loss_fnjax.value_and_gradに渡します。その返り値の関数を使うことで、勾配を計算してくれます。jax.value_and_gradのhas_aux=Trueにすることで、loss_fnの返り値を受け取り、メトリックの計算やbatch_statsの更新に使います。
この例ではflax.trainingtrain_stateを使ってparamsの更新を行ったり、batch_statsの管理をしています。Flax exampleのtrain.pyに詳細はあります。サンプルのKaggle notebookの方ではvision_transformerの実装をベースに作っています。
忘れてはいけないのは@jax.jitです。これにより、コンパイルされて高速になります。

pmapによる並列化

ここがJaxを使う一番のポイントだと思っています。Jaxではvmappmapといった、関数をベクトル化する関数が用意されています。関数をベクトル化するというのはイマイチわかりにくいですが、同じ関数を複製して同時に入力を与えて並列で処理できるようにする、と理解しています。pmapはXLAコンパイルを行い、さらに使用可能なデバイスに分散させて計算することができます。マルチGPUでもTPUでも同じように使うことができます。
pmapを使うためには、

  1. dataloaderの入力を[num_device, batch, height, width, channel]にする
  2. flax.jax_utils.replicateを使ってparams, optimizerの複製
  3. jax.pmapで関数の並列化
  4. jax.lax.pmeanなどで各デバイスで計算された値を集約

サンプルのKaggle notebookのここが該当箇所です。

params_repl = flax.jax_utils.replicate(params)
model_apply_repl = jax.pmap(
    lambda params、, inputs: model.apply(params, inputs['image'], train=False),
    axis_name='batch')
opt_state_repl = flax.jax_utils.replicate(opt_state)

flax.jax_utils.replicateをするとparamsの頭に1つ次元が増えて、num_deviceの値が入ることになります。もしparamsが[4, 10]のパラメタだとしたら、[num_device, 4, 10]になります。
入力もparamsもrepliceateをすることで最初にnum_deviceの次元が追加され、これらをmodel_apply_replに渡すことで各デバイスに割り振って計算を行ってくれます。シングルGPUでnum_deviceが1のときも、同じように使えるのでわざわざ変える必要はありません。
pmapを使って各デバイスで計算した勾配はpmeanを使って平均します。

grad = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='batch'), grad)
updates, opt_state = tx.update(grad, opt_state)
params = optax.apply_updates(params, updates)
loss = jax.lax.pmean(loss, axis_name='batch')

先程の勾配更新用の関数のgradを計算して以降の部分です。
ちなみにflax.jax_utils.replicateで複製したものはflax.jax_utils.unreplicateで元に戻せます。

学習の実行とweightの保存

学習の実行は特に変わったこともなく、最初に用意したTensorflowのDataloaderをas_numpy_iteratorでループを回すだけです。それをpmapでラップした関数に渡してあげるとマルチGPUやTPUが使える場合は分散して計算してくれます。
詳しくはサンプルのkaggle notebookを見てください。

学習し終わったWegihtの保存は、paramsをシリアライズして保存しています。
flax.serializationto_bytesを使ってシリアライズしたものをそのままファイルに書き込んでいます。

from flax.serialization import to_bytes, from_bytes
from flax.linen import FrozenDict

def save_params(params: FrozenDict, path: str) -> None:
    serialized_params = to_bytes(params)
    with open(path, 'wb') as f:
        f.write(serialized_params)

ここまでずっとparamsと言っていたものの実態はFrozenDictと呼ばれるもので、編集不可能なPythonの辞書になります。

Flaxで推論&サブミット

まずはweightの読み込みです。to_bytesで保存したparamsはfrom_bytesで読み込めます。

from flax.serialization import to_bytes, from_bytes
from flax.linen import FrozenDict

def load_params(params: FrozenDict, path: str) -> FrozenDict:
    with open(path, 'rb') as f:
        serialized_params = f.read()

    return FrozenDict(from_bytes(params, serialized_params))

この時、paramsも引数に必要です。なのでまずはmodel.initでparamsを呼び出します。そしてシリアライズ化されたパラメタを読み込んで、paramsに書き込みます。
paramsはreplicateして保存した場合、デバイスの数が保存したデバイス数と違う場合はunreplicateして再度replicateする必要があります。保存のときにunreplicateしておくほうが親切かもしれません。

# re-replicate for the device used now
params = flax.jax_utils.unreplicate(params_repl)
params_repl = flax.jax_utils.replicate(params)

推論の関数は学習時の評価用関数と基本的に同じです。推論時に注意する点としてはjax.jitjax.pmapでラップした関数の中では数値しか扱うことができません。なので、推論時に文字列やbyte型の画像IDなども一緒にdataloaderで取得する場合は、ラップした関数に渡すのを入力の画像だけに変更する必要があります。

model_apply_repl = jax.pmap(
    lambda params、, inputs: model.apply(params, inputs['image'], train=False),
    axis_name='batch')
      ↓
model_apply_repl = jax.pmap(
    lambda params, image: model.apply(params, image, train=False),
    axis_name='batch')

次にすること

ベースのコードができたとはいえ、まだまだPytorchのように周辺の機能が充実しているわけではありません。そのあたりは自分で実装して行く必要があると思います。その他にもJax/Flaxを使ってKaggleで戦うために必要そうなことをあげてみました。

  • Colaboratoryでも使えるようにする
  • TFRecordを自由自在に作れるようにする
  • Augmentation関数を充実させる
  • Gradcamなどのモデルを評価する関数の実装
  • Backboneをたくさん用意
  • アクセサリモジュールを自由に足せるようにする

このあたりでしょうか。やることいっぱいあって楽しいですね。

まとめ

この記事ではJax/Flaxの使い方をざっくり説明しました。細かい部分は用意したnotebookや参考リンクを見てください。
Jax/Flaxで学習から推論、サブミットまでやってみて、かなり簡単に使えることがわかりました。公式のドキュメントが充実しているのと、GoogleがJax/Flaxで論文実装を出しており、使えるコードがけっこうあるのが大きいと思います。
またJax/Flaxだと驚くほど簡単にTPUも使えました。しかも目に見えて高速になっていたので挑戦してみて良かったです。ただ残念なことに、Kaggle notebookのTPUは待ち時間が長く、制限時間もあまり多くはないので、Colaboratoryでも同じように学習できるようにするのは必須だと感じました。(ColaboratoryのTPUは少し性能が低いですが...。)
「Jax/FlaxはPytorchに取って代わる!」とまでは言い切れませんが、実装力さえついてくれば十分Kaggleの画像コンペで戦えると感じました。私は2023年は基本的にJax/FlaxでKaggleやろうと思います。

参考リンク

Discussion