🌲

SageMakerでLightGBMモデルをデプロイしてみた

2024/04/15に公開

はじめに

Boto3とAWSが提供しているPyTorchコンテナイメージを利用してLightGBMのモデルをサーバーレスエンドポイントとしてデプロイしてみたときのメモになります。

動作環境

  • python-vesion: 3.11.6
  • boto3: 1.34.69
  • sagemaker: 2.214.0
  • lightgbm: 4.3.0
  • pythonの実行はローカルPCで行いました

モデルの準備

LightGBM回帰モデルをpickleファイル形式で保存しておきます。
以下は保存する際のイメージです。

train.py
import pickle
from pathlib import Path
from lightgbm import LGBMRegressor


# モデルを学習する
model = LGBMRegressor(**params)
model.fit(...)

# モデルを保存する
with Path("model.pkl").open("wb") as f:
    pickle.dump(model, f)

なお、今回は5分割交差検証で得た5つのモデルを保存し、推論時にはこれらの平均値を出力するようにしました。

ただし、マルチモデルエンドポイントでこれらのモデルをホスティングするのではなく、シンプルな1つのエンドポイントが平均値を出力するようにしました。

依存関係の準備

PyTorchコンテナイメージはrequirements.txtに必要なパッケージを記載しておけば、エンドポイントにおける推論時にそのパッケージを使用することができます。

しかし、コンテナにおけるLightGBMのインストールに時間がかかってしまうことで、コンテナの起動後4分以内にヘルスチェックに応答しなければならないというSageMakerの制約に引っかかってしまうため、requirements.txtにlightgbmと記載するだけではデプロイ時にエラーになってしまうということが起きました。

そこで、Python Wheel形式でパッケージをインストールして、その時間を短縮させることにしました。
今回はPyPIから"lightgbm-4.3.0-py3-none-manylinux_2_28_x86_64.whl"をダウンロードしました。

requirements.txtの中身は以下になります。
"/opt/ml/model/code/~"については後述します。

requirements.txt
/opt/ml/model/code/lightgbm-4.3.0-py3-none-manylinux_2_28_x86_64.whl

推論コードの準備

推論コード(inference.py)はデフォルトで用意されており、既存のPyTorchコンテナイメージでは、PyTorchモデル(.pth形式のファイル)を入力とした推論コードになっていますが、デフォルトの推論コードにおける関数と同名の関数を開発者側で用意すれば、それが実行されるようになっています。

今回は、LightGBMモデルを使用することとモデルへのデータの入力をPandasのDataFrameで行いたかったため、model_fn、input_fnをこちらで推論コード上に定義しました。

また、最終的な推論値を5つの推論モデルの平均値にしたいので、predict_fnもこちらで再定義しました。

inference.py
import json
import os
import pickle
from pathlib import Path

from lightgbm import LGBMRegressor
import numpy as np
import pandas as pd

# モデルの読み込み
def model_fn(model_dir: str) -> list[LGBMRegressor]:
    models: list[LGBMRegressor] = []
    for filename in os.listdir(model_dir):
        if filename.endswith(".pkl"):
            with Path(model_dir, filename).open("rb") as f:
                model = cast(LGBMRegressor, pickle.load(f))
                models.append(model)
    return models

# 入力データの変換
def input_fn(req_body: str, req_content_type: str)->pd.DataFrame:
    if req_content_type == "application/json":
        req = json.loads(req_body)
        df_input = pd.DataFrame(
            columns=[
                "column1",
                "column2",
            ],
        )
        df_input.loc[len(df_input)] = [
            req["column1"],
            req["column2"],
        ]

        # カテゴリ変数をcategory型に変換する
        for col in df_input.columns:
            if df_input[col].dtype == "O":
                df_input[col] = df_input[col].astype("category")
        return df_input
    raise ValueError("Illegal content type.")

# 予測
def predict_fn(input_data: pd.DataFrame, models: list[LGBMRegressor]) -> np.ndarray:
    predictions = []
    for model in models:
        model_predictions = model.predict(input_data)
        predictions.append(model_predictions)
    return np.mean(predictions, axis=0)

準備物のS3へのアップロード

モデル、推論コード、依存関係を以下の通りに構成して、tar.gz形式に圧縮します。圧縮したファイルをS3にアップロードします。今回は手動でアップロードしていました。

inference_pytorch
|- model_fold_1.pkl
|- model_fold_2.pkl
|- model_fold_3.pkl
|- model_fold_4.pkl
|- model_fold_5.pkl
|- code/
  |- inference.py
  |- lightgbm-4.3.0-py3-none-manylinux_2_28_x86_64.whl
  |- requirements.txt
準備物の圧縮
$ tar czvf inference_by_pytorch.tar.gz -C inference_by_pytorch .

サーバーレスエンドポイントとしてデプロイ

以下のsample_deploy.pyを実行して、モデルとエンドポイント設定のSageMakerへの保存、エンドポイントの起動を行います。

モデルの作成(sagemaker_client.create_model関数)において、"ModelDataUrl"で、S3上の準備物の保存場所を指定します。ここで指定した準備物がPyTorchコンテナの"/opt/ml/model"に展開されます。

なお、inference.pyで定義したmodel_fn関数の引数であるmodel_dirには、この"/opt/ml/model"が入力されます。そのため、準備物を圧縮する際には、展開したときにディレクトリが作られずに中身のファイルのみが直に展開されるように圧縮する必要があります。

また、既存のPyTorchコンテナイメージは、用意されている環境変数を使って推論コードの保存場所等を設定します。

環境変数の1つである"SAGEMAKER_PROGRAM"で推論コードのファイル名を指定します。

さらに、"SAGEMAKER_SUBMIT_DIRECTORY"を使って、推論コードを格納するディレクトリを指定します。このディレクトリ直下にrequirements.txtを配置しておくと依存関係をインストールしてくれます。今回は、Python Wheel形式でインストールすること、推論コードを格納するディレクトリを"/opt/ml/model/code"としていることから、requirements.txtに上記の通り記載しました。準備物を上記の通りに構成したのもこのためです。

sample_deploy.py
import os

import boto3
from sagemaker import image_uris

boto3_session = boto3.Session(profile_name="AWS_PROFILE_NAME")
sagemaker_client = boto3_session.client(
    service_name="sagemaker",
)

container_image_uri = image_uris.retrieve(
    framework="pytorch",
    region="ap-northeast-1",
    version="2.1",
    py_version="py310",
    image_scope="inference",
    instance_type="ml.m5.large",
)

sagemaker_client.create_model(
    ModelName="sample-model",
    PrimaryContainer={
        "Image": container_image_uri,
        "ModelDataUrl": "s3://AWS_S3_BUCKET_NAME/inference_pytorch.tar.gz",
        "Environment": {
            "SAGEMAKER_CONTAINER_LOG_LEVEL": "20",
            "SAGEMAKER_PROGRAM": "inference.py",
            "SAGEMAKER_REGION": "ap-northeast-1",
            "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
        },
    },
    ExecutionRoleArn="SAGEMAKER_EXECUTION_ROLE_ARN",
)

sagemaker_client.create_endpoint_config(
    EndpointConfigName="sample-model-endpoint-config",
    ProductionVariants=[
        {
            "ModelName": "sample-model",
            "VariantName": "AllTrafic",
            "ServerlessConfig": {
                "MemorySizeInMB": 1024,
                "MaxConcurrency": 1,
            },
        },
    ],
)

sagemaker_client.create_endpoint(
    EndpointName="sample-model-endpoint",
    EndpointConfigName="sample-model-endpoint-config",
)

エンドポイントへのリクエスト

以下のsample_invoke_endpoint.pyを実行してレスポンスとして推論値が返ってくることを確認しました。

sample_invoke_endpoint.py
import json
import os

# Boto3のSageMakerランタイムクライアントを取得する
boto3_session = boto3.Session(profile_name="AWS_PROFILE_NAME")
sagemaker_runtime_client = boto3_session.client(
    service_name="sagemaker-runtime",
)

# 推論エンドポイントを呼び出す
response = sagemaker_runtime_client.invoke_endpoint(
    EndpointName="sample-model-endpoint",
    ContentType="application/json",
    Body=json.dumps({
        "Column1": "value1",
        "Column2": 20,
    }),
    Accept="application/json",
)

# 推論結果が返ってくることを確認する
result = json.loads(response["Body"].read().decode())
print(result)

おわりに

AWSは、まずはAWSが提供しているコンテナイメージを拡張する方針で検討することを推奨しており、コンテナイメージに制限を感じなければ、こちらのデプロイ方法で大体良いのかなと思います。データ分析やモデルの学習にできるだけ専念したいですしね。

参考

NCDCエンジニアブログ

Discussion