📝

timm.create_modelの実装の解説

2024/03/16に公開

timm.create_modelの実装の解説

目次

  • 概要
  • timm.create_modelとは何か
  • デザインパターンとは何か
  • どのようにFactory Methodパターンを実装するか
  • どうしてFactory Methodパターンを使うのか
  • まとめ

概要

  • timm.create_modelを呼び出したことはあるものの中身を読んだことがない人向けにざっくり中身を見ていきます
  • timm.create_modelにはFactory Methodパターンというデザインパターンが使用されています
  • 機械学習分野だと同じ入出力で中身が異なるアルゴリズムを実装することが多いため、このパターンを知っておくと便利かもしれません

timm.create_modelとは何か

画像認識用のライブラリとしてtimmがあります。ここではtimm(v0.9.12)のうち、timm.crate_modelについて説明します。

以下は呼び出したことはあるけど中身を読んだことがない人向けの説明です。関数自体の引数と返り値は簡潔です。

実装

def create_model(
        model_name: str,
        pretrained: bool = False,
        pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
        pretrained_cfg_overlay:  Optional[Dict[str, Any]] = None,
        checkpoint_path: str = '',
        scriptable: Optional[bool] = None,
        exportable: Optional[bool] = None,
        no_jit: Optional[bool] = None,
        **kwargs,
):
   (中略)
   return model

主な引数

  • model_nameにモデル名のstrを指定します

  • pretrainedに事前学習の有無を指定します

  • **kwargsにモデルごとの引数を指定します

返り値

  • 画像認識モデルが返ってきます
    • ImageNetで事前学習したモデルの場合、入力は画像、出力は1000クラスの確率値となるようなクラスが実装されています。

使用例

timm.create_modelは**kwargsによりモデルごとに固有の引数を使用できます。

以下のように、例えばResNetに対してはreplace_stem_poolViTに対してはqkv_biasなどの変数を渡して初期化できます。

import timm


resnet = timm.create_model("resnet50", replace_stem_pool=True)
vit = timm.create_model("vit_base_patch8_224", qkv_bias=False)

以下のように無効な引数(resnet50に対してqkv_bias)を指定した場合、エラーになります。

timm.create_model("resnet50", qkv_bias=False)

>> TypeError: ResNet.__init__() got an unexpected keyword argument 'qkv_bias'

では、これをどのように実装しているのでしょうか?

素朴にやるならばBaseVisionModelのような名前の基底クラスを定義し、そこからResNetやViTなどの派生クラスを作成します。その後、timm.create_model内でIF文で分岐させたり、Dictで対応付けしたりすることが考えられます。

上記の実装パターンのアイデアをより良くした実装の型(デザインパターン)が、実際の実装には使用されています。

デザインパターンとは何か

デザインパターン

デザインパターンとは、ある問題を解決するための典型的な実装・設計の型のことです。実装様式とモデル構造を対比させて例えると、Linter, Formatterなどの数行単位での実装様式がCNNやReLUならば、デザインパターンはEncoder-DecoderやBackbone-Headにあたります。画像認識モデルでは解像度を落とさずに特徴量を計算したい問題や、最終的にいくつかの値を予測したい問題において、大抵は前述の構造が用いられます。デザインパターンも同様で、現実の特定の問題に対して特定の設計が使えることが分かっています。

ここではそのような概念がある、くらいで大丈夫です。詳細はwikipedia: デザインパターン(ソフトウェア)などがあります。本だとHead Firstデザインパターン 第2版 ―頭とからだで覚えるデザインパターンの基本がメジャーだと思います。ただし第1版はJavaで書かれていた(第2版は不明)ので、Pythonのみ書く人にとっては少し読みにくいかもしれません。

Factory Methodパターン

timm.create_modelにもデザインパターンが使われています。具体的にはFactory Methodパターンが使用されています。実はDocumentにもその旨の記載があります。

It is that simple to create a model using timm. The create_model function is a factory method that can be used to create over 300 models that are part of the timm library.

Factory Methodパターンの詳細な説明はRefactoring.Guru: Factory Methodが分かりやすいと思います。ここではデザインパターン自体の詳細には踏み入らないので、timm.create_modelにより様々なモデルが作られるようになっている。その根底にはFactory Methodパターンというものがある、くらいの認識で大丈夫です。

どのようにFactory Methodパターンを実装するか

ここでは、Factory Methodパターンを具体的に理解するため、timm.create_modelを詳しく見ていきます。

timm.create_model

_factory.pyのcreate_model関数にて定義されています。

Documentを除くと30行程度の関数で、肝となるのは115行~122行目の以下の部分です。

    create_fn = model_entrypoint(model_name)
    with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
        model = create_fn(
            pretrained=pretrained,
            pretrained_cfg=pretrained_cfg,
            pretrained_cfg_overlay=pretrained_cfg_overlay,
            **kwargs,
        )

上記から抜粋して、説明のために単純に書くと以下のようになります。

def create_model(
    model_name: str,
    pretrained: bool = False,
    **kwargs,
):
	create_fn = model_entrypoint(model_name)
	model = create_fn(
    	pretrained=pretrained,
	    **kwargs,
	)
	return model

大まかに、処理は以下のような3段階になっていることが分かります。

  1. model名model_nameを受け取った後、model_entrypointに渡してモデルを作成するための関数create_fnを作成します。
  2. create_fnにモデル共通の引数(ここではpretrained; 事前学習の有無のみ)とモデル固有の引数**kwargsを渡し、画像認識モデルmodelを作成します
  3. modelを返します

model_entorypoint以外はどうなっているか、何となく想像がつくと思います。では、model_entorypointはどうなっているでしょうか?

単純に考えればキーをモデル名、値をモデル作成のための関数とする辞書で良さそうです。

_model_entrypoints

_registry.pyの_model_entrypointsで定義されています。

_model_entrypoints: Dict[str, Callable[..., Any]] = {}  # mapping of model names to architecture entrypoint fns

見ての通り、空の辞書です。ただしTypeHintによって、キーをモデル名、値をモデル作成のための関数とする辞書っぽいことが分かります。

では、どの部分で_model_entrypointsを更新しているのでしょうか? コードを検索して見ると、register_model関数が引っ掛かります。

register_model

_registry.pyのregister_modelで定義されています。

抜粋して単純に書くと、以下のように_model_entrypointsが更新されていることが分かります。

def register_model(fn: Callable[..., Any]) -> Callable[..., Any]:
	model_name = fn.__name__
	_model_entrypoints[model_name] = fn
	return fn

_model_entrypointsの値はモデル作成関数なので、fnがモデル作成関数と分かります。また、キーはモデル名model_nameなので、モデル名としてモデル作成関数の名前をそのまま使用していることも分かります。

ただ、_model_entrypointsを更新するだけならば返り値にfnは渡さなくても良さそうです。どうしてこうなっているのか、実際に使われている箇所を見てみます。

efficientnet_b0での使用例

register_modelは、例えばefficientnet.pyで使用されています。

抜粋して単純に書くと、以下のように使われています。

@register_model
def efficientnet_b0(pretrained=False, **kwargs) -> EfficientNet:
    model = _gen_efficientnet(
        'efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
    return model

上記より、register_modelがデコレータとして使われていることが分かりました。また、efficientnet_b0関数はモデル作成用の関数であり、モデル共通の引数としてpretrainedがあること、モデル固有の引数として**kwargsが設定されていることも確認できました。

もしモデルを追加する場合を考えると、最小限に実装する場合は以下のようなコードになりそうです。

from ._registry import register_model


@register_model
def custom_model_name(pretrained: bool = False, **kwargs) -> CUSTOM_MODEL_CLASS
    model = _create_custom_model(pretrained=pretrained, **kwargs)
    return model

今までのことから、もし新しいモデルを利用したい人がいた場合、頑張ってCUSTOM_MODEL_CLASSをコードから探してimportする必要がないことが分かります。任意のモデルは既存のtimm.create_modelから新しいモデルを呼び出せます。

また、もし新しいモデルを実装したい人がいた場合、新しいモデルの追加は他の部分の実装を変更・修正せずにできることが分かります。モデル固有の引数も他のモデルに左右されずに自由度が高く設計できます。

どうしてFactory Methodパターンを使うのか

利点

Factory Methodパターンを使うことで、以下のメリットがあります

既存の実装を修正せずに新しいモデルを実装し、従来と同じ関数から使用できる

作成可能なオブジェクトの確認・利用が容易

もし仮にモデル作成方法がtimm.create_modelを使わずにモデルのクラスをimportして利用する場合、以下の問題が生じます。

  1. 欲しいモデルを探す際にコードを見る必要があり、使いにくいです。

    • インターフェイスと実装が分離されていません。(インタフェース分離の原則に反しています)

    • もし仮にファイルの場所が変わった場合、モデルをimportしている部分を修正する必要があります。

  2. 追加されたモデルを網羅することは困難です。

    • 実装とは別の何かしらのドキュメントを人手で管理することになります。
    • ドキュメントがない場合、実装されているかいないかをコードから探す必要があります

一方、Factory Methodパターンであれば、上記の問題は解決できます。モデルはすべてtimm.create_modelから呼び出します。また、timm.create_modelから作成可能なモデルは_model_entorypointsのキーを確認すれば分かります。人手によるドキュメントを作らずにコードと結びついた形で管理できます。

(デコレータを使用した場合)新しいモデルを追加する際に既存の実装を変更しない

もし仮にモデルの管理を_model_entorypointsを直接編集して行った場合、モデルを追加するたびに既存の_factory.pyが変更されます。これは並列して複数の開発が行われている場合、コンフリクトを招く可能性が高いです。おそらく解法閉鎖原則に反しています。

一方、デコレータを使用して拡張する場合、モデルの追加は実装の変更を伴わない実装の追加のみで行えます。そのため、コンフリクト等は気にせずに済みます。

欠点

とはいえ万能ではなく、以下のケースでは得られるメリットより実装コストの方が高いと思います。

  • モデルは1つのみを使用する
  • 開発規模が小さい
  • 利用範囲が狭い

例えばドメイン固有の画像認識モデルを1つだけ育てる場合、Factory Methodパターンのメリットであるモデル選択の拡張性は不要です。また、開発規模が小さかったり利用範囲が狭かったりする場合、同様に必要とする拡張性はそこまで高くないことが想定されます。

上記のような場合、Factory Methodパターンは得られる利点よりも実装コストの方が高くなってしまう可能性があります。

とはいえ予期せず発生したトレードオフによりモデルを選びたくなった、当初の想定よりも開発規模が大きくなった、利用範囲が広まった、などのケースは間々想定されます。そのため、最初から使っていた方が結局は早いケースもあります。

まとめ

  • 一つの関数からモデルを作り分けたい、といったケースに対してよく使われている実装の型(デザインパターン)を紹介しました
  • Factory Methodパターンと呼ばれるもので、ここではtimm.create_modelを詳細に見ていきました
  • Factory Methodパターンを使うことで、以下のメリットがあります
    • オブジェクトの生成方法を一つにまとめることで、管理・拡張・利用しやすくなる
    • (デコレータを使用した場合、)新しいモデルを追加する際に既存の実装を変更しない

Discussion