SageMaker の推論処理を理解する:output_fn の役割と使い方(PyTorch)
はじめに
これまで、SageMaker における model_fn、input_fn、predict_fn のそれぞれの役割について整理してきました。
今回はその最後のステップとなる output_fn に焦点を当てて、
モデルの推論結果をどのように整形してクライアントへ返すかについて確認します。
個人の学習記録として、今の理解を整理しておきます。
output_fn とは?
output_fn は、モデルの推論結果をレスポンスとして返す形式に変換する関数です。
SageMaker のエンドポイントに推論リクエストを送ると、
最終的にこの output_fn が呼び出され、クライアントに返される出力が生成されます。
def output_fn(prediction, content_type):
# 推論結果(prediction)を content_type に応じた形式で返す
...
例えば、application/json を指定された場合には、
Pythonオブジェクトを JSON 形式に変換して返すことになります。
使用例(PyTorch + JSON の場合)
以下は、典型的な output_fn の例です。
import json
def output_fn(prediction, content_type):
if content_type == 'application/json':
# tensor → list → json 文字列
result = prediction.tolist()
return json.dumps(result)
else:
raise ValueError(f"Unsupported content type: {content_type}")
このように、出力の形式(Tensor → list → JSON)を整える処理を行います。
SageMaker では、デフォルトで 'application/json' が利用されることが多いため、
まずはこの形式に対応しておけば十分な場合が多いです。
呼び出される順序の整理
SageMaker の推論処理は以下の順序で進みます。
リクエスト受信 → input_fn → predict_fn → output_fn → レスポンス返却
このうち output_fn は、predict_fn の出力を元に、
クライアントが期待するフォーマットに変換して返す最後のステップを担っています。
設計時の注意点
-
content_typeを確認し、適切な形式で返す(例:JSON、CSV、textなど) - 推論結果が Tensor の場合は
.tolist()などで変換が必要 - 必要に応じて精度やラベル名の追加など、軽微な後処理をここで行ってもよい
- 未対応の
content_typeに対しては明示的に例外を出すと親切
今の時点での自分の理解(仮)
-
output_fnは「推論結果を返す形に整える」最後の処理 -
predictionはpredict_fnの出力(Tensorなど) -
content_typeはクライアントが求める返却形式(多くは JSON) - Tensor のままでは返せないので、Pythonの list や JSON 文字列に変換する
- 特別な出力処理がなければ、
json.dumps()だけでも十分なことが多い
おわりに
output_fn は、モデルの出力を「クライアントが受け取れる形式」に整形して返す役割を担う関数です。
推論結果をどのように返すかは用途によって異なりますが、
今回整理した内容が基本となる流れだと感じました。
今後は、CSV やバイナリといった他の content_typeにも対応する必要が出てくるかもしれませんが、
まずは JSON をベースに仕組みを理解することで、全体像がつかめてきたように思います。
🔗 関連記事
Discussion