Open7

JAX/Flax/OptaxでMobileNet v3 smallをやってみる

nb.onb.o

モデル実装

前回(JAX/Flax/OptaxでMobileNet v2をやってみる)と同様、以下のリポジトリで実装した。

MobileNet v3のモデルは以下で定義した。

Model EMAの実装

MobileNet v3は学習時、Model EMA(Exponential Moving Average)を使用している。Optaxにはoptax.emaがあるが、Model EMAへの利用方法がよくわからない(optax.chainと組み合わせると、学習時のパラメータと別にEMAのパラメータを保持できない)ため、flax.training.train_stateを継承したクラスに独自にEMAを行うメソッドとパラメータを用意することとした。これにより、1つのtrain_stateでEMAを行う前のパラメータや状態、EMAのパラメータを管理できることにした。

class TrainStateWithBatchNorm(train_state.TrainState):
    batch_stats: Any
    dynamic_scale: dynamic_scale_lib.DynamicScale
    ema_decay: float = 0.0
    ema_params: Any = None

    def apply_ema(self):
        return jax.tree_util.tree_map(
            lambda ema, param: (ema * self.ema_decay + param * (1.0 - self.ema_decay)),
            self.ema_params,
            self.params,
        )

モデルの初期化時にema_paramsを初期化、

        state = TrainStateWithoutBatchNorm.create(
            apply_fn=model.apply,
            params=params,
            tx=tx,
            dynamic_scale=dynamic_scale,
            ema_params=params,
            ema_decay=config.model_ema_decay,
        )

学習時、apply_emaメソッドを実行することで更新したパラメータをもとにEMAのパラメータを得る(実装では、Gradient accumulationを考慮してapply_emaメソッドを呼び出すステップを調整)。

        new_state = jax.lax.cond(
            step % gradient_accumulation_steps == 0,
            lambda _: new_state.replace(ema_params=new_state.apply_ema()),
            lambda _: new_state,
            None,
        )
nb.onb.o

MobileNet v3 smallの学習のセットアップ

MobileNet v3 smallの学習を行う。
学習パラメータを以下に示す。学習パラメータはTPUv2-8のリソースで実行できるようにTensorFlow-Slimの実装を参照し、バッチサイズを1536(192 par chip)とした。乱数のシード値は42ですべて固定した。

学習リソース

学習はTPUv2-8で行う。

Data augmentation

前回(JAX/Flax/OptaxでMobileNet v2をやってみる)と同様。

Optimizer

Optimizerはoptax.rmspropsを利用。

Config Params Note
optimizer RMSProp
Learning rate - schedulerで指定(別表)
Decay 0.9
Eps 0.002
Initial scale 1.0
Eps in sqrt True
Eentered False
Momentum 0.9
Nesterov False
Bias correction False
  • *はデフォルト引数のまま

Learning scheduer

LR ScheduerはWarmup exponential decayで、optax.warmup_exponential_decay_scheduleを利用する。
以下のパラメータを指定する。learning_ratetransition_stepsはバッチサイズによる調整を行う。

Config Params Note
Optimizer schedule Warmup exponential decay
Initial learning rate 0.0
Learning rate 0.16 0.02 * (batch_size(1536) / 192)
Warmup epochs 5
Exponential decay rate 0.99
Transition steps 2505 3.0 * steps_per_epoch (835)
LR drop staircase True
Num epochs 1000

Model setup

モデルの入力サイズ、Batch normalization、Dropuptに指定するパラメータは以下。

Config Params Note
Input size 224x224
BatchNorm momentum 0.997
BatchNorm esp 0.001
Dropout 0.2
Stochastic depth 0.0
Initializer truncated normal
scale=1.0
別途記載
Model dtype bfloat16
Batch size 1536

Kernel initializers

flax.linen.initializers.variance_scalingで指定する。

        kernel_initializer = jax.nn.initializers.variance_scaling(
            scale=1.0, mode="fan_in", distribution="truncated_normal"
        )
        conv = partial(
            nn.Conv, use_bias=False, kernel_init=kernel_initializer, dtype=self.dtype
        )

Loss

optax.losses.smooth_labelsにてLabel smoothingを指定する。

Config Params
Label smoothing 0.1

L2 weight decay

optax.add_decayed_weightsにて、L2 weight decayを指定する。指定や除外するopeについては、前回(JAX/Flax/OptaxでMobileNet v2をやってみる)と同様。

Config Params
L2 weight decay 0.00001

Model EMA

Model EMAを有効にし、以下のパラメータを指定する。

Config Params
EMA decay 0.9999
nb.onb.o

モデル学習の結果

1回目

終了時点から3epoch以内(checkpointの保存数)でもっともよいTop-1 Accuracyと参考にBest Top-1 Accuarcyも確認。
元論文はTop-1 Accuracyは67.5%で、今回の結果は許容できる誤差でなかった。
slimの実装でも8 GPU(バッチサイズ1536)での結果は若干の精度が低下があるが、その範囲を超えている。

Non-EMA Model

Result

Epoch Top-1 Accuracy Note
999 64.77
992 65.29 best

EMA Model

Result

Epoch Top-1 Accuracy Note
999 64.62
909 65.44 best

nb.onb.o

学習2回目

バッチサイズを大きくした場合、精度が1回目より向上するか確認する。バッチサイズを4096とし、LR scheduerを下記の表のとおりとする。

Learning scheduer

Config Params Note
Learning rate 0.426 0.02 * (batch_size(4096) / 192)
Transition steps 936 3.0 * steps_per_epoch (312)

Model setup

Config Params Note
Batch size 4096 Gradient accumulation
batch size=1024, every k schedule=4

学習の結果

期待できる精度を達成できなかった。また、Model EMAは1回目と比べ、十分に上がりきらない。

Non Model EMA

Result

Epoch Top-1 Accuracy Note
998 64.39
983 64.65 best

Model EMA

Result

Epoch Top-1 Accuracy Note
1000 55.64
946 60.04 best

nb.onb.o

学習3回目

実装の見直し

期待した精度を大幅に下回ったため、モデルの実装を中心に見直しを行った。結果、アクティベーションの指定に誤りがあった。元実装ではreluのところをrelu6としていた。
アクティベーションの指定を修正する。

学習の結果

2回目と同様のパラメータで学習を行う。精度は多少向上したが、元実装の精度67.5%を実現できなかった。

Non Model EMA

Result

Epoch Top-1 Accuracy Note
1000 65.20
980 65.53 best

Model EMA

Result

Epoch Top-1 Accuracy Note
1000 61.44 best

nb.onb.o

学習4回目

Model EMAの見直し

TF-Vision Model Garden(TFM)の実装を再確認。TFMではModel EMAにtfm.optimization.ExponentialMovingAverageを利用している。tf.train.ExponentialMovingAverageと動作が異なる。

  • trainable_weights_onlyフラグ
    Falseの場合、全パラメータ(batch normalization 移動平均(mean)と分散(var))を対象とする。
  • dynamic_decayフラグ
    Trueの場合、decay rateが0.1からはじまり、minimum(decay_rate, (1. + decay) / (10. + decay))で計算されたdecay rateが設定される。

tfm.optimization.ExponentialMovingAverageにあわせるように、Model EMAの実装を見直した。

学習の結果

期待した精度(Top-1、Top-5)を達成した。
設定したEpoch数(1000)に到達する前に、論文の精度まで到達したため、中止。

Epoch Top-1 accuracy Top-5 accuracy
619 67.52% 87.60%