🔰

SageMaker の推論処理を理解する:4つの関数の役割まとめ(PyTorch)

に公開

はじめに

推論とは何かと言われると、
「学習済みモデルを使って予測を行うこと」ということはわかったものの、
SageMaker上で実際にどんな処理が行われているのかについては、まだ把握しきれていませんでした。

今の自分の理解を整理してみます。
(本記事は、 PyTorch 用の inference.py を用いた場合の処理を前提にしています。)

推論処理の流れ(SageMaker PyTorch モデルサーバーの場合)

SageMaker では、推論処理が次のような 4つの関数に分かれて呼び出されます。

1. model_fn(model_dir)

学習済みモデルを ロードする関数
最初に呼ばれ、推論で使うモデルを準備します。
例えば torch.load() などを使ってモデルファイルを読み込みます。

2. input_fn(request_body, request_content_type)

リクエストデータを 逆シリアル化する関数
リクエストボディを PyTorch Tensor などに変換します。
JSONやCSVなど、受け取るフォーマットに応じて変換処理を行います。

3. predict_fn(input_object, model)

モデルにデータを渡して 予測を行う関数
model(input_object) のように使います。

4. output_fn(prediction, response_content_type)

予測結果を シリアル化する関数
返す形式(例:JSONなど)に応じて整形し、APIレスポンスとして返します。

まとめ

SageMaker の推論は、「モデルの読み込み → 入力変換 → 推論 → 出力変換」の順に処理されます。
inference.py の中でこの4つの関数を定義しておくと、SageMaker側で自動的に呼び出してくれる仕組みです。

おわりに

「推論とは何か」について、まだ十分に腑に落ちていないところもありますが、
処理の流れを知ることで、全体像が少しずつ見えてきたように感じます。

引き続き、一つ一つ理解を積み重ねていきたいと思います。


🔗関連記事

Discussion