♨️

Sagemaker EndpointとWhisperで文字起こし

2023/03/11に公開

要約

ServerlessInference(サーバレスエンドポイント)、AsyncInference(非同期エンドポイント)を用いて従量課金(利用時のみ課金)な機械学習endpoint設計を実現。
OpenAIのWhisperを動かした。
比較的新しいサーバレス推論だが、
23年3月時点だと下記点がデメリットに該当するので、今後に期待。

  • GPUが使えない
  • 推論タイムアウトが60秒
  • コールドスタート

読者対象

  • AWS SageMakerを使った/使いたい人
  • 機械学習モデルのデプロイ設計に悩んでいる人
  • 文字起こしAIのWhisperを使ったことがある/使う予定がある人

背景、目的

2022年に登場した文字起こしAIであるWhisperをサーバレス環境下で実行したいと思ったことが、本記事のきっかけでした。AWS環境下において、DeepLearningモデルを実行したい場合、コンピューティングリソースであるEC2のインスタンスを立てることが定石だと思います。一方でEC2インスタンスを常時稼働させると、それなりのランニングコストが発生したりサーバー管理のデメリットがあります。特に今回扱うWhisperで可能な文字起こしを四六時中行いたい機会は稀ですので、従量課金に稼働させる方法を探ります。またWhisperは計算に時間がかかるモデルですので、可能であればGPUを使いたいです。
まとめると要件は3点です。

  • オンデマンドは嫌
  • フルマネージドにデプロイしたい
  • GPU使いたい

AWS SageMakerの概要

AIモデルの学習、推論、開発とトータルサポートをしてくれるAWSのサービスです。ノーコードでの開発等も行えるようですが、pythonが分かる人であれば個人的にはノートブックインスタンスがおすすめです。こちらはEC2ベースでjupyter notebook環境とストレージ(EBS)を提供してくれます。osがAmazonLinux2になることを許容できるのであれば、環境構築に時間を取られず分析や開発コーディングを始められるためおすすめです。

またSageMaker Endpointという機能では、あらかじめ作成した学習モデルを用いて推論処理を行い、結果を返す事ができます。GPU含むコンピューティング環境をフルマネージドで提供してくれるのが大きなメリットです。自前で推論サーバーを用意する際に求められる、レイテンシの保証や利用するAPIへの理解といったインフラエンジニアリング稼働を抑えられるのが魅力ですね。

SageMaker Endpointで必要なこと

Endpointの種類によらず、必要なものは以下4つです。

  • model.tar.gz
    • ロードしたい学習モデルをおいておきます
  • コンテナイメージ
    • 実行環境をコンテナで指定できます。今回はWhisperで必要なFFMPEGを簡単にインストールできるubuntuイメージを利用します。デフォルトのAmazonLinux2だと少しインストール大変ですが、以下の記事が参考になると思います。 Lambdaでの音声処理を完全攻略した
  • entry_point.py
    • 入力->モデル読み込み->推論->出力の一連の役割を記述するコード
  • boto3を扱えるpyコード
    • APIgatewayと組み合わせる方法もありますが、今回はboto3クライアントからエンドポイントにリクエストします。

entry_point.pyの内部処理についても記載します。

  • input_fn()
    • ペイロード(リクエスト時に含まれるデータのこと)を受け取って処理する関数。今回は文字起こししたい動画のS3URIを受け取ってダウンロードし、tmp/直下に保存します。
  • model_fn()
    • model.tar.gzにあるモデルをロードします。
  • predict_fn()
    • model_fnから受け取ったモデルで推論。今回は推論部分を簡略にしたかったのでWhisperのtranscribeメソッドをそのまま流用してます。
  • output_fn()
    • 最後の処理を担当。今回は文字起こししたてテキストを返します。(お好みでS3にtxtファイルで保存してね。)

実際に自分がWhisperを動かすうえで設定した内容を下記に記載します。

entry_point.py
def model_fn(model_dir:str, context=None):
    """Loads a model. For PyTorch, a default function to load a model cannot be provided.
        Users should provide customized model_fn() in script.

        Args:
            model_dir: a directory where model is saved.
            context (obj): the request context (default: None).

        Returns: A PyTorch model.
    """
    return torch.load(model_dir)

def input_fn(input_data:str,content_type,context=None):
    """A default input_fn that can handle JSON, CSV and NPZ formats.

        Args:
            input_data: the request payload serialized in the content_type format
            content_type: the request content_type
            context (obj): the request context (default: None).

        Returns: input_data deserialized into torch.FloatTensor or torch.cuda.FloatTensor depending if cuda is available.
    """
    bucket_name = s3_path.split("/")[2]
    object_name = "/".join(s3_path.split("/")[3:])
    
    s3 = boto3.client('s3')
    s3.download_file(bucket_name, object_name, '/tmp/target.mp4')
    
    
    
    return "/tmp/target.mp4"
    
    
def predict_fn(data:str, model, context=None):
    """A default predict_fn for PyTorch. Calls a model on data deserialized in input_fn.
        Runs prediction on GPU if cuda is available.

        Args:
            data: input data (torch.Tensor) for prediction deserialized by input_fn
            model: PyTorch model loaded in memory by model_fn
            context (obj): the request context (default: None).

        Returns: a prediction
    """
    result = model.transcribe(data)
    return result["text"]

    
def output_fn(prediction, accept, context=None):
    """A default output_fn for PyTorch. Serializes predictions from predict_fn to JSON, CSV or NPY format.

        Args:
            prediction: a prediction result from predict_fn
            accept: type which the output data needs to be serialized
            context (obj): the request context (default: None).

        Returns: output data serialized
    """
    return json.dumps(prediction),accept #直接返すことにする

SageMaker Endpointについて

今回は代表的な3つを紹介します。
各項目ごとの比較表を掲載しますが、まとめると以下3点です。

  • 高速なレスポンス必要 -> リアルタイム推論
  • 従量課金でGPUが欲しい -> 非同期推論
  • CPUのみでOK & 従量課金で構成したい -> サーバレス推論 or AWS Lambda

Serverless Inferenceを使ってみた。

2021年12月に発表されたのがSageMaker Serverless Inferenceです。アイドルタイムへの課金も発生しません従量課金型です。これだけ見ると、予測モデルをアプリケーション実装する上では最適解のような気がしますが、今回の実装を進める中で、以下の弱点が見えました。

  • GPUが使えない
  • 推論実行時間(ランタイム)が最大で60秒(1分) <-短すぎる!?
  • コールドスタート
  • Endpointに使用するイメージサイズは10GBまで
  • 非同期推論InvokeEndpointAsyncが使えない


    あれ。。。?AWS Lambdaでよくね?
serverless_inference.py
from sagemaker import Model
from sagemaker.predictor import Predictor
from sagemaker.serverless import ServerlessInferenceConfig

###########################
"""
サーバレスエンドポイントのデプロイ
"""
###########################
model = Model(
    image_uri='YOUR CONTAINER IMAGE URI',
    model_data='YOUR S3 URI of model.tar.gz',
    entry_point='entry_point.py',
    source_dir='./code',
    role=sagemaker.get_execution_role(),
    name="YOUR MODEL NAME",
)

serverless_config = ServerlessInferenceConfig(
    memory_size_in_mb=6144,
    max_concurrency=5
)


predictor = model.deploy(initial_instance_count=1,
                         instance_type='ml.m5.xlarge',
                         serverless_inference_config=serverless_config)

###########################
"""
作成したエンドポイントで推論
"""
###########################
serverless_result = Predictor(endpoint_name="YOUR ENDPOINT NAME").predict(data=f"YOUR DATA")

Async Inference(非同期推論)を使ってみた

こちらは即時結果を返さない非同期タイプになっています。
特にオートスケールが適用できるので、使わないときはインスタンス数を0にできる点が魅力ですね。
0から1にスケールアウトするときにめっちゃ時間かかる点が欠点ではありますが。。。
たまにしか使わない、けどGPUを従量課金で使いたい人におすすめです。



コチラの非同期エンドポイントをsagemakerSDKからリクエストするとエラーを出ます。
issueにもバグとして残り続けてるので、boto3経由でリクエストしましょう。
該当issue

async_inference.py
from sagemaker import Model
import sagemaker
from sagemaker.async_inference import AsyncInferenceConfig

###########################
"""
エンドポイントのデプロイ
"""
###########################

async_model = Model(
    image_uri = "YOUR IMAGE URI",
    model_data='YOUR MODEL`s S3 URI',
    entry_point='entry_point.py',
    source_dir='./code',
    role=sagemaker.get_execution_role(),
    name="YOUR MODEL NAME",
)

async_inference_config = AsyncInferenceConfig(
    output_path="YOUR OUTPUT S3 URI",
)


predictor = async_model.deploy(initial_instance_count=1,
                               instance_type = "ml.g4dn.8xlarge",
                               async_inference_config=async_inference_config,
                              )

###########################
"""
作成したエンドポイントで推論
"""
###########################
sagemaker_runtime = boto3.client("sagemaker-runtime", region_name='YOUR REGION NAME')

response = sagemaker_runtime.invoke_endpoint_async(
                            EndpointName=async_endpoint_name,
                            InputLocation="YOUR INPUT DATA S3 URI")

Async Inferenceにオートスケーリングを適用する

公式ドキュメントに従って、非同期エンドポイントにオートスケールを設定します。
使っていないときには、インスタンス0にすることで疑似的な従量課金を実現します。
(インスタンス数0から1にスケールアウトするときに結構時間かかる難点があります。。。)

EC2 Scaling PolicyにはないApproximateBacklogSizePerInstanceという監視指標を設定できます。現在起動中の1インスタンス当たりどれだけリクエストが残っているかどうかを追跡するもので、任意の整数値をTargetValueとして与えます。
AWS公式ドキュメントのコピペで設定可能なので、本記事での記載は省略します。
ApproximateBacklogSizePerInstanceについて
Auto Scaling設定のコード内容

最後に

SageMaker Endpointについて日本語でまとまっている記事が少ないように感じたので執筆しました。一通りエンドポイントの実行までの流れをつかんでもらえたかなと思います。
個人的には、応答性悪いですけど従量課金でGPUが使える非同期エンドポイントの知見を得られたことが収穫でした。



23年3月1日にWhisperAPIが登場したため、自前でデプロイする必要もなくなりましたが、本記事が誰かの参考になれば幸いです。

Discussion