SageMaker Model Monitor を触ってみた
こんにちは、初めましての方は初めまして。株式会社 Fusic の瓦です。本棚に本を詰め込みすぎて本棚が撓んできて、いつ壊れるか戦々兢々としながら日々を過ごしています。
この記事では、SageMaker Model Monitor(以下 Model Monitor) に関して GitHub で公開されているコードを動かしつつ理解を深めていこうと思います。SageMaker Studio の JupyterLab を動かせる環境がある方は実際に動かしてみた方が理解が早いと思いますので、この記事はそのような環境がない、あるけど動かすのが面倒くさい方におすすめです。
SageMaker Model Monitor とは
SageMaker でデプロイしたモデルを監視できる機能で、推論リクエストで投げられたデータの保存や、モデルの精度変化の検知、モデルのバイアスの検知やデータの偏りなど様々な面からモデルの評価を行うことが出来ます。そもそもなぜ監視が必要か、監視のために Model Monitor で何が出来るかのついては AWS BlackBelt で詳しく説明されているのでそちらを見ていただくとして、この記事では実際に Model Monitor を実際に動かして、どのような情報が得られるのかを確認していきます
実際に動かしてみる
Model Monitor では実際にデプロイしたモデルの監視を行います。そのため、
- モデルのデプロイ
- デプロイしたモデルでベースラインの作成
- 疑似的な正解データの作成
- 疑似的な正解データで推論
の順に行っていき、それぞれの結果を確認していきます。また説明の都合上、サンプルのコードから一部を抜粋して説明しています。
準備
SageMaker Python SDK は sagemaker
という名前のライブラリで提供されているので、まずは sagemaker
ライブラリをインストールしておきます(pip 経由であれば pip install sagemaker
でインストールできます)ライブラリのドキュメントはここにあるので適宜参照するとよいです。
モデルのデプロイ
ここでは推論用のモデルをデプロイして API 経由で使用できるようにします。自分で訓練したモデルをデプロイしてもいいのですが、今回はリポジトリで用意されているモデルとイメージを使用します。
まずはモデルを S3 にアップロードし、モデルの推論に使用するイメージを指定して推論環境の設定を行います。その次に、作成したモデルを用いてデプロイを行います。
model_url = S3Uploader.upload("model/xgb-churn-prediction-model.tar.gz", s3_key)
# モデルの設定
model_name = f"DEMO-xgb-churn-pred-model-monitor-{datetime.utcnow():%Y-%m-%d-%H%M}"
image_uri = image_uris.retrieve(framework="xgboost", version="0.90-1", region=region)
model = Model(image_uri=image_uri, model_data=model_url, role=role, sagemaker_session=session)
# エンドポイントの作成
endpoint_name = f"DEMO-xgb-churn-model-quality-monitor-{datetime.utcnow():%Y-%m-%d-%H%M}"
data_capture_config = DataCaptureConfig(
enable_capture=True, sampling_percentage=100, destination_s3_uri=s3_capture_upload_path
)
model.deploy(
initial_instance_count=1,
instance_type="ml.m4.xlarge",
endpoint_name=endpoint_name,
data_capture_config=data_capture_config,
)
Model Monitor を使用したい場合は DataCaptureConfig
の設定を行う必要があることに注意しておきましょう。data_capture_config
を指定しない場合もデプロイ自体は出来ますが、どのようなリクエストが投げられたかが取得しづらく、そのためモデルの監視も難しくなります(もちろんイメージ内で S3 に投げるコードを自分で書いてリクエストを保存してもよいです)DataCaptureConfig
の sampling_percentage
は、その名の通りリクエストのうちどのくらいの割合を保存するかを表しており、デフォルトでは 20
となっています。少なすぎると実際のリクエストの分布が反映しづらくなってしまうのて、ここの値は実際にどのくらいのリクエストが来そうかと監視のためにどのくらいのデータが必要になるかを議論して決めるといいでしょう。
次にエンドポイントに投げられたデータが実際に S3 にどのように保存されるか見てみます。
# S3 にアップロードされたデータを取得
for _ in range(120):
capture_files = sorted(S3Downloader.list(f"{s3_capture_upload_path}/{endpoint_name}"))
if capture_files:
capture_file = S3Downloader.read_file(capture_files[-1]).split("\n")
capture_record = json.loads(capture_file[0])
if "inferenceId" in capture_record["eventMetadata"]:
break
print(".", end="", flush=True)
sleep(1)
print(json.dumps(capture_record, indent=2))
# 実際に書き出されたデータ
# {
# "captureData": {
# "endpointInput": {
# "observedContentType": "text/csv",
# "mode": "INPUT",
# "data": "186,0.1,137.8,97,187.7,118,146.4,85,8.7,6,1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,0.10,0.11,0.12,0.13,0.14,0.15,0.16,0.17,1.1,0.18,0.19,0.20,0.21,0.22,0.23,0.24,0.25,0.26,0.27,0.28,0.29,0.30,0.31,0.32,0.33,0.34,0.35,0.36,0.37,0.38,0.39,0.40,0.41,0.42,0.43,0.44,0.45,0.46,0.47,0.48,0.49,0.50,0.51,0.52,0.53,1.2,1.3,0.54,1.4,0.55",
# "encoding": "CSV"
# },
# "endpointOutput": {
# "observedContentType": "text/csv; charset=utf-8",
# "mode": "OUTPUT",
# "data": "0.01584203727543354",
# "encoding": "CSV"
# }
# },
# "eventMetadata": {
# "eventId": "eacf2c5e-7535-4c50-88a5-2e055bb1f8d0",
# "inferenceId": "0",
# "inferenceTime": "2024-10-12T06:18:53Z"
# },
# "eventVersion": "0"
# }
このように、デフォルトでは入出力の両方について jsonl
形式で S3 にデータが書き込まれます。このデータを用いて、例えば endpointInput
を確認して変な入力が与えられていないか調べたり、endpointOutput
を確認してモデルが変な出力をしていないか調べたりすることが出来ます。
ベースラインの作成
次にベースライン(現在のモデルの推論結果)の作成を行います。ここではベースラインをファイルに書き込んで S3 にアップロードします。この後でこのベースラインを元にしてモデルの変化を確認していきます。ここでのファイルもサンプルに含まれているものを使います。
predictor = Predictor(
endpoint_name=endpoint_name, sagemaker_session=session, serializer=CSVSerializer()
)
churn_cutoff = 0.8
validate_dataset = "validation_with_predictions.csv"
limit = 200 # Need at least 200 samples to compute standard deviations
i = 0
with open(f"test_data/{validate_dataset}", "w") as baseline_file:
baseline_file.write("probability,prediction,label\n") # our header
with open("test_data/validation.csv", "r") as f:
for row in f:
(label, input_cols) = row.split(",", 1)
probability = float(predictor.predict(input_cols))
prediction = "1" if probability > churn_cutoff else "0"
baseline_file.write(f"{probability},{prediction},{label}\n")
i += 1
if i > limit:
break
print(".", end="", flush=True)
sleep(0.5)
上のコードでは単純にエンドポイントに対してリクエストを投げ、その結果とラベル、正解を連結してファイルに書き出しています。boto3
を使ってエンドポイントを叩いてもいいですが、sagemaker
ライブラリを使うとコードがスッキリしていいですね。
churn_model_quality_monitor = ModelQualityMonitor(
role=role,
instance_count=1,
instance_type="ml.m5.xlarge",
volume_size_in_gb=20,
max_runtime_in_seconds=1800,
sagemaker_session=session,
)
baseline_job_name = f"DEMO-xgb-churn-model-baseline-job-{datetime.utcnow():%Y-%m-%d-%H%M}"
job = churn_model_quality_monitor.suggest_baseline(
job_name=baseline_job_name,
baseline_dataset=baseline_dataset_uri,
dataset_format=DatasetFormat.csv(header=True),
output_s3_uri=baseline_results_uri,
problem_type="BinaryClassification",
inference_attribute="prediction",
probability_attribute="probability",
ground_truth_attribute="label",
)
job.wait(logs=False)
次にモデルの品質の評価を行います。上のコードでは suggest_baseline
を使用して、ベースラインの評価について AUC
や F1
、F2
スコアなど分類問題でよく使用される評価指標での評価を行います。
baseline_job = churn_model_quality_monitor.latest_baselining_job
binary_metrics = baseline_job.baseline_statistics().body_dict["binary_classification_metrics"]
print(pd.json_normalize(binary_metrics).T) # 各指標での結果を見たい場合はこっち
print(pd.DataFrame(baseline_job.suggested_constraints().body_dict["binary_classification_constraints"]).T) # 各指標での制約を見たい場合はこっち
binary_classification_constraints の表示結果
以上のように、各指標に対してのベースラインと比較方法を提示してくれます。例えば recall
については LessThanThreshold
であるため、このスコアより下がるとモデルの品質が低下しているということになります。これらは S3 にも json
形式でアップロードされているため、この指標はいつでも確認できます。
疑似的な正解データの作成
def ground_truth_with_id(inference_id):
random.seed(inference_id) # to get consistent results
rand = random.random()
return {
"groundTruthData": {
"data": "1" if rand < 0.7 else "0", # randomly generate positive labels 70% of the time
"encoding": "CSV",
},
"eventMetadata": {
"eventId": str(inference_id),
},
"eventVersion": "0",
}
def upload_ground_truth(records, upload_time):
fake_records = [json.dumps(r) for r in records]
data_to_upload = "\n".join(fake_records)
target_s3_uri = f"{ground_truth_upload_path}/{upload_time:%Y/%m/%d/%H/%M%S}.jsonl"
print(f"Uploading {len(fake_records)} records to", target_s3_uri)
S3Uploader.upload_string_as_file_body(data_to_upload, target_s3_uri)
上のコードでは疑似的なデータを作っています。入力データは変わっていないが、ラベル付けの条件が変わって正解ラベルが変化したという設定ですね。変化したかどうかを確かめるためには正解データが必要となるため、本来は実際に来たリクエストに対して人などがラベリングしたデータが必要となりますが、今回は検知できるかを確かめるためにランダムにラベル付けして疑似的なデータを作成しています。
疑似的な正解データで推論
churn_monitor_schedule_name = (
f"DEMO-xgb-churn-monitoring-schedule-{datetime.utcnow():%Y-%m-%d-%H%M}"
)
endpointInput = EndpointInput(
endpoint_name=predictor.endpoint_name,
probability_attribute="0",
probability_threshold_attribute=0.5,
destination="/opt/ml/processing/input_data",
)
response = churn_model_quality_monitor.create_monitoring_schedule(
monitor_schedule_name=churn_monitor_schedule_name,
endpoint_input=endpointInput,
output_s3_uri=baseline_results_uri,
problem_type="BinaryClassification",
ground_truth_input=ground_truth_upload_path,
constraints=baseline_job.suggested_constraints(),
schedule_cron_expression=CronExpressionGenerator.hourly(),
enable_cloudwatch_metrics=True,
)
上のコードではモニタリングのスケジュールを設定しています。モニタリングは定期的に行った方がいいため、ここでは一時間ごと(毎時 0 分ごと)のスケジュールの設定をしています。
latest_execution = churn_model_quality_monitor.list_executions()[-1]
pd.options.display.max_colwidth = None
violations = latest_execution.constraint_violations().body_dict["violations"]
violations_df = pd.json_normalize(violations)
violations_df.head(10)
検知した変化の一覧
最後に何か変化が起こったを確認します。今回はランダムにラベルを振っているので色々な指標でドリフトが検知されています。例えば auc
は 0.513 ± 0.008 となっており、元の 0.9395 よりも精度が低下していることが分かります。実際にモデルを本番環境にデプロイして時間がたつと、ラベル付けの条件やユーザーからのデータの変化によってデプロイしたモデルは十分な性能が発揮できなくなることが多々あります。そのようなときに、この Model Monitor を定期的に走らせることで、性能低下を検知出来ます。また、監視したい評価指標と「下がっては/上がってはいけないライン」を決めて CloudWatch
を使ってアラートを発生させることも出来ます(これについてもサンプルのノートブックの最後に書かれているのでぜひ目を通してみてください)例えば「Model Monitor で性能低下を検知して CloudWatch でアラートが発生したら、S3 に保存してあるデータで再学習を行って学習後のモデルをデプロイする」というような仕組みを作ることも出来ます。
まとめ
この記事では公開されている SageMaker Model Monitor のコードを、少し分かりづらい部分を補足しながら動かしてみました。Model Monitor に関してはコンソールから見えづらい部分もあり、SageMaker 上で手を出しづらいと思っている方やそもそも機能を知らない方も多いのではないかと思います。この記事がそのような方のために役立てば幸いです。
最後に宣伝になりますが、機械学習でビジネスの成長を加速するために、Fusic の機械学習チームがお手伝いたします。機械学習の PoC から MLOps まですべての場面でサポートした実績があり、ご相談に合わせて様々な支援を行うことが出来ます。もし、困っている方がいましたら、ぜひ Fusic までご相談ください。お問い合わせからでも気軽にご連絡いただけます。また質問など Twitter の DM に対してのメッセージも大歓迎です。
Discussion