Open6

JAX/Flax/OptaxでConvNeXtをやってみる

nb.onb.o

モデル実装

前回まで(JAX/Flax/OptaxでMobileNet v3 largeをやってみる)と同様、以下のリポジトリで実装した。

ConvNeXtのモデルは以下で定義した。

ConvNeXtの各バリアントはstage size、filter size で決まる。今回はTinyを再現することを目的とする。Tinyの定義は以下。

ConvNeXtではJAX/Flax/Optaxで定義されていない独自レイヤー等を以下で実装した。

また、ConvNeXt特有の実装について記載する。

  • 1x1 ConvをDenseレイヤーで実装
nb.onb.o

Stochastic depthの実装

FlaxにはStochastic depthのレイヤーは実装されていない。
このため、独自レイヤーとして定義した。モデルの実装は以下。

Stochastic depthレイヤーの定義は以下。学習時のみレイヤー方向に層をドロップする(0にする)。後で気づいたが、EfficientNetの実装と同様、flax.linen.Dropoutで、broadcast_dimsをレイヤー方向にしていすればよかったかもしれない。


class StochasticDepth(nn.Module):
    """Create a stochastic depth layer.

        Note: When using :meth:`Module.apply() `, make
        sure to include an RNG seed named ``'stochastic_depth'``.
        StochasticDepth isn't necessary for variable initialization.

        Reference
            - Deep Networks with Stochastic Depth
            - https://github.com/tensorflow/models/blob/v2.14.2/official/vision/modeling/layers/nn_layers.py#L226-L261
            - https://pytorch.org/vision/main/_modules/torchvision/ops/stochastic_depth.html#StochasticDepth
            - https://flax.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout

    Attributes:
        stochastic_depth_drop_rate: the stochastic depth probability.
            (_not_ the keep rate!)
        deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` and
            masked, whereas if true, no mask is applied and the inputs are returned as is.
        rng_collection: the rng collection name to use when requesting an rng key
    """

    stochastic_depth_drop_rate: float
    deterministic: Optional[bool] = None
    rng_collection: str = "stochastic_depth"

    @nn.compact
    def __call__(
        self,
        inputs,
        deterministic: Optional[bool] = None,
        rng: Optional[KeyArray] = None,
    ):
        """Applies a random stochastic depth mask to the input.

        Args:
            inputs: the inputs that should be randomly masked.
            deterministic: if false the inputs are scaled by ``1 / (1 - rate)``
                and masked, whereas if true, no mask is applied and the inputs
                are returnedas is.
            rng: an optional PRNGKey used as the random key, if not specified,
                one will be generated using ``make_rng`` with the
                ``rng_collection`` name.

        Returns:
            The masked inputs reweighted to preserve mean.
        """
        deterministic = merge_param("deterministic", self.deterministic, deterministic)

        if (self.stochastic_depth_drop_rate == 0.0) or deterministic:
            return inputs

        if rng is None:
            rng = self.make_rng(self.rng_collection)

        keep_prob = 1.0 - self.stochastic_depth_drop_rate
        batch_size = inputs.shape[0]

        broadcast_shape = list([batch_size] + [1] * int(inputs.ndim - 1))

        mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
        mask = jnp.broadcast_to(mask, inputs.shape)
        return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))


def get_stochastic_depth_rate(init_rate, i, n):
    """Get drop connect rate for the ith block.

    Args:
        init_rate: A `float` of initial drop rate.
        i: An `int` of order of the current block.
        n: An `int` total number of blocks.

    Returns:
        Drop rate of the ith block.
    """
    if init_rate < 0 or init_rate > 1:
        raise ValueError("Initial drop rate must be within 0 and 1.")
    return init_rate * float(i) / n

Randomness and PRNGs in Flaxと同じように乱数生成器を管理する必要がある。jax.random.splitでstochastic depthの乱数生成器を用意する。

    rng = jax.random.PRNGKey(seed=config.seed)
    params_rng, dropout_rng, stochastic_depth_rng = jax.random.split(rng, num=3)
    rngs = {
        "params": params_rng,
        "dropout": dropout_rng,
        "stochastic_depth": stochastic_depth_rng,
    }
...
    dropout_rngs = jax.random.split(dropout_rng, jax.local_device_count())
    stochastic_depth_rngs = jax.random.split(
        stochastic_depth_rng, jax.local_device_count()
    )

学習ループの際、train_stateのapply_fndropoutstochastic depthの乱数を渡し、新たな乱数を生成する。ここは、dropoutの乱数にもう一つ追加された形なのでそこまで分かりにくいことはないはず。

def train_step(
...
):
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
    stochastic_depth_rng, new_stochastic_depth_rng = jax.random.split(
        stochastic_depth_rng
    )

...
    def loss_fn(params):
        """loss function used for training."""
...
            logits = state.apply_fn(
                {"params": params},
                batch["image"],
                rngs={"dropout": dropout_rng, "stochastic_depth": stochastic_depth_rng},
            )

...

    return new_state, metrics, new_dropout_rng, new_stochastic_depth_rng

モデル定義の際、stochastic depthで層ごとにdropする割合を大きくしていく。
ConvNeXtでは以下で割合を生成している。

また、この割合はブロックごとに変化させている。

今回の実装も同様にした。

nb.onb.o

Layer scale

ConvNeXtにはLayer scaleを利用している。

FlaxにはLayer scaleのレイヤーは実装されていない。このため、独自レイヤーとして定義した。実装は以下。

class LayerScale(nn.Module):
    """Create a layer scale.

        Reference
            - [Going deeper with Image Transformers](https://arxiv.org/abs/2103.17239v2)

    Attributes:
        init_values:
        projection_dim:
    """

    projection_dim: int
    init_values: float = 1e-6
    dtype: Any = jnp.float32

    def setup(self):
        initializer = nn.initializers.constant(value=self.init_values, dtype=self.dtype)
        self.scale = self.param(
            "scale",
            initializer,
            (self.projection_dim,),
            self.dtype,
        )

    @nn.compact
    def __call__(self, inputs):
        return inputs * self.scale
nb.onb.o

ConvNeXtブロックの実装

オリジナルのConvNeXtブロックの実装は、コメントにも記載があるとおり、
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
ではなく
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
である。

論文には記載がないが、

  • 1x1 Convではなく、Linearを使用
  • PyTorchのLayerNormではなく、オリジナルのLayerNorm(channels_last)を使用
    している。
    これは、Question 1x1 conv vs linear #18にあるが、推論スループットの向上のためとある。

JAX/Flax/Optaxでの実装はNHWCであるため、1x1 ConvをLinearとするのみをオリジナルと同様にした。

nb.onb.o

ConvNeXt Tinyの学習のセットアップ

ConvNeXt Tinyの学習を行う。学習パラメータは極力、オリジナルにあわせることとした。ただし、学習リソース(TPUv2-8)より、Gradient accumulationを利用した。また、Data loaderにTensorFlow Datasets 、Data augmentationにTensorFlow Models Vision Librariesを利用した。これにより、オリジナルの実相と一部異なる点がある。
乱数のシード値は42ですべて固定した。

学習リソース

学習はTPUv2-8で行う。

Data augmentation

ConvNeXtでは、以下のData augmentationをtimmで指定している。

  • Rand Augment
  • Mixup and CutMix
  • Random Erasing

自分の実装はFlaxのImageNet classificationのサンプル(ResNet50)をベースにしているため、Data loaderにTensorFlow Datasets 、Data augmentationにTensorFlow Models Vision Librariesを利用している。
timmへの置き換えは大変なため、TensorFlow DatasetsTensorFlow Models Vision Libraries を利用して、なるべく元実装と同じにする方針とした。TF-Vision Model Gardenの実装を参考に、Data augmentationを実装した。

Randaugment

tfm.vision.augment.RandAugmentを利用する。

Config Params Note
Num layers 2
Magnitude 9
Cutout const 40.0 Exclude ops(除外Ops)にCutoutを指定のため無効
Translate const 100
Magnitude std 0.5
Prob to apply None
Exclude ops ["Cutout"]

元の実装ではtimmのcreate_transformの引数auto_augmentrand-m9-mstd0.5-inc1を指定している。rand-m9-mstd0.5-inc1は以下の値となる。

"拡張の種類を増加" に関してはtfm.vision.augment.RandAugmentで指定ができない。

Mixup and CutMix

tfm.vision.augment.MixupAndCutmixを利用する。

Config Params Note
Mixup alpha 0.8
Cutmix alpha 1.0
Prob 1.0
Switch prob 0.5
Label smoothing 0.1 optax.losses.smooth_labelsは利用しない
Num classes 1000

元の実装で指定しているパラメータそのままを利用した。

Random Erasing

tfm.vision.augment.RandomErasingを利用する。

Config Params Note
Probability 0.25
Min area 0.02
Max area 1 / 3
Min aspect 0.3
Max aspect None
Min count 1
Max count 1
Trials 10

元の実装ではtimmのcreate_transformの引数re_prob, re_mode, re_countにパラメータを指定している。

上記以外のパラメータはtimmのRandomEraseのデフォルトパラメータとして指定した。

Optimizer

Optimizerはoptax.adamwを利用する。

Config Params Note
Optimizer AdamW
Learning rate - schedulerで指定(別表)
b1 0.9
b2 0.999
Eps 1e-08
Eps root 0.0
Mu dtype None
Weight decay 0.05
Mask - 指定(別表)
Nesterov False
  • *はデフォルト引数のまま

Learning scheduer

LR ScheduerはWarmup cosine decayで、optax.schedules.warmup_cosine_decay_scheduleを利用する。

Config Params Note
Optimizer schedule Warmup cosine decay
Init value 0.004 0.0
Peak value 0.004
Warmup steps 6260 20 epochs * 313
Decay steps 93900 300 * 313
End value 1e-6
Exponent 1.0
Num epochs 300

Model setup

モデルの入力サイズ、Dropupt、Stochastic depthに指定するパラメータは以下。TPUv2-8ではバッチサイズに4096を確保できないため、Gradient accumulation(every k schedule=4)、バッチサイズを1024とする。

Config Params Note
Input size 224x224
Dropout 0.2
Stochastic depth 0.0
Kernel initializer truncated normal
stddev=0.02
論文の値と実装が異なる?
Model dtype bfloat16
Batch size 4096 Batch size=1024, Gradient accumulation=4)

Kernel initializers

jax.nn.initializers.truncated_normalを利用し、flax.linen.Convの引数kernel_initに指定する。

        kernel_initializer = jax.nn.initializers.truncated_normal(stddev=0.02)
        conv = partial(
            nn.Conv,
            use_bias=self.use_bias,
            kernel_init=kernel_initializer,
            dtype=self.dtype,
        )

論文では、trunc. normal (0.2)(Table 5.)とあるが、実装では0.002と小さい値を指定している。

実装の0.002は意図したもの(Transformerの初期化に従う)とあり、実装の値を利用した。

Loss

tfm.vision.augment.MixupAndCutmixで行うため、optax.losses.smooth_labelsは利用しない。

Weight decay

optax.adamwweight_decay、およびmaskに指定する。maskに指定する除外opeについては、前回(JAX/Flax/OptaxでMobileNet v2をやってみる)と同様。

Model EMA

Model EMAを有効にし、以下のパラメータを指定する。

Config Params
EMA decay 0.9999