timmのModelEmaについて(ISIC2024コンペ振り返り①)
先日のKaggleのISICコンペは結構時間をかけて参加しましたが、惨敗でした…。(1088位/2739チーム)
上位解法で気になったものをいくつかまとめようと思います。まずは4thが画像モデルに使っていたtimmのModelEMAについてです。
timmのModelEmaとは?
ModelEMA
(Exponential Moving Average)は、モデルの重みの移動平均を保持するテクニックで、モデルの安定性や汎化性能を向上させるために用いられます。先日の関東kaggler会でもちょろっと紹介されていたようです。
1. EMAの基本的な考え方
EMAは、現在のモデルの重みと過去の重みの指数関数的な移動平均を計算します。これにより、学習中に過度に更新された重みを平均化し、より滑らかな更新結果を得ることができます。一般的には、次の式で表されます。
EMAの重み更新は次の式で表されます:
ここで、
-
は時刻\theta_{EMA}^{t} におけるEMAの重み、t -
は時刻\theta_{model}^{t} における現在のモデルの重み、t -
は0〜1の範囲のハイパーパラメータで、過去の重みをどれだけ重視するかを決定します。\alpha
2. timmのModelEmaV2実装
timmのModelEmaV2
クラスを使えば、この実装が簡単にできます。下記コードのように、トレーニング中にmodel_ema.update(model)
を呼び出すことで、現在のモデルの重みからEMAを計算・更新してくれます。
import timm
from timm.utils import ModelEmaV2
# timmから任意のモデルをロード
model = timm.create_model('resnet50', pretrained=True)
# ModelEmaV2の初期化
model_ema = ModelEmaV2(model, decay=0.9999)
# トレーニングループ内での更新
for batch in dataloader:
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, targets)
loss.backward()
optimizer.step()
# モデルのEMAを更新
model_ema.update(model)
# EMAモデルの重みを保存
torch.save(model_ema.module.state_dict(), 'ema_model.pth')
decay
は先に示した
注意点として、EMAモデルを使用する場合は、基本的に.moduleという記述を挟む必要があります。ModelEMAV2
は、元のモデルを内部で module
という属性にラップして管理しています。したがって、EMAモデルでの操作(推論や重みの保存・読み込みなど)は、常に model_ema.module
を通して行うことになります。
-
推論モードへの切り替え:
model_ema.module.eval() # 推論モードに切り替え
-
重みの保存:
torch.save(model_ema.module.state_dict(), 'ema_model.pth') # EMAモデルの重みを保存
-
重みの読み込み:
model_ema.module.load_state_dict(torch.load('ema_model.pth')) # EMAモデルの重みを読み込み
-
推論:
outputs = model_ema.module(inputs) # EMAモデルでの推論
3. EMAの効果
実際に自分がコンペ中に作った画像モデルに、EMAを適用してみました。モデルの設定は以下のノートブックを参照してください。
- EMA適用前の学習曲線
- EMA適用後の学習曲線
- 評価指標の確認(pAUC80)
CV平均 | PublicLB | PrivateLB | |
---|---|---|---|
EMA無し | 0.1580 | 0.1570 | 0.1461 |
EMAあり | 0.1595 | 0.1589 | 0.1467 |
EMA適用後は学習曲線の推移が滑らかになっているのが分かります。ただ、学習の進みは遅くなるので、パラメータを多少いじる必要はありそうです。
個人的に一番恩恵がありそうなのは、全データで学習しなおす時だと思います。全データで学習するとバリデーションデータを使って学習の進捗を確認できないため、EMAを使って比較的安定した学習ができるのは結構助かります。
4. まとめ
ModelEMAは通常のモデルの学習に少し追記するだけで簡単に実装できるので、覚えておいて損はないと感じました。今後のコンペで試そうと思います。
Discussion