Closed1
pytorchのInstanceNormでtrack_running_stats=Trueにしたときの挙動について

概要
pytorchのinstance_normでtrack_running_stats=True
としたときの挙動がトリッキーらしい。
本稿は[1]の裏どりがメイン。
報告された挙動
- 訓練時、
instance_norm(track_running_stats=True)
の処理結果はinstance_norm(track_running_stats=False)
の処理結果と一致する - 推論時、
instance_norm(track_running_stats=True)
の処理結果はinstance_norm(track_running_stats=False)
の処理結果と異なる。また、batch_norm(track_running_stats=True)
の処理結果と近い値になる
PyTorchのinstance_normの実装
- instance_normは実際にはbatch_normとして実装されている。
- (B, C, H, W)のtensorを(1, B*C, H, W)にreshapeしてbatch_normを適用している
- running meanの挙動は、(1, B*C, H, W)の形状のtensorを(B, C, H, W)にreshapeし、batch方向の統計を計算している
思ったこと
- そもそもrunning mean/varの移動平均をとるのであればサンプル(mini batch)間で相関が生まれるので、純粋なinstance_normではなくなるのでは?(むしろbatch_normに近い振る舞いになる)
参考資料
このスクラップは2023/11/05にクローズされました