nn.DataParallelを使えばデータをdeviceに送らなくていい
この投稿はrioyokotalab Advent Calendar 2020 13日目の投稿です。
nn.DataParallel
PyTorchでマルチGPUを使った機械学習を行いたい場合、
- nn.DataParalllel
- DistributedDataParallel
の二種類の方法があります。
DistributedDataParallel
はマルチプロセスで動かせるので、データの読み込みやモデルのアップデートなども並列に行え、高い並列化性能を出すことが可能です。さらに、複数のコンピュータを並列で動かして学習することも可能で、大規模なマルチノード並列実行を行うことが可能です。
nn.DataParallel
は、これらの利点を捨てる代わりに、特にコード変更を行うことなく、(新たにバグが生まれる可能性や、並列時の通信によるバグをふむ可能性を回避できる)マルチGPUの恩恵を受けることができるようになる機能です。なので、PyTorch的には、DistributerdDataParallelを推奨しているようですが、捨てられない機能の一つとなっています。
nn.DataParallel
はシングルプロセスでデータの読み込み、デバイス間通信などを行いますが、データは適宜、それぞれのGPUに転送してくれて、それぞれのGPUでforward, backwardを行ってくれます。
デバイス転送の管理を楽にする
nn.DataParallel
は、
model_ = Model() model = nn.DataParallel(model_, device_ids=[0, 1])
のように元々のmodelをwrapして使います。こうしてできたmodel
はマルチGPUで動くので、
for data in train_loader: optimizer.zero_grad() data = data.to('cuda:0') out = model_(data) ...
などとして、cpuから特定のdeviceを指定して転送することができず、代わりに、
for data in train_loader: optimizer.zero_grad() out = model(data) ...
というふうに、data
の転送先を特に明示せず、実行することができます。実際には、model
の内部で、
- ミニバッチを適当にn等分して、それぞれのデバイスに転送
- それぞれのデバイスでのforwardを実行
- 出力を集約し、特定のdevice上に集める
- その他計算グラフに関わる計算を実行
- backwardパスに従って、model内部で、出力勾配を配る
- それぞれのdeviceで、backwardを実行
といった処理が行われているようです。
どういった場面で使いたいか
学習アルゴリズム中に、複数のモデルが出る場合などに便利です。それぞれのモデルをメモリの関係で、別々のデバイス上で学習させたいが、デバイスも複数になって、データをどのデバイスに転送しないといけないか、不透明になってきた場合、モデルが乗っているデバイスに転送するというルールの元、特に明示的に書くことなくデバイスへの転送が行えます。
例えば、
- GANのgeneratorとdiscriminatorを分ける
- Cross Validationで並列に複数モデルを学習させる
など。より、難しいアルゴリズムを考えれば、考えるほど、こういったデバイスの管理は煩雑になってくると思います。
まとめ
別に、楽にしたいだけで、これしか選択肢がない話でもないし、普通に、デバイスをちゃんと変数で持ってればいい話でもあるし...
普段は使ってないですが、たまに、面白半分で使ってます。