🔰

SageMaker における model_fn の実行タイミングと初期化処理の役割

に公開

はじめに

SageMaker の推論処理を少しずつ学んでいます。
その中で「なるほど、そういう仕組みなのか」と気づいた点があったため、備忘録としてまとめておきます。

model_fn は最初に1度だけ呼び出される

SageMaker(PyTorch)で推論処理を行うとき、以下の4つの関数が順に呼ばれる仕組みになっています。

関数名 役割
model_fn モデルのロード(初回のみ)
input_fn 入力データの前処理(毎回実行)
predict_fn 推論(毎回実行)
output_fn 出力の整形(毎回実行)

この中で model_fn は、推論エンドポイントの初期化時に一度だけ実行されるという仕様になっています。
つまり、推論リクエストのたびに毎回実行されるわけではありません。

初期化処理は model_fn に書いておくと効率的

例えば、以下のような処理は model_fn にまとめておくと効率的です。

  • モデルのロード
  • 時間がかかる前処理(辞書読み込み、正規表現のコンパイルなど)
  • 再利用可能な外部リソースの準備

そして、それらをグローバル変数として保存しておけば、他の関数(input_fn / predict_fn / output_fn)からも再利用できます。

毎回同じ処理を繰り返さずに済むため、推論の処理速度を改善することにもつながります。

# inference.py の例(コメント付き)

# グローバル変数としてモデルと前処理器を定義
model = None
preprocessor = None

def model_fn(model_dir):
    global model, preprocessor
    
    # 指定されたディレクトリから学習済みモデルを読み込む
    model = torch.load(os.path.join(model_dir, "model.pth"), map_location='cpu')
    # 推論モードに設定
    model.eval()  
    
    # 前処理用のクラス(例:トークナイザーなど)を初期化
    preprocessor = MyTokenizer()
    
    # SageMaker側の仕様として、読み込んだモデルを返す必要がある
    return model


def predict_fn(input_data, model):
    global preprocessor
    
    # 入力データを前処理(例:トークン化など)する
    inputs = preprocessor.tokenize(input_data)
    
    # 推論時は勾配計算を無効化
    with torch.no_grad():
        outputs = model(inputs)
    
    return outputs

おわりに

「推論ってどのように動いているのか?」という視点から段階的に調べているところですが、
model_fn の役割と呼ばれるタイミングは、とても大事なポイントだと感じました。

まだ全体像は掴めていませんが、今後も少しずつ理解を深めていきたいと思います。


🔗 関連記事

Discussion