🐘

野生のモデルを飼い慣らす

に公開

今回は思考の整理を兼ねて、timm (pytorch image models)などの公開モデルを目的のタスクに合わせてカスタマイズする時に、自分がいつもやっているルーティーンについて紹介します。

後半では最近参加したコンペ(Yale, BYU)で遭遇したハマりどころを紹介しつつ、学習・推論の高速化のTipsについて書きました。

知識編

座学として最初に押さえておくと良いのは以下のような内容だと思います。

  • convolutionの機能とパラメータを正確に理解する(kernel size, padding, stride, dilation, group)
  • convolution以外の基本部品と合成部品の構造と機能を理解する(Upsample/Downsample, PixelShuffle, deconvolution, SCSE, MetaFormerBlock etc.)
  • アーキテクチャの基本構造を理解する(Unet, YOLO, FPN, Transformer, etc.)

最初は全ての部品の構造と設計意図を完全に理解している必要はないかもしれません。後述するmodelのsummaryを出力してみて、知らない名前の部品が出てきたらdepthを深くしたり論文、コードを読むなりして理解を深めていけば良いかと思います。肝心なのは、すべての部品は基本部品の組み合わせで構成されているということです。よくわからない部品に遭遇しても、分解すると必ずよく知っている部品に行き当たります。頭でイメージしにくい場合は、紙にアーキテクチャの構成図を書いてみるとイメージしやすくなるかもしれません。

モデル構成のメモの例:

Convolutionの仕組みを解説したブログ:

https://jinglescode.github.io/2020/11/01/how-convolutional-layers-work-deep-learning-neural-networks/

実践編

公開モデルの構造や処理内容を把握するために行なっていることを紹介します。

  • ネットワークの構造を可視化する(torchinfo etc.)
  • 途中の計算結果を取得する(monkeypatch)

ネットワークの構造を可視化する

論文を読んでもいまいち細部がどうなってるのかわからなかったり、自分が正しく理解できているのか不安になることがあります。コードを読めば良いわけですが、timmのように抽象化されたコードだと「具体的にここのレイヤのchannel数はいくつになるんだ?」といった詳細なパラメータを知るには、コードを深く追わないとわからない場合が多いです。そういう時に役立つのかモデルのsummaryの出力です。torchinfoは昔ながらのテキスト出力のツールですが、表示内容を柔軟に変えられるのと、オーバーヘッドが軽いので良く利用しています。

モデル構成の可視化例:

from torchinfo import summary


summary(
        model,
        input_size=(1, 5, 1000, 70),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"],
    )
)

途中の計算結果を取得する(monkeypatch)

モデルの途中の計算結果が必要な場合、summaryでは欲しい情報が得られないことがあります。こういう場合はモデルにmonkeypatchを当てることで途中の計算結果を後から利用できるようにします。BYUコンペではYOLOのanchorを取得するために、モデルのメソッドを動的に入れ替えるハックを使いました。

from typing import Callable


def replace_method(obj, method_name: str, func: Callable):
    """
    Replace a method of an object with a new function.

    Args:
        obj: The object whose method is to be replaced.
        method_name (str): The name of the method to replace.
        func (FunctionType): The new function to set as the method.
    """
    if not hasattr(obj, method_name):
        raise AttributeError(
            f"{obj.__class__.__name__} has no method '{method_name}' to replace."
        )
    bound = func.__get__(obj, obj.__class__)
    setattr(obj, method_name, bound)

monkeypatchに関するdiscussion:
https://www.kaggle.com/competitions/byu-locating-bacterial-flagellar-motors-2025/discussion/579617

応用編

モデルの重みを変えずに処理方法を変える

timmのモデルには元々入力チャネルやクラス数を柔軟に設定する仕組みがありますが、それでもタスクによってはモデルの構造を変えたくなる場合があります。ここでは学習ずみ重みを変えずに、その処理方法のみ変えるハックについて紹介します。

最初に知っておくべきなのは、モデルには重みとその処理方法(コード)の2つで構成されているということです(このルールは現在OSSで主流となっているpytorchやTensorflowなどの多くのフレームワークで共通しています)。そして、重みは変更しにくいが処理方法は比較的自由に変更できるというのがポイントです。

多くのKagglerはこの性質を利用して重みの処理方法をハックする方法を身につけています。有名なのは、stemと呼ばれる最初のconvレイヤーのstrideやpaddingを変更するテクニックです。これらのパラメータは同じ重みの処理方法を変えているだけです。

例えば、Yaleコンペでbartleyが公開したモデルでは、以下のカスタマイズが加えられていました。

  • stemのstride, paddingの変更
  • 1st blockの最初のlayerにpoolingを挿入
import torch.nn as nn


def update_stem(backbone):
    m = backbone
    m.stem.conv.stride = (4, 1)
    m.stem.conv.padding = (0, 4)
    m.stages_0.downsample = nn.AvgPool2d(kernel_size=(4, 1), stride=(4, 1))
    m.stem = nn.Sequential(
        nn.ReflectionPad2d((0, 0, 78, 78)),
        m.stem,
    )

これらの変更は、(1000, 70)という入力データの極端なアスペクト比の違いを矯正するために加えられていました。

アスペクト比が1:1に近い場合でも、画像の解像度を上げるためにstemのstrideを減らすトリックがよく使われます。

知識編で出てきたconvolutionの仕組みを理解しているとストライドやpaddingに具体的にどのような値を指定すべきかを自分で設計できるようになると思います。

コンペ中にBrtleryが公開した最強notebook:
https://www.kaggle.com/code/brendanartley/caformer-full-resolution-improved

アーキテクチャを自分で設計・改善する

タスクによっては公開モデル部分のみで完結しない場合があるので、そういう時は自分で足りない部品を作る必要があります。似たタスクの論文や過去の類似コンペの公開モデルを参考にしたりします。
自分で実装するとなるとチャネル数、次元、ストライドなど詳細なパラメータを正確に知る必要があるので、論文、コード、summary出力、手書きメモを駆使してなるべくモデルの構造を正確に把握します。
経験的に、ニューラルネットワークはナイーブな設計でもそこそこ意味のあるパフォーマンスが出るので、まずは誰でも思いつきそうな簡単なアーキテクチャを試し、そこから試行錯誤を重ねてアーキテクチャをリファインしていきます。

高速化編

最後に、最近遭遇したハマりどころを紹介しつつ、自分がよく行っている高速化のTipsについて紹介します。

torch.compile()

torch.compileは自動微分で生成する微分グラフを最適化することで推論、学習を高速化します。直近ではYaleコンペで「1000回のforループのコンパイルが非常に遅い」というのがありましたが、これは自動微分を有効化した状態でforループを回すと、微分グラフが非常に大きくなるためです。1step分のみcompileすることで現実的なコンパイル時間に収めるトリックが紹介されました。また、このトリックを50~100ステップに拡張してさらに7倍程度の高速化が実現できることがわかりました。他のタスクに転用可能な知見かは分かりませんが、コンペ中にそれなりに時間を割いて実装したconv2dやFFTの実装よりも10倍くらい早くなることがわかったので、手動でナイーブに最適化するより自動に任せた方が良いという教訓を得ました...。

torch.compileの高速化の効果に関するdiscussion:
https://www.kaggle.com/competitions/waveform-inversion/discussion/587506#3239808

AMP (Automatic Mixed Precision)

AMP(自動合精度)学習は、float32精度とそれよりも小さい精度を自動的に切り替える手法です。精度を落とすことでメモリ消費と計算量を削減します。
経験的には1~2倍程度高速化する場合が多いです。GPUによって対応していない場合があるので注意が必要です。例えばAmpair以降のGPUではbfloat16に対応していますが、T4インスタンスでは対応していないためfloat16を使用します。

Gradient Checkpointing

Gradient checkpointingは微分計算の途中結果を捨てて、要所となる結果だけ残して、必要になった時に都度計算することで、計算時間と引き換えにVLAMの消費を抑える方法です。 計算時間を犠牲にすると言っても、基本的には「計算は軽いがメモリ消費は多い」ブロックを対象にすれば、経験的に1.2-1.3倍程度の犠牲で済む感じです。反対にメモリ消費は1/2くらいになるイメージです。timmのモデルではmodel.set_grad_checkpointing(True)とすることで有効化できます。

Quantization

データの量子化は、場合によってはデータのロードが高速化するので、I/O律速の場合は学習速度が早くなる場合があります。Yaleコンペでは学習の意図もあって8+1 bitのブロック量子化を試みました(なお、元々compute boundなタスクでもあり、量子化誤差も無視できないことから、このアイデアはあまり受けは良くありませんでした)。

データのブロック量子化の例:
https://www.kaggle.com/competitions/waveform-inversion/discussion/587450

その他

Yaleコンペは非常に多くの計算リソースを要求するコンペだったため、コンペ中にも学習の高速化のTipsについて議論されていました。

学習の高速化についてのDiscussion:
https://www.kaggle.com/competitions/waveform-inversion/discussion/583896

GitHubで編集を提案

Discussion