🔥

PytorchでLossがnanが出たらどうするか

2022/07/15に公開約300字1件のコメント

解決法1. torch.autograd.set_detect_anomaly(True)を追加する

これによって勾配のエラーがわかるので、エラーの概観を掴むことができる。これをしてからより細かくみていったほうがいいと思った。

解決法2. assert not torch.isnan(hoge_tensor).any()を追加する

気になるところにこれをとりあえずかく。print文のように後から消したり、コメント化しなくていいのでこっちの方がいいかもしれない。

多分よくある原因

  • データセットの問題

データがそもそも良くない。z-normalizationでnanが出る

  • Lossの計算方法の問題

ニューラルネットワークというよりもLossの計算に問題がある

Discussion

基本的にデータセットにそもそもNaNが含まれているのがオチであることがおおいと思います

ログインするとコメントできます