PyTorchのテンプレコードを用意してどんなデータセットにも楽々ディープラーニング

この投稿はrioyokotalab Advent Calendar 2020 5日目の投稿です。

adventar.org

PyTorchは自由すぎる

PyTorchは自動微分ライブラリとしての側面が強く、一方で中途半端に深層学習としての機能を提供しているため、コードが書く人によってまちまちになりがち(個人的見解)。かといって、コードの書き方を強要するライブラリを大量に提供すれば、万人に受け入れられるコードにはならない。PyTorchはサードパーティーライブラリを作ってもらうことによって、書き方の共通化を進めているが、結局、そこには好みが出てきて、そのライブラリを使ったことがない人にとっては可読性の低いコードとなってしまう。

今回は、PyTorchの機能のみを使って(共有性の向上)、様々なデータセット、モデル、Loss関数の深層学習コードを実装できるようにした、自己流のテンプレートコードを紹介する。

Python**オペレータの利用

テンプレートコードでは、辞書を**オペレータによってunpackingしてキーワード引数として渡す機能1を利用する。さらに、中間の変数はなるべく辞書として持つことで、テンプレートコードに引数を追加する必要性を減らし、コードの変更量を減らす。こうすることで、データの意味や、値の意味を表現したまま、少ない変更量で学習コードを実現できる。これは、コードの修正が容易になるだけでなく、コードの可読性の向上にもつながると思っている。

テンプレートコード

データセット

class TestDataset:
    
    def __init__(self, df):
        self.data = df.data.to_list()
        self.root = root
        self.preprocess = Preprocess()
        self.postprocess = Postprocess()
        self.augmentation = None
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, i):
        x = Image.open(os.path.join(self.root, self.data[i]))
        x = self.preprocess(x)
        if self.augmentation is not None:
            x = self.augmentation(x)
        x = self.postprocess(x)
        return {
            'x': x,
        }
    
class ValidDataset(TestDataset):
    
    def __init__(self, df):
        super().__init__(df)
        self.label = df.label.to_list()
        
    def __getitem__(self, i):
        ret = super().__getitem__(i)
        ret['target'] = self.label[i]
        return ret
    
class TrainDataset(ValidDataset):
    
    def __init__(self, df, augmentation):
        super().__init__(df)
        self.augmentation = augmentation

データセット部分はTestDataset -> ValidDataset -> TrainDatasetの順で継承させる。

  • TestDatasetには、推論時に行うデータの前処理、正規化やモデルに入力するためのテンソル化などの作業を行う。
  • ValidDatsetでは学習ラベルを入力データとペアにするための作業を行う。
  • TrainDatasetでは主にデータ拡張の加える。

このようにコードを構成することで、推論用にモデルをデプロイする際に、TestDatsetを写すことで、テストデータの読み出しが想定と異なる動作をすることを防ぐことができる。

データローダー

train_dataset = TrainDataset(train_df, Augmentation())
valid_dataset = ValidDataset(valid_df)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)

特に目立つ記述はないが、先ほどのデータセットは辞書オブジェクトを要素に持つ配列となるが、それをPyTorchのDataLoaderに通すと、なんと、辞書要素ごとにバッチ化してくれる

つまり、

[{x: データ, target: ラベル} ... {x: データ, target: ラベル}]となっているデータセットなら、データローダーから取り出されるミニバッチは{x:[データ, ..., データ], target: [ラベル, ..., ラベル]}となっている。

なぜこんな仕様になっているかは知らない。

モデル定義

class Model(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        
    def forward(self, x, **kwargs):
        out = self.model(x)
        kwargs['out'] = out
        return kwargs

後々のコードではデータローダーから取り出された辞書オブジェクトに格納されたバッチを**オペレータによってunpackingして渡す。その時、すべての辞書要素を無差別に渡すようにするので、モデルには関係ないデータ(例えば、ラベル情報)なども入力される。それを変数として認識せず、かつ、返り値に残すために**kwargs変数を使う。これは、キーワード引数で渡された、明示的に書かれていない引数を辞書として持つ機能で、kwargsという名前で、辞書オブジェクトとして扱うことができる。モデルに通した後、その後に扱いたい情報をkwargsに追加して返すことで、拡張性の高いコードとすることができる。(あまり行儀がいいとは言えないかもしれないが)

余談ではあるが、huggingface氏が提供しているtransformers2という有名な自然言語向けの深層学習ライブラリがある。 ここで提供されているモデルは返り値を辞書で返してくるので、辞書で変数をまとめて扱う、というコーディングスタイルはかなり有用であると言える。

損失関数

class Criterion(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.cross_entropy = nn.CrossEntropyLoss()
        
        
    def forward(
        self,
        out,
        target,
        **kwargs
    ):
        return self.cross_entropy(out, target)

損失関数についても、モデル定義の時と同様に辞書オブジェクトを**オペレータでunpackingして渡されることを想定して、**kwargs引数を持った状態で作成する。モデル定義の時と違うのは、損失は.backwardメソッドを呼んで、勾配を計算する必要があるので、辞書ではなく、torch.tensorを直接返すようにする。

class AccMetric(nn.Module):
    
    def __init__(self):
        super().__init__()
        
    def forward(
        self,
        out,
        target,
        **kwargs
    ):
        pred = out.argmax(dim=1)
        return (pred == target).float().mean()

その他metricsも必要に応じて設計する。

バイス転送関数

def to(x, device, *args, **kwargs):
    return {
        key: value.to(device, *args, **kwargs) for key, value in x.items()
    }

辞書オブジェクトで渡されたミニバッチはtorch.tensorではないので、.toメソッドを直接呼ぶことができない。従って、各辞書要素について、.toメソッドを呼ぶ関数を用意する。

*args, **kwargs引数は.toメソッドに渡す他の引数(例えば、non_blockingなど)のために用意しておく。

その他学習に必要なオブジェクトの用意

num_epochs = 20
model = Model(_model)
model = model.to(device)
criterion = Criterion()
metric = AccMetric()
optimizer = torch.optim.AdamW(model.parameters())
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, 
    max_lr=0.01, 
    epochs=num_epochs, 
    steps_per_epoch=len(train_loader),
    pct_start=0.1
)

作成したクラスなどからオブジェクトを作成し、その他optimzierやschedulerなどを設定する。

学習

for i in range(num_epochs):
    for data in train_loader:
        model.train()
        optimizer.zero_grad()
        data = to(data, device)
        out = model(**data)
        loss = criterion(**out)
        loss.backward()
        optimizer.step()
        scheduler.step()
    for data in valid_loader:
        model.eval()
        with torch.no_grad():
            data = to(data, device)
            out = model(**data)
            acc = metric(**out)
            loss = criterion(**out)
            

これだけで学習ができる。

まとめ

他の記事で言及しようと思っているが、実際には自分は、Weights and Biases3やpytorch-pfn-extras4など、実験トラッキングツールを活用している。しかし、それらは個人の好みが関わってくるので、今回は言及しない。ここまで、書いたテンプレートコードをまとめると次のようになる。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

class TestDataset:
    
    def __init__(self, df):
        self.data = df.data.to_list()
        self.root = root
        self.preprocess = Preprocess()
        self.postprocess = Postprocess()
        self.augmentation = None
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, i):
        x = Image.open(os.path.join(self.root, self.data[i]))
        x = self.preprocess(x)
        if self.augmentation is not None:
            x = self.augmentation(x)
        x = self.postprocess(x)
        return {
            'x': x,
        }
    
class ValidDataset(TestDataset):
    
    def __init__(self, df):
        super().__init__(df)
        self.label = df.label.to_list()
        
    def __getitem__(self, i):
        ret = super().__getitem__(i)
        ret['target'] = self.label[i]
        return ret
    
class TrainDataset(ValidDataset):
    
    def __init__(self, df, augmentation):
        super().__init__(df)
        self.augmentation = augmentation


train_dataset = TrainDataset(train_df, Augmentation())
valid_dataset = ValidDataset(valid_df)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)

class Model(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        
    def forward(self, x, **kwargs):
        out = self.model(x)
        kwargs['out'] = out
        return kwargs

class Criterion(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.cross_entropy = nn.CrossEntropyLoss()
        
        
    def forward(
        self,
        out,
        target,
        **kwargs
    ):
        return self.cross_entropy(out, target)

class AccMetric(nn.Module):
    
    def __init__(self):
        super().__init__()
        
    def forward(
        self,
        out,
        target,
        **kwargs
    ):
        pred = out.argmax(dim=1)
        return (pred == target).float().mean()

def to(x, device, *args, **kwargs):
    return {
        key: value.to(device, *args, **kwargs) for key, value in x.items()
    }

num_epochs = 20
model = Model(_model)
model = model.to(device)
criterion = Criterion()
metric = AccMetric()
optimizer = torch.optim.AdamW(model.parameters())
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, 
    max_lr=0.01, 
    epochs=num_epochs, 
    steps_per_epoch=len(train_loader),
    pct_start=0.1
)


for i in range(num_epochs):
    for data in train_loader:
        model.train()
        optimizer.zero_grad()
        data = to(data, device)
        out = model(**data)
        loss = criterion(**out)
        loss.backward()
        optimizer.step()
        scheduler.step()
    for data in valid_loader:
        model.eval()
        with torch.no_grad():
            data = to(data, device)
            out = model(**data)
            acc = metric(**out)
            loss = criterion(**out)
            

別にこれが最適解だとは思っていないので、まだまだ煮詰めていきたい。