🔰
SageMaker における model_fn の実行タイミングと初期化処理の役割
はじめに
SageMaker の推論処理を少しずつ学んでいます。
その中で「なるほど、そういう仕組みなのか」と気づいた点があったため、備忘録としてまとめておきます。
model_fn は最初に1度だけ呼び出される
SageMaker(PyTorch)で推論処理を行うとき、以下の4つの関数が順に呼ばれる仕組みになっています。
| 関数名 | 役割 |
|---|---|
model_fn |
モデルのロード(初回のみ) |
input_fn |
入力データの前処理(毎回実行) |
predict_fn |
推論(毎回実行) |
output_fn |
出力の整形(毎回実行) |
この中で model_fn は、推論エンドポイントの初期化時に一度だけ実行されるという仕様になっています。
つまり、推論リクエストのたびに毎回実行されるわけではありません。
初期化処理は model_fn に書いておくと効率的
例えば、以下のような処理は model_fn にまとめておくと効率的です。
- モデルのロード
- 時間がかかる前処理(辞書読み込み、正規表現のコンパイルなど)
- 再利用可能な外部リソースの準備
そして、それらをグローバル変数として保存しておけば、他の関数(input_fn / predict_fn / output_fn)からも再利用できます。
毎回同じ処理を繰り返さずに済むため、推論の処理速度を改善することにもつながります。
# inference.py の例(コメント付き)
# グローバル変数としてモデルと前処理器を定義
model = None
preprocessor = None
def model_fn(model_dir):
global model, preprocessor
# 指定されたディレクトリから学習済みモデルを読み込む
model = torch.load(os.path.join(model_dir, "model.pth"), map_location='cpu')
# 推論モードに設定
model.eval()
# 前処理用のクラス(例:トークナイザーなど)を初期化
preprocessor = MyTokenizer()
# SageMaker側の仕様として、読み込んだモデルを返す必要がある
return model
def predict_fn(input_data, model):
global preprocessor
# 入力データを前処理(例:トークン化など)する
inputs = preprocessor.tokenize(input_data)
# 推論時は勾配計算を無効化
with torch.no_grad():
outputs = model(inputs)
return outputs
おわりに
「推論ってどのように動いているのか?」という視点から段階的に調べているところですが、
model_fn の役割と呼ばれるタイミングは、とても大事なポイントだと感じました。
まだ全体像は掴めていませんが、今後も少しずつ理解を深めていきたいと思います。
🔗 関連記事
Discussion