🌟

半精度、混合精度学習の安定性とepsについて

2023/08/18に公開

経緯

  • モデルの予測結果にNaNを含む場合に例外を投げて死ぬというロジック(Code 1)を入れたところ、自動混合精度学習(AMP)中に死んだ
  • 原因を調べると、通常の精度(float32)では再現しないことに気づいた
  • 混合精度の話ではないが、半精度(float16)学習で1e-8が0に丸められるという記事を見つけた[1]
  • 混合精度でも同じようなことが起こるか半信半疑だったが、可能な範囲でeps=1e-4に設定するとNaNで死ぬ事象は亡くなった。少なくとも安定性の問題は改善したと言える

Code 1: モデルの予測結果にNaNを含む場合に例外を投げるコード

with autocast(device_type=str(device), enabled=cfg.use_amp):
    output = model(input_dict)
    logit = output["logit"]
    with torch.no_grad():
        if logit.isnan().any():
            torch.save(
                {
                    "state_dict": model.state_dict(),
                    "image": image.detach().cpu().numpy(),
                    "mask": mask.detach().cpu().numpy(),
                    "logit": logit.detach().cpu().numpy(),
                },
                cfg.model_dir / f"model_{cfg.fold_name}_{cfg.seed:04d}_nan.pth",
            )
            raise ValueError("output is nan while training")

注意点

  • PyTorchの自動混合精度学習は「gradientにNaNを含む場合に重みの更新を見送る」という仕様になっている[2]。つまり、特定のstepにおいてlossやgradientにNaNを含む可能性は織り込み済みということである。「lossにNaNを含む場合に例外を投げる」という仕様はあまりスマートではないかもしれない。
  • eps=1e-4があらゆるケースで良いということを主張しているわけではない(大きく設定しすぎるとモデルの表現力を損なうリスクもある)
  • 半精度の話と混合精度の話は分けて考える必要がある。半精度で1e-8が0に丸められるからといって、混合精度にも当てはまるとは限らない

学び/Tips

  • epsを大きめにすると学習の安定性が改善する場合がある
  • lossにNaNが含まれる場合に直ちに例外を投げるのではなく、「NaNを含むstepの割合をカウントしておき、一定の閾値を超えたら死ぬ」というロジックが良いかもしれない
  • 死ぬ場合も、ただメッセージを残すだけでなく、後で問題を解析できるようにエラーが起こったコンテキスト(モデルの重み、バッチ、ラベル、モデルの予測etc.)を保存しておくとよい。

参考資料

Appendix

A. 半精度で表現できる範囲について

[3]によると、

  • 最小の正の非正規化数: 2−24 ≈ 5.96 × 10−8
  • 最小の正の正規化数: 2−14 ≈ 6.10 × 10−5
  • 表現可能な最大数: 65,504

1e-8はこの範囲に入らないので0に丸められる。なお、6.10x10-5よりも小さい数は精度を犠牲にして表現しているので、半精度でepsを設定する場合はこの値より大きな値に設定するのが望ましいのではないかと思う。

実際にオーバーフローするかチェックする方法は以下:

import torch


def is_rounded_to_zero_in_fp16(value):
    half_value = torch.tensor([value], dtype=torch.float32).to(dtype=torch.float16).item()
    return half_value == 0


value = 1e-8
print(is_rounded_to_zero_in_fp16(value))
GitHubで編集を提案

Discussion