Closed1

pytorchのInstanceNormでtrack_running_stats=Trueにしたときの挙動について

bilzardbilzard

概要

pytorchのinstance_normでtrack_running_stats=Trueとしたときの挙動がトリッキーらしい。
本稿は[1]の裏どりがメイン。

報告された挙動

  • 訓練時、instance_norm(track_running_stats=True)の処理結果はinstance_norm(track_running_stats=False)の処理結果と一致する
  • 推論時、instance_norm(track_running_stats=True)の処理結果はinstance_norm(track_running_stats=False)の処理結果と異なる。また、batch_norm(track_running_stats=True)の処理結果と近い値になる

PyTorchのinstance_normの実装

  • instance_normは実際にはbatch_normとして実装されている。
  • (B, C, H, W)のtensorを(1, B*C, H, W)にreshapeしてbatch_normを適用している
  • running meanの挙動は、(1, B*C, H, W)の形状のtensorを(B, C, H, W)にreshapeし、batch方向の統計を計算している

https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Normalization.cpp#L653-L689

思ったこと

  • そもそもrunning mean/varの移動平均をとるのであればサンプル(mini batch)間で相関が生まれるので、純粋なinstance_normではなくなるのでは?(むしろbatch_normに近い振る舞いになる)

参考資料

このスクラップは2023/11/05にクローズされました