JAX/Flax/OptaxでConvNeXtをやってみる

はじめに
JAX/Flax/Optaxを利用し、ConvNeXtのモデルを構築する。ImageNet2012の学習データを用い、精度(Top1-Acc)が再現できるか確認する。
参照

モデル実装
前回まで(JAX/Flax/OptaxでMobileNet v3 largeをやってみる)と同様、以下のリポジトリで実装した。
ConvNeXtのモデルは以下で定義した。
ConvNeXtの各バリアントはstage size、filter size で決まる。今回はTiny
を再現することを目的とする。Tiny
の定義は以下。
ConvNeXtではJAX/Flax/Optaxで定義されていない独自レイヤー等を以下で実装した。
- Stochastic depth
- Layer scale
- Model EMA(https://zenn.dev/link/comments/f89daf1d41ce58と同様)
また、ConvNeXt特有の実装について記載する。
- 1x1 ConvをDenseレイヤーで実装

Stochastic depthの実装
FlaxにはStochastic depthのレイヤーは実装されていない。
このため、独自レイヤーとして定義した。モデルの実装は以下。
Stochastic depthレイヤーの定義は以下。学習時のみレイヤー方向に層をドロップする(0にする)。後で気づいたが、EfficientNetの実装と同様、flax.linen.Dropoutで、broadcast_dims
をレイヤー方向にしていすればよかったかもしれない。
class StochasticDepth(nn.Module):
"""Create a stochastic depth layer.
Note: When using :meth:`Module.apply() `, make
sure to include an RNG seed named ``'stochastic_depth'``.
StochasticDepth isn't necessary for variable initialization.
Reference
- Deep Networks with Stochastic Depth
- https://github.com/tensorflow/models/blob/v2.14.2/official/vision/modeling/layers/nn_layers.py#L226-L261
- https://pytorch.org/vision/main/_modules/torchvision/ops/stochastic_depth.html#StochasticDepth
- https://flax.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout
Attributes:
stochastic_depth_drop_rate: the stochastic depth probability.
(_not_ the keep rate!)
deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` and
masked, whereas if true, no mask is applied and the inputs are returned as is.
rng_collection: the rng collection name to use when requesting an rng key
"""
stochastic_depth_drop_rate: float
deterministic: Optional[bool] = None
rng_collection: str = "stochastic_depth"
@nn.compact
def __call__(
self,
inputs,
deterministic: Optional[bool] = None,
rng: Optional[KeyArray] = None,
):
"""Applies a random stochastic depth mask to the input.
Args:
inputs: the inputs that should be randomly masked.
deterministic: if false the inputs are scaled by ``1 / (1 - rate)``
and masked, whereas if true, no mask is applied and the inputs
are returnedas is.
rng: an optional PRNGKey used as the random key, if not specified,
one will be generated using ``make_rng`` with the
``rng_collection`` name.
Returns:
The masked inputs reweighted to preserve mean.
"""
deterministic = merge_param("deterministic", self.deterministic, deterministic)
if (self.stochastic_depth_drop_rate == 0.0) or deterministic:
return inputs
if rng is None:
rng = self.make_rng(self.rng_collection)
keep_prob = 1.0 - self.stochastic_depth_drop_rate
batch_size = inputs.shape[0]
broadcast_shape = list([batch_size] + [1] * int(inputs.ndim - 1))
mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
mask = jnp.broadcast_to(mask, inputs.shape)
return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))
def get_stochastic_depth_rate(init_rate, i, n):
"""Get drop connect rate for the ith block.
Args:
init_rate: A `float` of initial drop rate.
i: An `int` of order of the current block.
n: An `int` total number of blocks.
Returns:
Drop rate of the ith block.
"""
if init_rate < 0 or init_rate > 1:
raise ValueError("Initial drop rate must be within 0 and 1.")
return init_rate * float(i) / n
Randomness and PRNGs in Flaxと同じように乱数生成器を管理する必要がある。jax.random.split
でstochastic depthの乱数生成器を用意する。
rng = jax.random.PRNGKey(seed=config.seed)
params_rng, dropout_rng, stochastic_depth_rng = jax.random.split(rng, num=3)
rngs = {
"params": params_rng,
"dropout": dropout_rng,
"stochastic_depth": stochastic_depth_rng,
}
...
dropout_rngs = jax.random.split(dropout_rng, jax.local_device_count())
stochastic_depth_rngs = jax.random.split(
stochastic_depth_rng, jax.local_device_count()
)
学習ループの際、train_stateのapply_fn
でdropout
とstochastic depth
の乱数を渡し、新たな乱数を生成する。ここは、dropout
の乱数にもう一つ追加された形なのでそこまで分かりにくいことはないはず。
def train_step(
...
):
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
stochastic_depth_rng, new_stochastic_depth_rng = jax.random.split(
stochastic_depth_rng
)
...
def loss_fn(params):
"""loss function used for training."""
...
logits = state.apply_fn(
{"params": params},
batch["image"],
rngs={"dropout": dropout_rng, "stochastic_depth": stochastic_depth_rng},
)
...
return new_state, metrics, new_dropout_rng, new_stochastic_depth_rng
モデル定義の際、stochastic depth
で層ごとにdropする割合を大きくしていく。
ConvNeXtでは以下で割合を生成している。
また、この割合はブロックごとに変化させている。
今回の実装も同様にした。
- https://github.com/NobuoTsukamoto/jax_examples/blob/554318b38a95b3cb5b22ce4420c706642e409dbd/classification/implements/convnext.py#L105-L110
- https://github.com/NobuoTsukamoto/jax_examples/blob/554318b38a95b3cb5b22ce4420c706642e409dbd/classification/implements/convnext.py#L81
- https://github.com/NobuoTsukamoto/jax_examples/blob/554318b38a95b3cb5b22ce4420c706642e409dbd/common/common_layer.py#L209

Layer scale
ConvNeXtにはLayer scaleを利用している。
FlaxにはLayer scaleのレイヤーは実装されていない。このため、独自レイヤーとして定義した。実装は以下。
class LayerScale(nn.Module):
"""Create a layer scale.
Reference
- [Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239v2)
Attributes:
init_values:
projection_dim:
"""
projection_dim: int
init_values: float = 1e-6
dtype: Any = jnp.float32
def setup(self):
initializer = nn.initializers.constant(value=self.init_values, dtype=self.dtype)
self.scale = self.param(
"scale",
initializer,
(self.projection_dim,),
self.dtype,
)
@nn.compact
def __call__(self, inputs):
return inputs * self.scale

ConvNeXtブロックの実装
オリジナルのConvNeXtブロックの実装は、コメントにも記載があるとおり、
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
ではなく
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
である。
論文には記載がないが、
- 1x1 Convではなく、Linearを使用
- PyTorchのLayerNormではなく、オリジナルのLayerNorm(channels_last)を使用
している。
これは、Question 1x1 conv vs linear #18にあるが、推論スループットの向上のためとある。
JAX/Flax/Optaxでの実装はNHWCであるため、1x1 ConvをLinearとするのみをオリジナルと同様にした。

ConvNeXt Tinyの学習のセットアップ
ConvNeXt Tinyの学習を行う。学習パラメータは極力、オリジナルにあわせることとした。ただし、学習リソース(TPUv2-8)より、Gradient accumulationを利用した。また、Data loaderにTensorFlow Datasets 、Data augmentationにTensorFlow Models Vision Librariesを利用した。これにより、オリジナルの実相と一部異なる点がある。
乱数のシード値は42ですべて固定した。
学習リソース
学習はTPUv2-8で行う。
Data augmentation
ConvNeXtでは、以下のData augmentationをtimmで指定している。
- Rand Augment
- Mixup and CutMix
- Random Erasing
自分の実装はFlaxのImageNet classificationのサンプル(ResNet50)をベースにしているため、Data loaderにTensorFlow Datasets 、Data augmentationにTensorFlow Models Vision Librariesを利用している。
timm
への置き換えは大変なため、TensorFlow Datasets
と TensorFlow Models Vision Libraries
を利用して、なるべく元実装と同じにする方針とした。TF-Vision Model Gardenの実装を参考に、Data augmentationを実装した。
Randaugment
tfm.vision.augment.RandAugmentを利用する。
Config | Params | Note |
---|---|---|
Num layers | 2 | |
Magnitude | 9 | |
Cutout const | 40.0 | Exclude ops(除外Ops)にCutoutを指定のため無効 |
Translate const | 100 | |
Magnitude std | 0.5 | |
Prob to apply | None | |
Exclude ops | ["Cutout"] |
元の実装ではtimmのcreate_transformの引数auto_augment
にrand-m9-mstd0.5-inc1
を指定している。rand-m9-mstd0.5-inc1
は以下の値となる。
- Magnitude = 9
- Magnitude std = 0.5
- 拡張の種類を増加(下記のテーブル定義で
Increasing
がつく拡張が対象)
https://github.com/huggingface/pytorch-image-models/blob/v1.0.12/timm/data/auto_augment.py#L641 - Translate constは入力画像サイズから算出し、int(224 * 0.45)=100
https://github.com/huggingface/pytorch-image-models/blob/v1.0.12/timm/data/transforms_factory.py#L170
"拡張の種類を増加" に関してはtfm.vision.augment.RandAugmentで指定ができない。
Mixup and CutMix
tfm.vision.augment.MixupAndCutmixを利用する。
Config | Params | Note |
---|---|---|
Mixup alpha | 0.8 | |
Cutmix alpha | 1.0 | |
Prob | 1.0 | |
Switch prob | 0.5 | |
Label smoothing | 0.1 | optax.losses.smooth_labelsは利用しない |
Num classes | 1000 |
元の実装で指定しているパラメータそのままを利用した。
Random Erasing
tfm.vision.augment.RandomErasingを利用する。
Config | Params | Note |
---|---|---|
Probability | 0.25 | |
Min area | 0.02 | |
Max area | 1 / 3 | |
Min aspect | 0.3 | |
Max aspect | None | |
Min count | 1 | |
Max count | 1 | |
Trials | 10 |
元の実装ではtimmのcreate_transformの引数re_prob
, re_mode
, re_count
にパラメータを指定している。
- https://github.com/facebookresearch/ConvNeXt/blob/main/main.py#L113-L121
- https://github.com/facebookresearch/ConvNeXt/blob/main/datasets.py#L58
- https://github.com/huggingface/pytorch-image-models/blob/v1.0.12/timm/data/transforms_factory.py#L232
上記以外のパラメータはtimmのRandomEraseのデフォルトパラメータとして指定した。
Optimizer
Optimizerはoptax.adamwを利用する。
Config | Params | Note |
---|---|---|
Optimizer | AdamW | |
Learning rate | - | schedulerで指定(別表) |
b1 | 0.9 | * |
b2 | 0.999 | * |
Eps | 1e-08 | * |
Eps root | 0.0 | * |
Mu dtype | None | * |
Weight decay | 0.05 | |
Mask | - | 指定(別表) |
Nesterov | False | * |
- *はデフォルト引数のまま
Learning scheduer
LR ScheduerはWarmup cosine decay
で、optax.schedules.warmup_cosine_decay_scheduleを利用する。
Config | Params | Note |
---|---|---|
Optimizer schedule | Warmup cosine decay | |
Init value | 0.004 | 0.0 |
Peak value | 0.004 | |
Warmup steps | 6260 | 20 epochs * 313 |
Decay steps | 93900 | 300 * 313 |
End value | 1e-6 | |
Exponent | 1.0 | * |
Num epochs | 300 |
Model setup
モデルの入力サイズ、Dropupt、Stochastic depthに指定するパラメータは以下。TPUv2-8ではバッチサイズに4096を確保できないため、Gradient accumulation(every k schedule=4)、バッチサイズを1024とする。
Config | Params | Note |
---|---|---|
Input size | 224x224 | |
Dropout | 0.2 | |
Stochastic depth | 0.0 | |
Kernel initializer | truncated normal stddev=0.02 |
論文の値と実装が異なる? |
Model dtype | bfloat16 | |
Batch size | 4096 | Batch size=1024, Gradient accumulation=4) |
Kernel initializers
jax.nn.initializers.truncated_normalを利用し、flax.linen.Convの引数kernel_init
に指定する。
kernel_initializer = jax.nn.initializers.truncated_normal(stddev=0.02)
conv = partial(
nn.Conv,
use_bias=self.use_bias,
kernel_init=kernel_initializer,
dtype=self.dtype,
)
論文では、trunc. normal (0.2)
(Table 5.)とあるが、実装では0.002
と小さい値を指定している。
実装の0.002
は意図したもの(Transformerの初期化に従う)とあり、実装の値を利用した。
Loss
tfm.vision.augment.MixupAndCutmixで行うため、optax.losses.smooth_labelsは利用しない。
Weight decay
optax.adamwのweight_decay
、およびmask
に指定する。mask
に指定する除外opeについては、前回(JAX/Flax/OptaxでMobileNet v2をやってみる)と同様。
Model EMA
Model EMAを有効にし、以下のパラメータを指定する。
Config | Params |
---|---|
EMA decay | 0.9999 |