😊

【PyTorch】モデルの中間層の出力を取得する「Hook」の使い方:ResNetでの具体例付き

に公開

PyTorchで深層学習モデルの**中間層の出力(特徴量)**を取得したいとき、モデル構造を壊さずにアクセスできるのが「Hook(フック)」です。

この記事では、WideResNet50 を例にとって、forward_hook を使って中間特徴マップを抽出する方法を、実用コードとともに解説します。

0.Hookとは?なぜ使うのか?

PyTorchの Hook(フック) とは、モデルの 特定の層の順伝播や逆伝播のタイミングで、入力や出力にアクセスできる仕組み です。とくに forward_hook を使うと、「順伝播(forward)」の 出力 をキャッチできます。

フックが活躍する場面:

  • 中間層の特徴マップ(feature map)を抽出したいとき
  • 活性化値を可視化したいとき(例:Grad-CAM)
  • 複数スケールの特徴を使った異常検知(例:PaDiM、PatchCore)
  • モデル内部のデバッグ

1. モデルと環境の準備

import torch
import torchvision.models as models
import os

# Intel MKLの重複読み込みを防止(環境依存の対策)
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"

# GPUが利用可能ならGPU、なければCPUを使用
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Current Device is {device}")

# ImageNetで事前学習されたWideResNet50を読み込み
model = models.wide_resnet50_2(weights=models.Wide_ResNet50_2_Weights.IMAGENET1K_V1)
model.eval().to(device)

2. Forward Hookの定義と登録

出力保存用のリストを準備

outputs = []

フック関数の定義

def hook(module, input, output):
    outputs.append(output.clone().detach().cpu().numpy())

この関数は、対象の層が順伝播した直後に自動的に呼び出され、その出力テンソルをNumPy形式で outputs に保存します。

フックを登録する

model.layer1[-1].register_forward_hook(hook)
model.layer2[-1].register_forward_hook(hook)
model.layer3[-1].register_forward_hook(hook)

ここで layer1[-1] は、ResNetのlayer1ブロックの最後のBottleneck層を指しています。

3. ダミー画像で動作を確認

dummy_input_tensor = torch.randn(1, 3, 512, 512).to(device)

with torch.no_grad():
    _ = model(dummy_input_tensor)

4. フックの結果を確認する

for i, layer_output_np in enumerate(outputs):
    print(f"Output from Hook {i+1}")
    print(f"  Shape: {layer_output_np.shape}")
    print(f"  First values: {layer_output_np.flatten()[:5]}")

例:

Output from Hook 1
  Shape: (1, 256, 128, 128)
  First values: [0.003, 0.014, 0.018, ...]

5. 活用例:異常検知や可視化に

取得した中間特徴は、以下のような応用が可能です:

活用法 説明
PaDiM・PatchCore 異常検知用に特徴を保存し、類似度や距離を計算
Grad-CAM 勾配と組み合わせて注目領域を可視化
転移学習の前段処理 中間出力を別モデルへ渡して学習に利用

6. 注意点と補足

  • フックは forward(順伝播)時のみ 呼ばれます(backward_hook は逆伝播用)
  • 登録されたフックは解除しないと残り続けます。後述の handle.remove() を使いましょう
  • 多くの層にHookを登録するとメモリを圧迫します

補足①:ResNetのブロック構造と layer1[-1]

ブロック名 Bottleneck数 出力チャンネル 出力サイズ(入力224x224時)
layer1 3 256 56×56
layer2 4 512 28×28
layer3 6 1024 14×14
layer4 3 2048 7×7

layer4 は空間分解能が小さすぎるため、局所的な異常検知には不向きなことがあります。したがって、通常は layer1layer3 を使用します。


補足②:register_forward_hook() の構造と文法

以下のような関数を登録することで、モデルの各層の「出力」を取得できます:

def hook(module, input, output):
    print(output.shape)

handle = model.layer2[-1].register_forward_hook(hook)

引数の意味:

引数 内容
module 呼び出されたモジュール(例:Conv2d)
input 入力(タプルで渡される)
output 出力テンソル(これを保存したり可視化)

フックの解除方法:

handle.remove()

Hookの一連の流れ(図解)

まとめ

  • PyTorchのforward_hookを使うと、中間層の出力を簡単に取得できる
  • ResNetの layer1[-1] のような書き方で、任意の層の出力を抽出可能
  • 特徴抽出・異常検知・可視化など、多くの応用に使える
  • Hookは登録→使用→解除の流れで安全に活用

Discussion