🔭

OpenTelemetryを使ってSpanをW&Bに記録する

2024/10/07に公開

はじめに

機械学習を活用して少し複雑なことをしようとすると、複数のモデルを組み合わせたパイプラインを構築する必要があります。ML パイプラインの開発においては、処理速度やメモリの使用量も重要な指標です。例えば、モデルの更新で大きな精度向上を得られたとしても、処理速度が要求を満たせなければ、そのモデルを採用することはできません。また、これらの要求を満たしているか確認するためには、モデルの開発段階でメトリクスとして計測・収集・評価する必要がありますが、運用環境においても同じメトリクスを収集し、運用環境でも想定通りの挙動をしているか把握することが多いと思います。

機械学習 (ML) 分野において、モデルの精度評価や実験のトラッキングに Weights & Biases (以降、W&B) が使われることが増えてきていると思います。モデルの開発段階では W&B など を使用すれば簡単にメトリクスを集めることができます。しかし、プロダクションのシステムにおいては Grafana など別のバックエンドを用いてメトリクスを収集することが求められることがあります。ナイーブに書くと、同じ ML パイプラインで同じメトリクスを収集するためのコードを 2 度書くことが必要になります。DRY 原則的にも保守の観点から見ても、できれば避けたいものです。

mlops
MLOps のサイクルとツール群[1]

何かいい方法はないかと考えていたところ、OpenTelemetry (OTel) を活用することを思いついたので、そのアイデアを試してみました。最終的なに目指すシステムのイメージは以下の画像の通りです。

idea
目指すシステムの全体像

今回想定する ML パイプライン

今回は以下の構成のパイプラインを例に実装しました。想定するのは、全てのモデルが 1 つのプロセスで動く monolithic なパイプラインです。Span が今回計測対象のブロックです。

Dummy Pipeline Architecture
パイプラインの構成

3 つのモデルで構成されますが、今回は簡単のため実際のモデルではなく、処理時間が正弦波で変化するダミーモデルを使用します。

class DummyModel(torch.nn.Module):
    def __init__(  # noqa: R0917
        self,
        amp: float = 10,  # milliseconds
        interval: float = 200,
        beta: float = 100,  # milliseconds
    ) -> None:
        super().__init__()
        self.sin_wave_fn = lambda x: amp * np.sin(2 * np.pi * x / interval) + beta
        self.call_count: int = 0

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        duration = self.sin_wave_fn(self.call_count)  # milliseconds
        self.call_count += 1

        logger.info(f"Sleep for {duration:.2f}ms")
        time.sleep(duration / 1e3)
        return x

# Init dummy models
model1 = DummyModel(amp=50, interval=250, beta=100)
model2 = DummyModel(amp=50, interval=500, beta=100)
model3 = DummyModel(amp=50, interval=1000, beta=100)

実装の詳細

実装の詳細は GitHub のリポジトリを参照してください。

https://github.com/getty708/mlops-sandbox/tree/20241006-wandb-spanmetrics-exporter/services/monitoring

カスタム Exporter WandBSpanmetricsExporter の実装

OpenTelemetry には、Span Metrics Connector という、span を duration などの metrics に変換して出力する機能があります。これに近いものを W&B 用に実装することで、運用時と同じメトリクスを W&B に記録することを目指します。

具体的には、Python SDK における Exporter のベースクラス SpanExporter を継承して、トレースを受け取り、メトリクスに変換して wandb.log() でメトリクスを W&B に記録します。実験用には外部の依存が少ない方が使い勝手が良いだろうということで、外部の OTel Collector などに送信せず、すべてローカルで処理を完結させます。

具体的な実装はこんな感じです。

from typing import Sequence

import wandb
from loguru import logger
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult

IS_ROOT_SPAN_KEY_NAME = "is_root_span"


class WandBSpanmetricsExporter(SpanExporter):
    """Convert spans to metrics and log them with `wandb.log()`."""

    def __init__(self, commit_every_call: bool = False) -> None:
        """_summary_

        Args:
            commit_every_call (bool, optional): Set True when you have a only one span in every iteration.
                If not and you have multiple spans including nested spans, set False and set `is_root_span` attribute
                to spans that end in the last to increment 'step' in wandb. Defaults to False.
        """
        super().__init__()
        self.commit_every_call = commit_every_call

    def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult:
        for span in spans:
            # Get duration from the span
            duration = (span.end_time - span.start_time) / 1e6  # Duration in milliseconds
            if self.commit_every_call is True:
                commit = True
            else:
                commit = span.attributes.get(IS_ROOT_SPAN_KEY_NAME, False)
            metric = {f"otel/span/{span.name}/duration/ms": duration}
            # Get other attributes from the span
            for key, value in span.attributes.items():
                if key == IS_ROOT_SPAN_KEY_NAME:
                    continue
                metric[f"otel/span/{span.name}/{key}"] = value

            # Log the metric
            logger.debug("span: name={}, duration={:,}[ns]", span.name, duration)
            wandb.log(metric, commit=commit)

        # Return SUCCESS if successful
        return SpanExportResult.SUCCESS

    def shutdown(self) -> None:
        logger.info(f"Shutdown {self.__class__.__name__}")

SpanExporter は、export(spans)shutdown() の 2 つのメソッドを持ちます。export() メソッド今回の実装の要で、外部にデータを送信する部分です。今回は span のリストを受け取り、span の duration (単位: millisecond) を計算して、W&B に記録します。

実装のポイントは、commit_every_call という引数です。wandb.log() はステップ毎に一回だけ呼ばれることが基本的な想定です。span に階層構造がある場合でも、span 自体は個別のオプジェクトとして export() に渡されます。その場合、span ごとにステップ数が更新され、同じステップのメトリクスがバラバラに記録されてしまいます。今回の場合、入力フレームを処理する親 span の中に各モデルの推論という子 span が 3 つあるという階層構造がありますが、個別にステップ数が更新されてしまうため、処理したフレーム数の 4 倍 (= 親 + 子 x3) のステップ数が記録されてしまいます。

この問題を解決するために commit_every_call という引数を追加しました。デフォルト値はFalseです。span の取得時に、最上位の親の span にis_root_span=True という属性を追加します。この span がステップの中で最後に終わるため、export() メソッドは is_root_span 属性を随時チェックし、is_root_span=True の時のみ wandb.log(data, commit=True) とすることで子 span のメトリクスとあわせて 1 つのステップとして記録します。それ以外は wandb.log(data, commit=False) とし、ステップの更新を行いません。

shutdown() メソッドは、今回は特に実装の必要はありません。

計装の追加 (トレースの収集)

カスタム Exporter を実装した後は、パイプラインの中でスパンを収集を実装します。基本的なトレースの実装と同じで、tracer に先ほど実装したカスタム Exporter を渡すだけです。

まずは、tracer を準備します。

resource = Resource.create({"service.name": "pipeline"})
tracer_provider = TracerProvider(resource=resource)
trace.set_tracer_provider(tracer_provider=tracer_provider)
tracer_provider.add_span_processor(
    span_processor=BatchSpanProcessor(
        # SpanExporterにカスタムExporterを渡す
        span_exporter=WandBSpanmetricsExporter()
    )
)
tracer = trace.get_tracer_provider().get_tracer(__name__)

次に、パイプラインの中で tracer.start_as_current_span() を用いて、span を収集します。tracer.start_as_current_span() 環境内で実行された処理が 1 つの span として記録されます。

@click.command()
@click.option("-n", "--num-frames", default=10, help="Number of frames to process")
@click.option("-w", "--wandb-mode", default="offline", show_default=True, help="WandB mode")
def main(num_frames: int = 10, wandb_mode: str = "offline"):
    wandb.init(
        project=WANDB_PROJECT_NAME,
        mode=wandb_mode,
    )

    # Init dummy models
    model1 = DummyModel(name="model1", mean=0.1, std=0.001)
    model2 = DummyModel(name="model2", mean=0.2, std=0.010)
    model3 = DummyModel(name="model3", mean=0.3, std=0.100)

    # Run the pipeline
    for i in range(num_frames):
        logger.info(f"Processing frame {i}")
        # 親 Span
        with tracer.start_as_current_span("process_single_frame") as root_span:
            x = torch.randn(1, 3, 100, 100)

            # 子span
            with tracer.start_as_current_span("model1"):
                x = model1(x)
            with tracer.start_as_current_span("model2"):
                x = model2(x)
            with tracer.start_as_current_span("model3"):
                x = model3(x)

            # wandbのstepを更新するためのフラグ
            root_span.set_attribute("is_root_span", True)
            # 現在のフレームに関する情報をspanに追加
            root_span.set_attribute("frame_idx", i)
            root_span.set_attribute("num_detecton", num_detection_generator(i))

    trace.get_tracer_provider().shutdown()
    wandb.finish()

参考ですが、model_1 には 物体検出モデルが入ることを想定し、検出されたオブジェクト数を span の属性に記録しています。後続のmodel_2model_3 には、オブジェクト毎に処理を行うモデルが想定され、モデルの処理時間は入力されるオブジェクト数に依存するため、非常に重要な値です。

実行結果

実行方法は以下のとおりです。

# Clone the repository and Install dependencies
git clone git@github.com:getty708/mlops-sandbox.git
cd mlops-sandbox
poetry install

# Run the sample pipeline
python services/monitoring/tools/run_otel_in_wandb.py -n 1000 -w online

これにより 1000 フレーム分の Span が生成され、以下のように W&B にメトリクスが記録されます。model_1は 250 step 周期の正弦波、model_3 は 1000 step 周期の正弦波になっており、事前に設定した各モデルのパラメータと一致していることが確認できました。また、num_detection も正弦波をもとに生成されており、それを確認することができます。x 軸も全てのグラフで揃っていることも確認できました。

W&B Dashboard
W&B ダッシュボードで記録されたメトリクス

まとめ

各モデルの処理時間を、OTel のカスタム Exportor を実装することで、W&B に記録することができるようになりました。次は運用環境を想定し、同じトレースとメトリクスを外部のシステム (JeagerGrafana) で長期間にわたって監視できることを確認していきたいと思います。

参考文献

脚注
  1. Well-Architected machine learning lifecycle - AWS ↩︎

Discussion