Open3

JAX/Flax/OptaxでResNetをやってみる

nb.onb.o

はじめに

JAX/Flax/Optaxを利用し、ResNetのモデルを構築する。ImageNet2012の学習データを用い、精度(Top1-Acc)が再現できるか確認する。
FlaxにはImageNet classificationのサンプルでResNet50が実装されている。自分のリポジトリは、これをもとに実装しているため、今回は学習をConvNeXtにあるTraining Techniquesで行い、 78.8%の精度が再現できるかを確認する。はじめにオリジナルの学習パラメータでの学習、つづいてConvNeXtのTraining Techniquesでの学習でそれぞれの精度を確認する。

参照

ResNet

ConvNeXt

nb.onb.o

Flaxでの学習パラメータおよび精度

参考にしたFlaxのImageNet classificationのサンプルでの学習パラメータは以下。

オリジナルのResNetと異なり、Epochsは90→100となっている。

精度は以下でTPUv2-32で76.67%のTop-1 accuracyとある。

この実装を元にTPUv2-8で再現を行う。

終了時点から3epoch以内(checkpointの保存数)でもっともよいTop-1 Accuracyと参考にBest Top-1 Accuarcyも確認。精度は76.3%でほぼ再現できることを確認。

nb.onb.o

ConvNeXtのTraining Techniquesでの学習パラメータと精度

学習パラメータ

論文では以下を適用したとある。

  • Optimizer
    • AdamW
  • LR Scheduer
    • Warmup cosine decay
  • Data augmentation
    • Mixup and Cutmix
    • RandAugment
    • Random Erasing
  • Regularization schemes
    • Stochastic Depth
    • Label Smoothing