OpenTelemetryを使ってSpanをW&Bに記録する
はじめに
機械学習 (ML) 分野において、モデルの精度評価や実験のトラッキングに Weights & Biases (以降、W&B) がよく使われます。少し複雑なことをしようとすると、複数のモデルを組み合わせたパイプラインを構築する必要があります。ML パイプラインの開発においては、処理速度やメモリの使用量も重要な指標です。例えば、モデルの更新で大きな精度向上を得られたとしても、処理速度が要求を満たせなければ、そのモデルを採用することはできません。
メモリの使用量などシステムのメトリクスは W&B 側で自動的に記録されます。一方で、各モデルの処理速度など、アプリケーションのメトリクスは、wandb.log()
などを用いて自分で記録する必要があります。開発時には、wandb.log()
を挿入して値を記録すれば良いのですが、運用時には同じメトリクスを取るために、別の実装を追加しなければなりません。同じことをするコードを、2度書くことは保守の観点でも避けたいものです。
何かいい方法はないかと考えていたところ、OpenTelemetry (OTel) を活用することを思いついたので、その方法を紹介します。
今回想定する ML パイプライン
今回は以下の構成のパイプラインを例に実装しました。Span が今回計測対象のブロックです。
パイプラインの構成
3 つのモデルで構成されますが、今回は簡単のため実際のモデルではなく、ランダムな処理時間を持つクラスで置き換えています。処理時間は平均mean
分散 std
の正規分布に従う乱数で決定されます。
class DummyModel(torch.nn.Module):
def __init__(self, name: str, mean: float, std: float, seed: int = 12345) -> None:
super().__init__()
self.name = name
self.mean = mean
self.std = std
self.rng = np.random.default_rng(seed)
def forward(self, x: torch.Tensor) -> torch.Tensor:
duration = self.rng.normal(loc=self.mean, scale=self.std)
logger.info(f"Sleep for {duration:.2f}sec")
time.sleep(duration)
return x
WandBSpanmetricsExporter
の実装
カスタム Exporter OpenTelemetry には、Span Metrics Connector
という、span を duration などの metrics に変換して出力する機能があります。これに近いものを W&B 用に実装することで、運用時と近いメトリクスを W&B に記録できるようにします。
具体的には、Python SDK における Exporter のベースクラス SpanExporter
を継承して、トレースを受け取り、メトリクスに変換して wandb.log()
でメトリクスを W&B に記録します。実験用には外部の依存が少ない方が使い勝手が良いだろうということで、外部の OTel Connector などに送信せず、すべてローカルで処理を完結させます。
具体的な実装はこんな感じです。
from typing import Sequence
import wandb
from loguru import logger
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult
IS_ROOT_SPAN_KEY_NAME = "is_root_span"
class WandBSpanmetricsExporter(SpanExporter):
"""Convert spans to metrics and log them with `wandb.log()`."""
def __init__(self, commit_evry_call: bool = False) -> None:
"""_summary_
Args:
commit_evry_call (bool, optional): Set True when you have a only one span in every iteration.
If not and you have multiple spans including nested spans, set False and set `is_root_span` attribute
to spans that end in the last to increment 'step' in wandb. Defaults to False.
"""
super().__init__()
self.commit_evry_call = commit_evry_call
def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult:
for span in spans:
# Get duration from the span
duration = (span.end_time - span.start_time) / 1e6 # Duration in milliseconds
if self.commit_evry_call is True:
commit = True
else:
commit = span.attributes.get(IS_ROOT_SPAN_KEY_NAME, False)
metric = {f"otel/span/{span.name}/duration/ms": duration}
# Get other attributes from the span
for key, value in span.attributes.items():
if key == IS_ROOT_SPAN_KEY_NAME:
continue
metric[f"otel/span/{span.name}/{key}"] = value
# Log the metric
logger.debug("span: name={}, duration={:,}[ns]", span.name, duration)
wandb.log(metric, commit=commit)
# Return SUCCESS if successful
return SpanExportResult.SUCCESS
def shutdown(self) -> None:
logger.info(f"Shutdown {self.__class__.__name__}")
SpanExporter
は、export(spans)
と shutdown()
の 2 つのメソッドを持ちます。export()
メソッド今回の実装の要で、外部にデータを送信する部分です。今回は span のリストを受け取り、span の duration (単位: millisecond) を計算して、W&B に記録します。
実装のポイントは、commit_evry_call
という引数です。wandb.log()
はステップ毎に一回だけ呼ばれることが基本的な想定です。span に階層構造がある場合でも、span 自体は個別のオプジェクトとして export()
に渡されます。その場合、span ごとにステップ数が更新され、同じステップのメトリクスがバラバラに記録されてしまいます。今回の場合、入力フレームを処理する親 span の中に各モデルの推論という子 span が 3 つあるという階層構造がありますが、個別にステップ数が更新されてしまうため、処理したフレーム数の 4 倍 (= 親 + 子 x3) のステップ数が記録されてしまいます。
この問題を解決するために commit_evry_call
という引数を追加しました。デフォルト値はFalse
です。span の取得時に、最上位の親の span にis_root_span=True
という属性を追加します。この span がステップの中で最後に終わるため、export()
メソッドは is_root_span
属性を随時チェックし、is_root_span=True
の時のみ wandb.log(data, commit=True)
とすることで子 span のメトリクスとあわせて 1 つのステップとして記録します。それ以外は wandb.log(data, commit=False)
とし、ステップの更新を行いません。
shutdown()
メソッドは、今回は特に実装の必要はありません。
計装の追加 (トレースの収集)
カスタム Exporter を実装した後は、パイプラインの中でスパンを収集を実装します。基本的なトレースの実装と同じで、tracer に先ほど実装したカスタム Exporter を渡すだけです。
まずは、tracer
を準備します。
resource = Resource.create({"service.name": "pipeline"})
tracer_provider = TracerProvider(resource=resource)
trace.set_tracer_provider(tracer_provider=tracer_provider)
tracer_provider.add_span_processor(
span_processor=BatchSpanProcessor(
# SpanExporterにカスタムExporterを渡す
span_exporter=WandBSpanmetricsExporter()
)
)
tracer = trace.get_tracer_provider().get_tracer(__name__)
次に、パイプラインの中で tracer.start_as_current_span()
を用いて、span を収集します。tracer.start_as_current_span()
環境内で実行された処理が 1 つの span として記録されます。
@click.command()
@click.option("-n", "--num-frames", default=10, help="Number of frames to process")
@click.option("-w", "--wandb-mode", default="offline", show_default=True, help="WandB mode")
def main(num_frames: int = 10, wandb_mode: str = "offline"):
wandb.init(
project=WANDB_PROJECT_NAME,
mode=wandb_mode,
)
# Init dummy models
model1 = DummyModel(name="model1", mean=0.1, std=0.001)
model2 = DummyModel(name="model2", mean=0.2, std=0.010)
model3 = DummyModel(name="model3", mean=0.3, std=0.100)
# Run the pipeline
for i in range(num_frames):
logger.info(f"Processing frame {i}")
# 親 Span
with tracer.start_as_current_span("process_single_frame") as root_span:
x = torch.randn(1, 3, 100, 100)
# 子span
with tracer.start_as_current_span("model1"):
x = model1(x)
with tracer.start_as_current_span("model2"):
x = model2(x)
with tracer.start_as_current_span("model3"):
x = model3(x)
# wandbのstepを更新するためのフラグ
root_span.set_attribute("is_root_span", True)
# 現在のフレームに関する情報をspanに追加
root_span.set_attribute("frame_idx", i)
root_span.set_attribute("num_detecton", num_detection_generator(i))
trace.get_tracer_provider().shutdown()
wandb.finish()
参考ですが、model_1
には 物体検出モデルが入ることを想定し、検出されたオブジェクト数を span の属性に記録しています。後続のmodel_2
とmodel_3
には、オブジェクト毎に処理を行うモデルが想定され、これらのモデルの処理時間はそのオブジェクト数に依存するため、非常に重要な値です。
実行結果
パイプライン全体の実装は、GitHub を参照してください。
実行方法は以下のとおりです。
# Clone the repository and Install dependencies
git clone git@github.com:getty708/mlops-sandbox.git
cd mlops-sandbox
poetry install
# Run the sample pipeline
python services/monitoring/tools/run_otel_in_wandb.py -n 100 -w online
これにより 100 フレーム分の Span が生成され、以下の用に W&B にメトリクスが記録されます。model_1
はおよそ 100ms 付近で分散は小さく、model_3
は 300ms を中心に大きな分散を持ちます。事前に設定した正規分布の値が確認できました。また、num_detection
は sin 波をもとに生成されており、それを確認することができます。x 軸も全てのグラフで揃っていることも確認できました。
W&B ダッシュボードで記録されたメトリクス
まとめ
各モデルの処理時間を、OTel のカスタム Exportor を実装することで、W&B に記録することができるようになりました。次は運用環境を想定し、同じトレースとメトリクスを外部のシステム (Jeager や Grafana) で長期間にわたって監視できることを確認していきたいと思います。
Discussion