Weights & Biases と pytorch-pfn-extrasをくっつけたら最強なんじゃないかと思った
この投稿はrioyokotalab Advent Calendar 2020 9日目の投稿です。
実験のトラッキング
深層学習の学習コードを動かすにあたって、学習の経過を確認したり、複数の学習の結果を比較したりといったことは、学習アルゴリズムの改善を行うにあたって非常に重要となってきます。これを便利に行えるようにしてくれるツールには、mlflow1, comet.ml2などがあります。
自分は普段はWeights & Biases3を使っているので、今回は、W&Bを使った自分の実験トラッキング環境を紹介します。
Weights & Biases
Weights & Biasesというのは、深層学習の実験トラッキングツールの一つで、自前で実験ログ用のストレージを持つ必要のないのが特徴のトラッキングツールです。そのため、アカウント一つで素早く導入することができ、必要に応じて情報の共有が楽だったりします。他にも、ハイパーパラメータサーチのための機能が入っていたり、実験ログをMarkdown形式でまとめたりと、様々な機能が揃っています。(また、コンタクトを取るとレスポンスが素早いという噂も)個人使用を想定している場合は、何も不足することはないと思います。
大体の使い方はこちらを参考にしていただければと思います。
今回は、実験のログ(学習率、損失、検証スコアなど)をWeights & Biasesに送るための自前のやり方を紹介しようと思います。
pytorch-pfn-extrasとの接続
自分は、学習コードに必要な諸機能をpytorch-pfn-extras4でまかなっています。この話については、同AdCのこちらの記事を参照してください。
pytorch-pfn-extrasには学習のログを吐く機能があるのですが、このログこそが、W&Bに送りたい情報なのです。なので、この情報をどうやって取り出すかというのが、今回の話題になります。
extensionの作成
pytorch-pfn-extrasでは、実験のログの取得や可視化など、実験コードに付け加えたい便利機能の諸々はextensionと呼ばれるオブジェクトによって、実装されています。基本的な考え方は、ログの可視化(stdoutへの出力)を行うPrintReportの出力先をW&Bに変えてあげればいいということになります。なので、 pytorch-pfn-extrasのgithubコード中のPrintReportのコードを参照し、次のようなクラスを設計しました。
class SendWandB(ppe.training.extensions.PrintReport): def __init__(self, entries=None, log_report='LogReport', wandb=None): super().__init__(entries, log_report, None) self.wandb = wandb def __call__(self, manager): log_report = self.get_log_report(manager) log = log_report.log log_len = self._log_len while len(log) > log_len: self.wandb.log(log[log_len], step=log_len) log_len += 1 self._log_len = log_len
あとは、
manager.extend(E.PrintReport(['epoch', 'iteration', 'lr', 'train/loss', 'valid/loss', 'valid/acc', 'elapsed_time']), trigger=standard_trigger) manager.extend(SendWandB(['epoch', 'iteration', 'lr', 'train/loss', 'valid/loss', 'valid/acc', 'elapsed_time'], wandb=wandb), trigger=standard_trigger)
といった感じで、PrintReportと同じように書いてあげるとwandbへ送ることができるようになります。
引数のwandbには
import wandb
としてimportしたwandbモジュールを渡すことで、将来的にもしpytorch-pfn-extrasに導入された場合でも依存性を回避できるかなぁとか、別ファイルにモジュール化した場合でも、import文書かなくて良くて依存性回避できるなぁとか、出力先を明示した方がわかりやすいかなぁとか、結局wandbを起動するためにはメインのコード側で一度init
メソッドとか呼ばないといけないしなぁとかいろいろ考えて、モジュールを渡すってデザイン大丈夫なのかとか思いながらこれにしました。
ちなみにwandb.init
を呼ぶタイミングはログを開始する前であれば、extensionを作成した後でも大丈夫です。
これを使えば、pytorch-pfn-extrasを用いたテンプレートコードからそこまで実装を変えることなく、ログを転送できるようになります。
(確か、PrintReportで表示しようとしたログのカラムとSendWandBで送ろうとしたログのカラムの和集合が送られて表示されちゃってた気がして、manager周りで何か共有しちゃってる気がするけれど、別に困らないから直してない。)
テンプレートコード
class Evaluator(ppe.training.extension.Extension): priority = ppe.training.extension.PRIORITY_WRITER def __init__(self, model, device, metrics, loader, prefix='valid/'): self.model = model self.device = device self.metrics = metrics self.prefix = prefix self.loader = loader def __call__(self, manager): self.model.eval() logs = {name: [] for name, metric in self.metrics.items()} with torch.no_grad(): for data in self.loader: data = to(data, device) out = self.model(**data) for name, metric in self.metrics.items(): met = metric(**out).item() logs[name].append(met) for name, value in logs.items(): ppe.reporting.report({ self.prefix + name: np.mean(value) }) manager = ppe.training.ExtensionsManager( model, optimizer, num_epochs, iters_per_epoch=len(train_loader), out_dir=out_dir ) standard_trigger = (1, 'epoch') manager.extend(E.observe_lr(optimizer=optimizer), trigger=standard_trigger) manager.extend(E.LogReport(trigger=standard_trigger)) manager.extend(E.PrintReport(['epoch', 'iteration', 'lr', 'train/loss', 'valid/loss', 'valid/acc', 'elapsed_time']), trigger=standard_trigger) manager.extend(SendWandB(['epoch', 'iteration', 'lr', 'train/loss', 'valid/loss', 'valid/acc', 'elapsed_time'], wandb=wandb), trigger=standard_trigger) manager.extend(Evaluator(model, device, {'loss': criterion, 'acc': metric}, valid_loader, prefix='valid/'), trigger=standard_trigger) manager.extend(E.snapshot(target=model, filename='best.pth'), trigger=ppe.training.triggers.MaxValueTrigger(key='valid/acc', trigger=standard_trigger)) manager.extend(E.snapshot(target=model, filename='model.pth'), trigger=standard_trigger) while not manager.stop_trigger: for data in train_loader: with manager.run_iteration(): model.train() optimizer.zero_grad() data = to(data, device) out = model(**data) loss = criterion(**out) ppe.reporting.report({ 'train/loss': loss.item() }) loss.backward() optimizer.step() scheduler.step()
前述しましたが、pytorch-pfn-extrasの使い方の方も参照していただければと思います。
まとめ
実際、自分がやり易ければ、トラッキングの方法なんてなんでもいいと思ってる。
もちろん自分はこれがやりやすいと思っているので、おすすめです。