Weights & Biases と pytorch-pfn-extrasをくっつけたら最強なんじゃないかと思った

この投稿はrioyokotalab Advent Calendar 2020 9日目の投稿です。

adventar.org

実験のトラッキング

深層学習の学習コードを動かすにあたって、学習の経過を確認したり、複数の学習の結果を比較したりといったことは、学習アルゴリズムの改善を行うにあたって非常に重要となってきます。これを便利に行えるようにしてくれるツールには、mlflow1, comet.ml2などがあります。

自分は普段はWeights & Biases3を使っているので、今回は、W&Bを使った自分の実験トラッキング環境を紹介します。

Weights & Biases

www.wandb.com

Weights & Biasesというのは、深層学習の実験トラッキングツールの一つで、自前で実験ログ用のストレージを持つ必要のないのが特徴のトラッキングツールです。そのため、アカウント一つで素早く導入することができ、必要に応じて情報の共有が楽だったりします。他にも、ハイパーパラメータサーチのための機能が入っていたり、実験ログをMarkdown形式でまとめたりと、様々な機能が揃っています。(また、コンタクトを取るとレスポンスが素早いという噂も)個人使用を想定している場合は、何も不足することはないと思います。

大体の使い方はこちらを参考にしていただければと思います。

note.com

今回は、実験のログ(学習率、損失、検証スコアなど)をWeights & Biasesに送るための自前のやり方を紹介しようと思います。

pytorch-pfn-extrasとの接続

自分は、学習コードに必要な諸機能をpytorch-pfn-extras4でまかなっています。この話については、同AdCのこちらの記事を参照してください。

deoxy.hatenablog.com

pytorch-pfn-extrasには学習のログを吐く機能があるのですが、このログこそが、W&Bに送りたい情報なのです。なので、この情報をどうやって取り出すかというのが、今回の話題になります。

extensionの作成

pytorch-pfn-extrasでは、実験のログの取得や可視化など、実験コードに付け加えたい便利機能の諸々はextensionと呼ばれるオブジェクトによって、実装されています。基本的な考え方は、ログの可視化(stdoutへの出力)を行うPrintReportの出力先をW&Bに変えてあげればいいということになります。なので、 pytorch-pfn-extrasのgithubコード中のPrintReportのコードを参照し、次のようなクラスを設計しました。

class SendWandB(ppe.training.extensions.PrintReport):
    def __init__(self, entries=None, log_report='LogReport', wandb=None):
        super().__init__(entries, log_report, None)
        self.wandb = wandb

    def __call__(self, manager):
        log_report = self.get_log_report(manager)
        log = log_report.log
        log_len = self._log_len
        while len(log) > log_len:
            self.wandb.log(log[log_len], step=log_len)
            log_len += 1
        self._log_len = log_len

あとは、

manager.extend(E.PrintReport(['epoch', 'iteration', 'lr', 'train/loss', 'valid/loss', 'valid/acc', 'elapsed_time']), trigger=standard_trigger)
manager.extend(SendWandB(['epoch', 'iteration', 'lr', 'train/loss', 'valid/loss', 'valid/acc', 'elapsed_time'], wandb=wandb), trigger=standard_trigger)

といった感じで、PrintReportと同じように書いてあげるとwandbへ送ることができるようになります。

引数のwandbには

import wandb

としてimportしたwandbモジュールを渡すことで、将来的にもしpytorch-pfn-extrasに導入された場合でも依存性を回避できるかなぁとか、別ファイルにモジュール化した場合でも、import文書かなくて良くて依存性回避できるなぁとか、出力先を明示した方がわかりやすいかなぁとか、結局wandbを起動するためにはメインのコード側で一度initメソッドとか呼ばないといけないしなぁとかいろいろ考えて、モジュールを渡すってデザイン大丈夫なのかとか思いながらこれにしました。

ちなみにwandb.initを呼ぶタイミングはログを開始する前であれば、extensionを作成した後でも大丈夫です。

これを使えば、pytorch-pfn-extrasを用いたテンプレートコードからそこまで実装を変えることなく、ログを転送できるようになります。

(確か、PrintReportで表示しようとしたログのカラムとSendWandBで送ろうとしたログのカラムの和集合が送られて表示されちゃってた気がして、manager周りで何か共有しちゃってる気がするけれど、別に困らないから直してない。)

テンプレートコード

class Evaluator(ppe.training.extension.Extension):
    priority = ppe.training.extension.PRIORITY_WRITER

    def __init__(self, model, device, metrics, loader, prefix='valid/'):
        self.model = model
        self.device = device
        self.metrics = metrics
        self.prefix = prefix
        self.loader = loader
        
    def __call__(self, manager):
        self.model.eval()
        logs = {name: [] for name, metric in self.metrics.items()}
        with torch.no_grad():
            for data in self.loader:
                data = to(data, device)
                out = self.model(**data)
                for name, metric in self.metrics.items():
                    met = metric(**out).item()
                    logs[name].append(met)
        for name, value in logs.items():
            ppe.reporting.report({
                self.prefix + name: np.mean(value)
            })

manager = ppe.training.ExtensionsManager(
    model, optimizer, num_epochs,
    iters_per_epoch=len(train_loader),
    out_dir=out_dir
)

standard_trigger = (1, 'epoch')
manager.extend(E.observe_lr(optimizer=optimizer), trigger=standard_trigger)
manager.extend(E.LogReport(trigger=standard_trigger))
manager.extend(E.PrintReport(['epoch', 'iteration', 'lr', 'train/loss', 'valid/loss', 'valid/acc', 'elapsed_time']), trigger=standard_trigger)
manager.extend(SendWandB(['epoch', 'iteration', 'lr', 'train/loss', 'valid/loss', 'valid/acc', 'elapsed_time'], wandb=wandb), trigger=standard_trigger)
manager.extend(Evaluator(model, device, {'loss': criterion, 'acc': metric}, valid_loader, prefix='valid/'), trigger=standard_trigger)
manager.extend(E.snapshot(target=model, filename='best.pth'), trigger=ppe.training.triggers.MaxValueTrigger(key='valid/acc', trigger=standard_trigger))
manager.extend(E.snapshot(target=model, filename='model.pth'), trigger=standard_trigger)

while not manager.stop_trigger:
    for data in train_loader:
        with manager.run_iteration():
            model.train()
            optimizer.zero_grad()
            data = to(data, device)
            out = model(**data)
            loss = criterion(**out)
            ppe.reporting.report({
                'train/loss': loss.item()
            })
            loss.backward()
            optimizer.step()
            scheduler.step()

前述しましたが、pytorch-pfn-extrasの使い方の方も参照していただければと思います。

まとめ

実際、自分がやり易ければ、トラッキングの方法なんてなんでもいいと思ってる。

もちろん自分はこれがやりやすいと思っているので、おすすめです。