Open3
JAX/Flax/OptaxでResNetをやってみる

はじめに
JAX/Flax/Optaxを利用し、ResNetのモデルを構築する。ImageNet2012の学習データを用い、精度(Top1-Acc)が再現できるか確認する。
FlaxにはImageNet classificationのサンプルでResNet50が実装されている。自分のリポジトリは、これをもとに実装しているため、今回は学習をConvNeXtにあるTraining Techniquesで行い、 78.8%の精度が再現できるかを確認する。はじめにオリジナルの学習パラメータでの学習、つづいてConvNeXtのTraining Techniquesでの学習でそれぞれの精度を確認する。
参照
ResNet
ConvNeXt

Flaxでの学習パラメータおよび精度
参考にしたFlaxのImageNet classificationのサンプルでの学習パラメータは以下。
オリジナルのResNetと異なり、Epochsは90→100となっている。
精度は以下でTPUv2-32で76.67%のTop-1 accuracyとある。
この実装を元にTPUv2-8で再現を行う。
- https://github.com/NobuoTsukamoto/jax_examples/tree/main/classification
- https://github.com/NobuoTsukamoto/jax_examples/blob/main/classification/configs/imagenet_resnet50_v1_tpu.py
終了時点から3epoch以内(checkpointの保存数)でもっともよいTop-1 Accuracyと参考にBest Top-1 Accuarcyも確認。精度は76.3%でほぼ再現できることを確認。

ConvNeXtのTraining Techniquesでの学習パラメータと精度
学習パラメータ
論文では以下を適用したとある。
- Optimizer
- AdamW
- LR Scheduer
- Warmup cosine decay
- Data augmentation
- Mixup and Cutmix
- RandAugment
- Random Erasing
- Regularization schemes
- Stochastic Depth
- Label Smoothing