🅿️

研究開発のためのPinjectedによるDependency Injection (Short Version)

2024/12/18に公開

はじめに

こんにちは、CyberAgent AILab リサーチサイエンティストの増井です。
普段、画像生成や画像認識などの機械学習研究を行っています。

この記事はAILab Advent Calendarの12月18日向けとして作成しております!

研究でPythonによる実験コードを書いていると色々と課題に直面しますが、
今回はその解決のために開発したライブラリ(Pinjected)を紹介させていただきたいと思います。
(ロング版も是非)

研究開発で生じる課題

研究開発における機械学習実験では、モデル、オプティマイザ、データセット、学習率やバッチサイズといった多種多様なハイパーパラメータを管理することがよくあります。これらを扱う過程では、以下のような状況が生じがちです。

  • 設定パラメータを集約したcfgオブジェクトが肥大化し、その中にあるパラメータがどこで参照されるか把握しにくくなる
  • モデルやデータセットを切り替えるたびにif文が増え、可読性が低下する
  • 部分的なテストやデバッグを行うには、毎回フルセットアップが必要で、手軽な確認が難しくなる

このような複雑さは、研究を進めていく上で小さな負担を積み重ね、実験の切り替えや新しい手法の試行に時間をかけることにつながります。

Dependency Injection (DI) への関心とpinjected

こうした課題に対し、依存関係を明示的に外部から注入する「Dependency Injection (DI)」と呼ばれる設計手法は有効なヒントとなります。DIを取り入れることで、

  • パラメータを必要な箇所に絞って明示的に注入でき、全体依存を解消しやすくなる
  • if文を減らして、パラメータやオブジェクトの切り替えを外部設定だけで行える
  • 単一コンポーネントだけの試行やテストを容易にし、柔軟な再利用を可能にする

pinjectedは、Python向けのDIツールで、研究開発における実験コード管理を支援しようとするライブラリです。既存のDIツールと比較して、次のような点を目指しています。

  • デコレータや簡易的な関数呼び出しによる、簡潔な定義方法
  • CLI上でのパラメータ上書きや、ローカル設定ファイル(~/.pinjected.py)との組み合わせによる簡易な条件切り替え
  • 単独コンポーネントの実行やテストを手軽に行える実行インターフェース

本記事では、pinjectedの基本的な機能と考え方をできるだけシンプルに紹介します。詳細な設計背景や他ツールとの比較、より高度な機能については、別途用意したLong Versionで取り上げています。

pinjectedの基本的な仕組み

@instanceと@injectedで依存オブジェクトを定義する

pinjectedでは、@instance@injectedというデコレータを用いて依存関係となるオブジェクトを定義します。これらのデコレータを付与した関数は、その関数名をキーとしてDIコンテナに登録され、後から参照可能になります。

from pinjected import instance

# cnn_in_channelsという要求キーに対して1を返す宣言
@instance
def cnn_in_channels():
    return 1

# cnn_in_channelsに依存して、SimpleCNNを返す依存オブジェクト
@instance
def model__simplecnn(cnn_in_channels):
    # 例としてSimpleCNNインスタンスを返す
    # cnn_in_channelsには自動的に1が入ってくる
    return SimpleCNN(in_channels=cnn_in_channels)

ここではmodel__simplecnnという関数名を付けていますが、これは慣例的なネーミングであり、自動的にmodelキーが割り当てられるわけではありません。

この時、cnn_in_channelsはmodel__simplecnnの要求する依存オブジェクトであり、DIコンテナによって自動的に作成、解決されます。

@injectedは、@instanceに似ていますが、関数引数を「DIからの注入対象」と「実行時に指定する引数」に区別しやすくするための機能を備えています。詳細は後述しますが、「固定的なリソースはDIから注入し、変動しやすいパラメータは呼び出し時に渡す」といった使い分けが可能です。

design()で依存をまとめて組み合わせる

design()関数は、依存オブジェクトやパラメータをキーと値でまとめたDesignオブジェクトを作成します。これにより、実験条件のバリエーションを簡潔に切り替えたり、オブジェクトの組み合わせを直感的に記述できます。

from pinjected import design

base_design = design(
    learning_rate=0.001,
    batch_size=64
)

# 別のdesignと組み合わせて実験条件をカスタマイズ
mnist_design = base_design + design(
    model=model__simplecnn,  # model__simplecnnキーの依存オブジェクトをmodelキーへ割り当て
    dataset=dataset__mnist,  # dataset__mnistも同様に割り当て
    trainer=Trainer
)

このように、model__simplecnnmodelキーへ、dataset__mnistdatasetキーへと対応付けられるため、実際にコードを動かすときにはmodeldatasetという分かりやすい名前で参照できるようになります。

CLIやローカルファイルで条件を上書きする

pinjectedは、python -m pinjected run ... --batch_size=32のようなCLI引数や~/.pinjected.pyでの設定上書きをサポートしています。これにより、コードを書き換えなくてもパラメータを即時変更でき、異なる条件での試行を手軽に行うことができます。

@injectedで実行時引数を分離する

@injectedを用いると、DIで注入される要素と、実行時に可変な引数を分かりやすく分離できます。たとえば、大きなモデルを一度だけロードし、promptだけは実行時に指定したいといったケースで便利です。こうした分離によって、重い初期化を何度も行わずに済み、試行のスピードアップにつながる可能性があります。

簡単なコード例

ここでは、modeldatasettrainerを組み合わせ、pinjectedによる実験実行を試すシンプルな例を示します。実際にはSimpleCNNMNISTDatasetなどを定義・インポートする必要がありますが、ここではイメージとしてのサンプルコードになります。

# example.py
from pinjected import instance, design

# 仮のモデル・データセット・トレーナークラス
class SimpleCNN:
    def __init__(self, in_channels, hidden_units):
        self.in_channels = in_channels
        self.hidden_units = hidden_units

    def forward(self, x):
        pass  # 実際のモデル処理は割愛

class MNISTDataset:
    def __init__(self, batch_size):
        self.batch_size = batch_size

    def __iter__(self):
        # ダミーで10バッチ分のイテレータとする
        yield from range(10)

class Trainer:
    def __init__(self, model, dataset, learning_rate):
        self.model = model
        self.dataset = dataset
        self.learning_rate = learning_rate

    def train(self):
        print(f"Training {self.model.__class__.__name__} on {self.dataset.__class__.__name__} "
              f"with lr={self.learning_rate}")
        for batch in self.dataset:
            # 本来はここで学習ステップを実行
            pass

# 依存オブジェクトを定義する関数
@instance
def model__simplecnn():
    return SimpleCNN(in_channels=1, hidden_units=128)

@instance
def dataset__mnist(batch_size):
    return MNISTDataset(batch_size=batch_size)

# トレーナーはオブジェクトではなくクラス自体をdesignで割り当てる想定
# こうすることで、trainer=Trainer とすると、自動的に依存解決される
# (modelとdataset、learning_rateが注入される)

# 基本設定
base_design = design(
    learning_rate=0.001,
    batch_size=64
)

# MNIST + SimpleCNN用のデザインを作成
mnist_design = base_design + design(
    model=model__simplecnn,
    dataset=dataset__mnist,
    trainer=Trainer
)

# run_trainエントリーポイント
@instance
def run_train(trainer: Trainer):
    trainer.train()

# __meta_design__は必須の特別変数で、pinjectedがデフォルトで参照するデザインを示す
# ここではmnist_designをデフォルトとして指定
__meta_design__ = design(
    overrides=mnist_design
)

実行:

# run_trainを実行
python -m pinjected run example.run_train
# modelを表示
python -m pinjected run example.model__simplecnn

上記例では、

  • @instanceで依存オブジェクトを定義し、model__simplecnndataset__mnistといった関数がDIへ登録されます。
  • design()modeldatasettrainerといったキーにこれらを割り当て、mnist_designとしてまとめています。
  • run_traintrainerが注入され、trainer.train()を呼ぶことで実験が実行されます。
  • __meta_design__を定義することで、このファイルをpython -m pinjected run example.run_trainと実行すれば、mnist_designが適用され、SimpleCNNMNISTDatasetを用いたトレーニングが行われます。
  • コマンドラインから--batch_size=32のようなオプションを指定すれば、コード編集なしでバッチサイズの変更が可能です。
  • model__simplecnnなど、@instanceが付与された関数を直接動作確認できます。

この例はあくまで一例ですが、pinjectedを用いることで、条件やパラメータ変更を容易に行えます。さらに高度な記述方法や複雑な条件の管理方法については、ロング版でより詳しく解説しています。

効果的な活用ポイントとベストプラクティス

部分テストや軽量デバッグ

pinjectedを活用すれば、全体起動なしで特定のオブジェクトやモジュールのみを初期化・確認できます。
たとえば、データローダーやモデル出力を単独で検証する専用のエントリーポイントを用意し、python -m pinjected run your_module.run_dataset_checkといったコマンドで即座に実行可能です。これによって、不具合の早期発見や実験サイクルの効率向上が期待できます。

if分岐削減によるコードの明快化

実験条件の切り替えをdesign()CLIオプションで行えるようになると、条件ごとのif分岐が不要になります。
これにより、コード中の条件分岐を最小限に抑え、シンプルで読みやすい構造を維持しやすくなります。

まとめと次のステップ

本記事では、pinjectedを用いた実験コード管理の基本的な考え方や、@instancedesign()による依存オブジェクト定義、CLIオプションによる実行時パラメータ変更などを紹介しました。ここで説明したのはあくまでも基礎的な部分で、pinjectedにはさらに多くの機能が備わっています。

例えば、以下のような機能を活用することで、より柔軟かつ効率的な実験管理が可能になります。

  1. @injectedによる実行時引数分離
    固定的なリソースはDIでロードしつつ、promptseedなど可変要素は呼び出し時に指定できます。

  2. ローカル設定ファイル (~/.pinjected.py)との組み合わせ
    ユーザー固有のパスやAPIキーをコード変更なしで差し替えられ、環境依存情報を安全かつ簡潔に扱えます。

  3. Injected/IProxyによる高度なDSL的表現
    複雑な依存関係や計算ロジックを宣言的に構築し、複数のオブジェクトやパラメータを効率的につなぎ合わせることができます。

  4. 複数エントリーポイントの容易な管理
    run_trainrun_dataset_checkなど、用途別のエントリーポイントをファイル内で整理し、必要な機能だけを素早く起動できます。

  5. CLI上書きとデザイン合成によるバリエーション生成
    design()を+演算子で組み合わせたり、CLIオプションでオーバーライドすることで、学習率やモデル構成などの無数のバリエーションを簡潔に試すことが可能です。

  6. Async対応による非同期依存関係解決
    非同期処理が必要なケースにも対応でき、async関数をDIに組み込んで並列実行をスムーズに行う手段が用意されています。

  7. IDEプラグインによるワンクリック実行
    PyCharm, VSCode(pinjected-runner)と連携することで、pinjectedが定義したエントリーポイントをボタンクリック一発で実行でき、生産性を向上させることが期待できます。

  8. 依存関係の可視化機能
    IDE連携により開発中に「どのオブジェクトが何に依存しているか」をグラフィカルに把握できます。構造を理解しやすくなり、デバッグや拡張計画を立てやすくなります。

これらの機能を適切に組み合わせれば、より豊かな実験管理環境が実現できます。詳細な設計思想や比較、さらなる活用ノウハウはロング版で詳しく解説しています。、pinjectedの魅力や可能性をさらに見ていただけると幸いです!

Discussion