nn.DataParallelを使えばデータをdeviceに送らなくていい

この投稿はrioyokotalab Advent Calendar 2020 13日目の投稿です。

adventar.org

nn.DataParallel

PyTorchでマルチGPUを使った機械学習を行いたい場合、

  • nn.DataParalllel
  • DistributedDataParallel

の二種類の方法があります。

DistributedDataParallelはマルチプロセスで動かせるので、データの読み込みやモデルのアップデートなども並列に行え、高い並列化性能を出すことが可能です。さらに、複数のコンピュータを並列で動かして学習することも可能で、大規模なマルチノード並列実行を行うことが可能です。

nn.DataParallelは、これらの利点を捨てる代わりに、特にコード変更を行うことなく、(新たにバグが生まれる可能性や、並列時の通信によるバグをふむ可能性を回避できる)マルチGPUの恩恵を受けることができるようになる機能です。なので、PyTorch的には、DistributerdDataParallelを推奨しているようですが、捨てられない機能の一つとなっています。

nn.DataParallelはシングルプロセスでデータの読み込み、デバイス間通信などを行いますが、データは適宜、それぞれのGPUに転送してくれて、それぞれのGPUでforward, backwardを行ってくれます。

バイス転送の管理を楽にする

nn.DataParallelは、

model_ = Model()
model = nn.DataParallel(model_, device_ids=[0, 1])

のように元々のmodelをwrapして使います。こうしてできたmodelはマルチGPUで動くので、

for data in train_loader:
    optimizer.zero_grad()
    data = data.to('cuda:0')
    out = model_(data)
    ...

などとして、cpuから特定のdeviceを指定して転送することができず、代わりに、

for data in train_loader:
    optimizer.zero_grad()
    out = model(data)
    ...

というふうに、dataの転送先を特に明示せず、実行することができます。実際には、modelの内部で、

  1. ミニバッチを適当にn等分して、それぞれのデバイスに転送
  2. それぞれのデバイスでのforwardを実行
  3. 出力を集約し、特定のdevice上に集める
  4. その他計算グラフに関わる計算を実行
  5. backwardパスに従って、model内部で、出力勾配を配る
  6. それぞれのdeviceで、backwardを実行

といった処理が行われているようです。

どういった場面で使いたいか

学習アルゴリズム中に、複数のモデルが出る場合などに便利です。それぞれのモデルをメモリの関係で、別々のデバイス上で学習させたいが、デバイスも複数になって、データをどのデバイスに転送しないといけないか、不透明になってきた場合、モデルが乗っているデバイスに転送するというルールの元、特に明示的に書くことなくデバイスへの転送が行えます。

例えば、

  • GANのgeneratorとdiscriminatorを分ける
  • Cross Validationで並列に複数モデルを学習させる

など。より、難しいアルゴリズムを考えれば、考えるほど、こういったデバイスの管理は煩雑になってくると思います。

まとめ

別に、楽にしたいだけで、これしか選択肢がない話でもないし、普通に、デバイスをちゃんと変数で持ってればいい話でもあるし...

普段は使ってないですが、たまに、面白半分で使ってます。