🔰
SageMaker の推論処理を理解する:4つの関数の役割まとめ(PyTorch)
はじめに
推論とは何かと言われると、
「学習済みモデルを使って予測を行うこと」ということはわかったものの、
SageMaker上で実際にどんな処理が行われているのかについては、まだ把握しきれていませんでした。
今の自分の理解を整理してみます。
(本記事は、 PyTorch 用の inference.py
を用いた場合の処理を前提にしています。)
推論処理の流れ(SageMaker PyTorch モデルサーバーの場合)
SageMaker では、推論処理が次のような 4つの関数に分かれて呼び出されます。
model_fn(model_dir)
1. 学習済みモデルを ロードする関数。
最初に呼ばれ、推論で使うモデルを準備します。
例えば torch.load()
などを使ってモデルファイルを読み込みます。
input_fn(request_body, request_content_type)
2. リクエストデータを 逆シリアル化する関数。
リクエストボディを PyTorch Tensor などに変換します。
JSONやCSVなど、受け取るフォーマットに応じて変換処理を行います。
predict_fn(input_object, model)
3. モデルにデータを渡して 予測を行う関数。
model(input_object)
のように使います。
output_fn(prediction, response_content_type)
4. 予測結果を シリアル化する関数。
返す形式(例:JSONなど)に応じて整形し、APIレスポンスとして返します。
まとめ
SageMaker の推論は、「モデルの読み込み → 入力変換 → 推論 → 出力変換」の順に処理されます。
inference.py
の中でこの4つの関数を定義しておくと、SageMaker側で自動的に呼び出してくれる仕組みです。
おわりに
「推論とは何か」について、まだ十分に腑に落ちていないところもありますが、
処理の流れを知ることで、全体像が少しずつ見えてきたように感じます。
引き続き、一つ一つ理解を積み重ねていきたいと思います。
🔗関連記事
Discussion