ゼロ割が見つからない時のNaNの原因
この投稿はrioyokotalab Advent Calendar 2020 17日目の投稿です。
深層学習中最悪のバグ
深層学習の学習コードを何度も自前で組んでいるといつか出くわすNaN
。時に再現性が無かったり、再現するのに1時間かかったり、ひたすらにプログラマの頭を悩ませることになります。もちろん、PyTorchのデバッグ機能をしっかり利用して見つけられることもありますが、学習が進むにつれて、とある値がアンダーフローやオーバーフローすることによって、発生するNaN
は原因を特定するのがひたすらに面倒です。しかも、精度を大きく改善させられそうなユニークなアルゴリズムを実装した時に限ってそういうNaN
はよく発生するものです。そして、このNaN
を解決するためには、時にエスパーと呼ばれる、プログラマの特殊能力が必要になります。
NaN
に対処するための記事はググればいくらでも出てきます。今回は、ゼロ割じゃないと言い切れる場合に起きうるもう一つのよくありがちなNaN
の可能性を紹介します。
見つからないときはbackwardを想像する
自分がNaNに悩まされた時に一番役立った記事がこちら
forwardの計算式はプログラム上で明示的に書かれているので、ゼロ割や以上に巨大な値を出していないか探すのにそこまで苦労はしないと思います。しかし、backwardは結構忘れがちです。まさに自動微分フレームワークという甘ったれた機械学習環境に頼り切ったでーたさいえんてぃすとの自業自得ってやつですね...。
具体的に
sqrt
ベクトルのノルム などを計算する時によく出てくるsqrtです。ベクトルを正規化する時に
などとして、ゼロ割を回避してやった気でいると、踏んでしまうやつです。ちなみにPyTorchのWeightNormalizationというクラスは、ここのゼロ割対策が入っておらず、自前で実装し直す必要がありました。その時に私が踏みました。
正しくは、
としてやる必要があります。
sqrtは中身が0の時、勾配がinf
に飛ぶので、中身が0にならないようにしなければなりません。
まあ、確かに勾配が真上を向いている気がする...
おまけで
atan2
の二次元ベクトルの原点と、から見た角度を出力する関数です。metric learningとか、空間認識系のタスクに必要なのでしょうか。使ったことがないのでわかりません。
を入れると、計算途中で、が計算されるため、が計算されるため、値がNaN
になります。
まあ、何が計算されているのかわからないですね。
まとめ
深層学習の自動微分コードはbackwardまでが実装コードであることを意識しておくこと。さもないと、丸一日分の進捗が無に帰すことになります。