機械学習エンジニアの備忘録

主に自分が勉強したことのメモ

Pytorchを使っているときにlossがnanになったときに確認すること

pytorchでモデルを学習させているときにlossがnanになって時間を結構溶かしたのでメモ。

BCEではなくBCEWithLogitsLossを使っているか

2値分類問題で自分で最終レイヤをいじったモデルはnanにならなくてOSSのモデル実装を使うとなぜかnanになる現象にぶちあった。
自分が定義したモデルをよくよく見るとforward関数の最後でsigmoid関数に通しており、BCEロスを使って学習していた。
しかし、OSSの実装はforwardの最後でsigmoidを通しておらず、自分はその状態でBCEを使って学習させていた。
どうやらこれがロスがnanになってしまった原因だった。
実際にBCEWithLogitsLossを使って学習するように変更したところロスがnanにならなくなった。

PytorchでSOTAモデルの実装を集めたGithubレポジトリは最後にsigmoidを通していない場合が多い気がするので基本BCEWithLogitsLossを使ったほうがいいかもしれない。(本来はちゃんとコードを読んで実装を確認するべき)
また、

公式のページでもBCEよりBCEWithLogitsLossのほうが数値計算的に安定しているらしいので基本はBCEWithLogitsLossを使う方針でよいと思う。
ただし、その場合はpredictの際には自分でモデルの出力をsigmoid関数に通すことを忘れずに。

pytorch.org

今回自分が陥ったケースではないがロスがnanになる原因として、入力にnanが含まれている、正規化していない等の問題が考えられる。