👀

Pytorchで書いたモデルの中間層と友達になろう

2023/10/05に公開

ブラックボックスをちょっと覗くだけ!ちょっとだけだから!
蔑ろにされがちな中間層くんとお友達になろうの回です

TL:DR

torch.nn.Moduleで設定した任意の層の順伝播時における入出力、モジュールは、pytorchの機能であるregister_forward_hookを使用することで保存、書き出しなどの操作が可能

背景

機械学習を触り始めていると、このモデルのここの部分で出力されている特徴量が取得できたらな〜と思う機会が増えると思います。
しかし、以下のような悩みに直面し、 面倒くさい 着手するのが億劫になってしまいがちです。

  1. 大規模なソースコードであればあるほど、内部構造が複雑でソースコードを追うのがつらい
  2. 論文とコードを対応付けて、たぶんこの層だろう、と当たりをつけても、実際のソースコード中のどの変数に対応するかが分かりづらい
  3. データを捕捉しても、そこからすべての返り値に欲しい変数のreturnを書かなければならない

register_forward_hookを使おう

実は、pytorchの機能として各層に対する操作が行えるものがあります。

How can l load my best model as a feature extractor/evaluator? - Pytorch Forums

上記のforumで示されている、register_forward_hookが該当します。
nn.Moduleで定義されているモジュールなら呼び出すことが可能です。

使い方として、まず、関数を引数として渡し、hookとして「登録」します。

register_forward_hookに登録した関数は、データがモデルのその層を通ったときに入出力およびモジュールを引数として渡して実行されます。登録した関数の実行が終わると、次の処理に進む、といった流れになっています。

backpropagationの際に処理したい関数を登録できるregister_backward_hookもあるようです。(僕はまだ使ったことがないので詳細は書けません)

サンプルコード

中間層出力の保存

import torch
import torch.nn as nn

# サンプルモデルの定義
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, 1)
        self.fc1 = nn.Linear(16 * 6 * 6, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = x.view(-1, 16 * 6 * 6)
        x = self.fc1(x)
        return x

# 中間層の出力を保存するための関数
outputs = []
def hook(module, input, output):
    outputs.append(output)

model = SimpleModel()
model.conv1.register_forward_hook(hook)

# テストデータ
input_data = torch.randn(1, 3, 8, 8)
output_data = model(input_data)
print(outputs[0].shape)

上記のコードでは、conv1という層に対しlayerを登録しています。

実行結果

torch.Size([1, 16, 6, 6])

層の名前の確認

さて、このようなモデルの定義から取得するべき、対応する層がわかっている場合はさておき、中にはわざわざ定義を見に行くのすら面倒な場合があるかもしれません。そんなときは、一度モデルをprintしてみましょう。

print(model)

実行結果

SimpleModel(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=576, out_features=10, bias=True)
)

あるいは、named_childrenでちょっとテクニカルに取得するのもありかもしれません。

for name, layer in model.named_children():
    print(name, layer)

実行結果

conv1 Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))
fc1 Linear(in_features=576, out_features=10, bias=True)

printした際に()になっている部分が、層の名前(name)に対応します。複雑なモデルでは入れ子になっているため、層のどこにアクセスしているかに注意しましょう

hookに引数をさらに追加したい場合

デフォルトライブラリのfunctoolsを使用することで、登録する関数に更に引数を追加できます。

import functools

def hook_with_arg(module, input, output, arg):
    print(f"Hook called with arg: {arg}")

arg = "Sample Argument"
model.conv1.register_forward_hook(functools.partial(hook_with_arg, arg=arg))

output_data = model(input_data)

実行結果

Hook called with arg: Sample Argument

結び

中間層くんも新しい友達ができてうれしそうです

公式ドキュメント

https://pytorch.org/docs/stable/generated/torch.nn.modules.module.register_module_forward_hook.html

Discussion