📝

クラスを使いこなしてPyTorchのSequentialで効率よく色々なモデル設計を試す

2022/03/02に公開

学習済みなモデルにしても、自分で1から作るにしても、ちょっと変えただけのモデルをいろいろ試したいことがあります。
 そこで、PyTorchのSequentialを使って少し手を加えたモデルをたくさん試す方法をご紹介します。

また、ここで紹介する方法で、PyTorchに付随している学習済みモデルのレイヤーも変更することができるようになります。

解決したい課題

  • ちょっとレイヤーの構成変えただけなのにクラスをコピペして作るやつ
  • どこを変えたかわからなくなる

ダメな例

ダメな例をそのダメな理由とともにご紹介します。

ちょっとレイヤーの構成変えただけなのにクラスをコピペして作るやつ

クラスの変更が一部だけなのに継承しないで作成するやつ

class SimpleModel(nn.Module):  

    def __init__(self):
        super().__init__()
        self.layers1 = nn.Sequential(
            nn.Linear(100, 100),
            nn.ReLU()
        )

        self.layers2 = nn.Sequential(
            nn.Linear(100, 70),
            nn.ReLU()
        )
        
        self.layers3 = nn.Sequential(
            nn.BatchNorm1d(70),
            nn.Linear(70, 2)
        )

    def forward(self, x):
        
        x = self.layers1(x)
        x = self.layers2(x)
        out = self.layers3(x)
        
        return out

これに「あ、DropOut層つけたい」と思った時に、このクラスをコピペして以下のようSimpleModel2とかにしていませんか?
 実際にしている奴がいました。

class SimpleModel(nn.Module):  

    def __init__(self):
        super().__init__()
        self.layers1 = nn.Sequential(
            nn.Linear(100, 100),
            nn.ReLU()
        )
        
        self.layers2 = nn.Sequential(
            nn.Linear(100, 70),
            nn.ReLU()
        )
        
        self.layers3 = nn.Sequential(
            nn.BatchNorm1d(70),
            nn.Linear(70, 2)
        )

    def forward(self, x):
        
        x = self.layers1(x)
        x = self.layers2(x)
        out = self.layers3(x)
        
        return out

class SimpleModel2(nn.Module):  

    def __init__(self):
        super().__init__()
        self.layers1 = nn.Sequential(
            nn.Linear(100, 100),
            nn.ReLU()
        )
        
        # これ以降のレイヤは、カテゴリ変数、連続値ともに同じレイヤを通過するようにしている。
        self.layers2 = nn.Sequential(
            nn.Linear(100, 70),
            nn.ReLU()
        )
        
        self.layers3 = nn.Sequential(
            nn.Dropout(p=0.5, inplace=False),
            nn.BatchNorm1d(70),
            nn.Linear(70, 2)
        )

    def forward(self, x):
        
        x = self.layers1(x)
        x = self.layers2(x)
        out = self.layers3(x)
        
        return out

このスパゲティコーディングは、別に研究室で個人でやっている分に問題ないように思えます。ですが、研究を引き継いだ後輩が「これはなんだー?」とコードリーディングのため、無駄な時間を過ごすことになってしまいよくありません。
 また、人間はすぐに忘れる生き物なので、1ヶ月経ったらそんな変更の内容なんてほっとんど覚えていません。

どこを変更したのかわからない

上のコードで、どこを変更したかぱっと見わからないです。

作った自分でさえ、何を変更したのかわからない状態になってしまっています。
 これはなぜ起こるのでしょうか?
 ただ単に、クラスを継承して変更部分だけを更新すれば良いのに、全て書いてしまっているためです。
 
 実際に、これを書いていた人(先輩)はクラスの継承の方法がわからなかったようです。

Sequentialを使いこなす

解決は簡単で、ちゃんとPyTorchのSequentialを使いこなして、親クラスを継承して作成しましょう。
 PyTorchのSequentialを使うことで、複数のレイヤーを一つにまとめることができて、forward関数の中をすっきり書くことができたりします。

また、公式で配布されている学習済みモデルのアーキテクチャーはSequentialを使って書かれているので、PyTorchのSequentialの操作方法を習得すれば学習済みモデルのレイヤーの調整にも応用できます。

新しいSequentialオブジェクトまたはレイヤーの追加方法

親クラスを以下のSimpleクラスにしたときの操作方法を見ていきます。

親クラス
class SimpleModel(nn.Module):  

    def __init__(self):
        super().__init__()
        self.layers1 = nn.Sequential(
            nn.Linear(100, 100),
            nn.ReLU()
        )
        
        # これ以降のレイヤは、カテゴリ変数、連続値ともに同じレイヤを通過するようにしている。
        self.layers2 = nn.Sequential(
            nn.Linear(100, 70),
            nn.ReLU()
        )
        
        self.layers3 = nn.Sequential(
            nn.BatchNorm1d(70),
            nn.Linear(70, 2)
        )

    def forward(self, x):
        
        x = self.layers1(x)
        x = self.layers2(x)
        out = self.layers3(x)
        
        return out

レイヤーの追加(PyTorchのSequentialオブジェクトを新たに追加する場合)

レイヤーを追加する際は継承して追加しましょう。
 追加といっても、
  新たにPyTorchのSequentialオブジェクトを追加する場合

既にあるSequentialオブジェクトにレイヤーを新たに追加する場合

インスタンス単位で追加する場合

があります。

こっちは、新たにPyTorchのSequentialオブジェクトを追加する場合です。
 この方法は、多くのレイヤーを追加する際に向いています。

class SimpleChildModel(SimpleModel):
    
    def __init__(self):
        super().__init__()
        
        self.layers4 = nn.Sequential(
            nn.BatchNorm1d(70),
            nn.Linear(70, 1),
            nn.Dropout(p=0.5, inplace=False),
        )
    
    def forward(self, x)

        x = self.layers1(x)
        x = self.layers2(x)
        x = self.layers3(x)
        out = self.layers4(x)

        return out

このようにするだけで、SimpleChildModelは新たなレイヤーをつかしたことがすぐにわかります。

レイヤーの追加(既にあるSequentialオブジェクトにレイヤーを新たに追加する場合)

この方法は、少数のレイヤーを追加する際に向いています。

class SimpleChildModel(SimpleModel):
    
    def __init__(self):
        super().__init__()
        
        self.layers3.add_module('dropout', nn.Dropout(0.5))

こうするだけでfowardの中も書き換えずに済むので、大変便利で可読性も向上します。

レイヤーの追加(インスタンス単位で追加する場合)

事前学習済みモデルを利用する際などのレイヤーの書き換えがこれに該当します。
 以下のように書き換えます。

from torch.nn import Sequential, Dropout
model = SimpleModel    
model.layers3 = Sequential(Dropout(0.5), model.layrs3)

で追加することができます。モデルの挿入箇所を確認するには、print(model)などとすることで書き換える箇所を探すことが可能です。

新しいSequentialオブジェクトまたはレイヤーの変更・置き換え方法

Sequentialでのレイヤーの変更方法を見ていきます。
 変更で間違い無いと思うのですが、多くの人が置き換えと表記していたので、一応置き換えとも表記されることがありますくらいの気持ちで書いてます。
 追加がわかって終えば対して難しいことはないのですが、見ていきましょう。

レイヤーを置き換えする際は継承して置き換えしましょう。
 置き換えといっても、
  新たにPyTorchのSequentialオブジェクトを置き換えする場合

既にあるSequentialオブジェクトにレイヤーを新たに置き換えする場合

インスタンス単位で置き換えする場合

があります。

レイヤーの置き換え(PyTorchのSequentialオブジェクトを新たに置き換えする場合)

こっちは、新たにPyTorchのSequentialオブジェクトを置き換えする場合です。
 この方法は、多くのレイヤーを置き換えする際または、Sequentialオブジェクトごと置き換えるのに向いています。

class SimpleChildModel(SimpleModel):
    
    def __init__(self):
        super().__init__()
        
        self.layers3 = self.layers3 = nn.Sequential(
            nn.BatchNorm1d(70),
            nn.Linear(70, 1),
            nn.Dropout(p=0.5, inplace=False),
        )

このようにするだけで、SimpleChildModelはSimpleModelのlayers3を置き換えしたことがすぐにわかります。

レイヤーの追加(既にあるSequentialオブジェクトにレイヤーを新たに置き換えする場合)

この方法は、少数のレイヤーを置き換えする際に向いています。

class SimpleChildModel(SimpleModel):
    
    def __init__(self):
        super().__init__()
        
        self.layers3[0]=('dropout', nn.Dropout(0.5))

こうするだけでfowardの中も書き換えずに済むので、大変便利で可読性も向上します。
 この場合、BatchNorm1dをDropoutに置き換えています。

レイヤーの追加(インスタンス単位で置き換えする場合)

事前学習済みモデルを利用する際などのレイヤーの書き換えがこれに該当します。
 以下のように書き換えます。

from torch.nn import Sequential, Dropout
model = SimpleModel    
model.layers3[0] = Sequential(Dropout(0.5), model.layrs3[0])

で置き換えすることができます。モデルの置き換え箇所を確認するには、print(model)などとすることで書き換える箇所を探すことが可能です。

これは、学習済みモデルの活性化関数の変更などでよく用いると思いますが、その際は、出力の特徴量などのわずかなパラメータだけをいじりたい場合もあると思います。
 そんな時は、特定のパラメータだけいじることも可能です。

modelのレイヤーのパラメーターを変更する

modelレイヤーの中の一部のパラメーターを変更する方法をご紹介します。例としてResNetを取り上げます。

import torch
model = torch.hub.load('rwightman/pytorch-image-models:master', 'resnet18', pretrained=True)
print(model)

#ResNet(
#   (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
#   (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#   (act1): ReLU(inplace=True)
# ...etc

この時、bn1のBatchNorm2dのeps=1e-05をeps=1e-04に取り替える場合を見てみましょう。

model.bn1 = torch.nn.BatchNorm2d(64, eps=1e-04, momentum=0.1, affine=True, track_running_stats=True)

とするだけで完了です。

新しいSequentialオブジェクトまたはレイヤーの削除方法

Sequentialでのレイヤーの削除方法を見ていきます。
 追加がわかって終えば対して難しいことはないのですが、見ていきましょう。

レイヤーを削除する際は継承して削除しましょう。
 削除といっても、
  新たにPyTorchのSequentialオブジェクトを削除する場合

既にあるSequentialオブジェクトにレイヤーを削除する場合

インスタンス単位で削除する場合

があります。

レイヤーの削除(PyTorchのSequentialオブジェクトを削除する場合)

こっちは、新たにPyTorchのSequentialオブジェクトを削除する場合です。
 この方法は、多くのレイヤーを削除する際または、Sequentialオブジェクトごと削除するのに向いています。

class SimpleChildModel(SimpleModel):
    
    def __init__(self):
        super().__init__()
        
        self.layers3 = torch.nn.Identity()

このようにするだけで、SimpleChildModelはSimpleModelのlayers3を削除したことがすぐにわかります。

レイヤーの削除(既にあるSequentialオブジェクトにレイヤーを削除する場合)

この方法は、少数のレイヤーを削除する際に向いています。

class SimpleChildModel(SimpleModel):
    
    def __init__(self):
        super().__init__()
        
        self.layers3[0]=torch.nn.Identity()

こうするだけでfowardの中も書き換えずに済むので、大変便利で可読性も向上します。

レイヤーの削除(インスタンス単位で削除する場合)

事前学習済みモデルを利用する際などのレイヤーの削除がこれに該当します。
 以下のように書き換えます。

from torch.nn import Sequential, Dropout
model = SimpleModel    
model.layers3[0] = torch.nn.Identity()

で削除することができます。モデルの削除箇所を確認するには、print(model)などとすることで削除箇所を探すことが可能です。

最後に

要するにちゃんとクラスの使い方覚えようという話でした。
 クラスの概念がc言語などに比べればPythonは意識しなくても書けるようになっていると思うのですが、やはりしっかりとしたクラスの概念の理解がないとこいうちょっとした時に困るものだなと実感させられます。

GitHubで編集を提案

Discussion