Rank Gaussという正規化手法

この投稿はrioyokotalab Advent Calendar 2020 3日目の投稿です。

adventar.org

Rank Gauss

Kaggleでは良く知られているRank Gaussという正規化手法について紹介します。RankGaussは2017年のKaggleコンペ"Porto Seguro’s Safe Driver Prediction”で提案された1正規化手法です。この正規化手法は、平均を引いて標準偏差で割るという、画像処理などでよく使われる正規化とは違い、どんなデータ分布であっても、標準正規分布に直すことができるという強力な手法となっています。

正規分布に正規化するモチベーション

機械学習モデルについて統計的な議論を行う時、汎化や収束速度などの議論をする場合、入力データ分布に正規分布を仮定することが多いです。特にニューラルネットワーク系の話題になるとその色が強くなります。つまり、ニューラルネットワークにおける、様々なアルゴリズムの恩恵を受けるためには、入力が正規分布である必要があるのです。

RankGaussを使うモチベーション

実世界のデータは必ず正規分布になっているとは限りません。しかし、正規分布という、自然界のいろいろなものがなぜか従う分布を背後にもっているデータ分布はあるかもしれません。データ分布が生成されるモデル背景とその背後にある正規分布を見つけることができれば、話は早いですが、そんなことをしなくても、生成するモデルが単調性さえ維持していれば、RankGaussで正規分布に戻せるのです。

アルゴリズム

ある、数値データを考えます。まず、数値データを順位付けし、その順位が上位どの割合にあるかという[0, 1]の値に変換します。これだけで、元のデータ分布を一様分布に押し込めることができました。同じ要領で正規分布に押し込めることができます。

  1. 順位付けする
  2. -1 ~ 1の範囲にスケーリングする
  3. 逆誤差関数:  \frac{2}{\sqrt{\pi}} \int_{0}^{x} e^{-t^{2}} d t逆関数に通す

ことで、正規化ができます。

f:id:deoxy:20201203233753p:plain

大体こんなイメージ。

実装

scikit-learnにQuantileTransformer2というメソッドがあります。このoutput_distributionという引数に'normal'を渡すとRankGaussと同じ機能になります。 正規化は学習データに基づいて統計的な操作を加えるものなので、学習データの統計情報を保持する必要があります。QurantileTransformerはscikit-learnの学習モデルインターフェースに従って、.fitメソッドと.transformメソッドがあり、.fitメソッドに学習データを渡した後、.transformに学習、検証、テストデータを渡してあげると、正規分布に直されたデータが手に入ります。

まとめ

要は最強の正規化アルゴリズムf:id:deoxy:20201203235726p:plain ...まあ、何でもかんでも正規分布に押し込めりゃいいって話ではない。