🔭

OpenTelemetryを使ってSpanをW&Bに記録する

2024/10/07に公開

はじめに

機械学習 (ML) 分野において、モデルの精度評価や実験のトラッキングに Weights & Biases (以降、W&B) がよく使われます。少し複雑なことをしようとすると、複数のモデルを組み合わせたパイプラインを構築する必要があります。ML パイプラインの開発においては、処理速度やメモリの使用量も重要な指標です。例えば、モデルの更新で大きな精度向上を得られたとしても、処理速度が要求を満たせなければ、そのモデルを採用することはできません。

メモリの使用量などシステムのメトリクスは W&B 側で自動的に記録されます。一方で、各モデルの処理速度など、アプリケーションのメトリクスは、wandb.log()などを用いて自分で記録する必要があります。開発時には、wandb.log() を挿入して値を記録すれば良いのですが、運用時には同じメトリクスを取るために、別の実装を追加しなければなりません。同じことをするコードを、2度書くことは保守の観点でも避けたいものです。

何かいい方法はないかと考えていたところ、OpenTelemetry (OTel) を活用することを思いついたので、その方法を紹介します。

今回想定する ML パイプライン

今回は以下の構成のパイプラインを例に実装しました。Span が今回計測対象のブロックです。

Dummy Pipeline Architecture
パイプラインの構成

3 つのモデルで構成されますが、今回は簡単のため実際のモデルではなく、ランダムな処理時間を持つクラスで置き換えています。処理時間は平均mean 分散 std の正規分布に従う乱数で決定されます。

class DummyModel(torch.nn.Module):
    def __init__(self, name: str, mean: float, std: float, seed: int = 12345) -> None:
        super().__init__()
        self.name = name
        self.mean = mean
        self.std = std
        self.rng = np.random.default_rng(seed)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        duration = self.rng.normal(loc=self.mean, scale=self.std)
        logger.info(f"Sleep for {duration:.2f}sec")
        time.sleep(duration)
        return x

カスタム Exporter WandBSpanmetricsExporter の実装

OpenTelemetry には、Span Metrics Connector という、span を duration などの metrics に変換して出力する機能があります。これに近いものを W&B 用に実装することで、運用時と近いメトリクスを W&B に記録できるようにします。

具体的には、Python SDK における Exporter のベースクラス SpanExporter を継承して、トレースを受け取り、メトリクスに変換して wandb.log()でメトリクスを W&B に記録します。実験用には外部の依存が少ない方が使い勝手が良いだろうということで、外部の OTel Connector などに送信せず、すべてローカルで処理を完結させます。

具体的な実装はこんな感じです。

from typing import Sequence

import wandb
from loguru import logger
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult

IS_ROOT_SPAN_KEY_NAME = "is_root_span"


class WandBSpanmetricsExporter(SpanExporter):
    """Convert spans to metrics and log them with `wandb.log()`."""

    def __init__(self, commit_evry_call: bool = False) -> None:
        """_summary_

        Args:
            commit_evry_call (bool, optional): Set True when you have a only one span in every iteration.
                If not and you have multiple spans including nested spans, set False and set `is_root_span` attribute
                to spans that end in the last to increment 'step' in wandb. Defaults to False.
        """
        super().__init__()
        self.commit_evry_call = commit_evry_call

    def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult:
        for span in spans:
            # Get duration from the span
            duration = (span.end_time - span.start_time) / 1e6  # Duration in milliseconds
            if self.commit_evry_call is True:
                commit = True
            else:
                commit = span.attributes.get(IS_ROOT_SPAN_KEY_NAME, False)
            metric = {f"otel/span/{span.name}/duration/ms": duration}
            # Get other attributes from the span
            for key, value in span.attributes.items():
                if key == IS_ROOT_SPAN_KEY_NAME:
                    continue
                metric[f"otel/span/{span.name}/{key}"] = value

            # Log the metric
            logger.debug("span: name={}, duration={:,}[ns]", span.name, duration)
            wandb.log(metric, commit=commit)

        # Return SUCCESS if successful
        return SpanExportResult.SUCCESS

    def shutdown(self) -> None:
        logger.info(f"Shutdown {self.__class__.__name__}")

SpanExporter は、export(spans)shutdown() の 2 つのメソッドを持ちます。export() メソッド今回の実装の要で、外部にデータを送信する部分です。今回は span のリストを受け取り、span の duration (単位: millisecond) を計算して、W&B に記録します。

実装のポイントは、commit_evry_call という引数です。wandb.log() はステップ毎に一回だけ呼ばれることが基本的な想定です。span に階層構造がある場合でも、span 自体は個別のオプジェクトとして export() に渡されます。その場合、span ごとにステップ数が更新され、同じステップのメトリクスがバラバラに記録されてしまいます。今回の場合、入力フレームを処理する親 span の中に各モデルの推論という子 span が 3 つあるという階層構造がありますが、個別にステップ数が更新されてしまうため、処理したフレーム数の 4 倍 (= 親 + 子 x3) のステップ数が記録されてしまいます。

この問題を解決するために commit_evry_call という引数を追加しました。デフォルト値はFalseです。span の取得時に、最上位の親の span にis_root_span=True という属性を追加します。この span がステップの中で最後に終わるため、export() メソッドは is_root_span 属性を随時チェックし、is_root_span=True の時のみ wandb.log(data, commit=True) とすることで子 span のメトリクスとあわせて 1 つのステップとして記録します。それ以外は wandb.log(data, commit=False) とし、ステップの更新を行いません。

shutdown() メソッドは、今回は特に実装の必要はありません。

計装の追加 (トレースの収集)

カスタム Exporter を実装した後は、パイプラインの中でスパンを収集を実装します。基本的なトレースの実装と同じで、tracer に先ほど実装したカスタム Exporter を渡すだけです。

まずは、tracer を準備します。

resource = Resource.create({"service.name": "pipeline"})
tracer_provider = TracerProvider(resource=resource)
trace.set_tracer_provider(tracer_provider=tracer_provider)
tracer_provider.add_span_processor(
    span_processor=BatchSpanProcessor(
        # SpanExporterにカスタムExporterを渡す
        span_exporter=WandBSpanmetricsExporter()
    )
)
tracer = trace.get_tracer_provider().get_tracer(__name__)

次に、パイプラインの中で tracer.start_as_current_span() を用いて、span を収集します。tracer.start_as_current_span() 環境内で実行された処理が 1 つの span として記録されます。

@click.command()
@click.option("-n", "--num-frames", default=10, help="Number of frames to process")
@click.option("-w", "--wandb-mode", default="offline", show_default=True, help="WandB mode")
def main(num_frames: int = 10, wandb_mode: str = "offline"):
    wandb.init(
        project=WANDB_PROJECT_NAME,
        mode=wandb_mode,
    )

    # Init dummy models
    model1 = DummyModel(name="model1", mean=0.1, std=0.001)
    model2 = DummyModel(name="model2", mean=0.2, std=0.010)
    model3 = DummyModel(name="model3", mean=0.3, std=0.100)

    # Run the pipeline
    for i in range(num_frames):
        logger.info(f"Processing frame {i}")
        # 親 Span
        with tracer.start_as_current_span("process_single_frame") as root_span:
            x = torch.randn(1, 3, 100, 100)

            # 子span
            with tracer.start_as_current_span("model1"):
                x = model1(x)
            with tracer.start_as_current_span("model2"):
                x = model2(x)
            with tracer.start_as_current_span("model3"):
                x = model3(x)

            # wandbのstepを更新するためのフラグ
            root_span.set_attribute("is_root_span", True)
            # 現在のフレームに関する情報をspanに追加
            root_span.set_attribute("frame_idx", i)
            root_span.set_attribute("num_detecton", num_detection_generator(i))

    trace.get_tracer_provider().shutdown()
    wandb.finish()

参考ですが、model_1 には 物体検出モデルが入ることを想定し、検出されたオブジェクト数を span の属性に記録しています。後続のmodel_2model_3 には、オブジェクト毎に処理を行うモデルが想定され、これらのモデルの処理時間はそのオブジェクト数に依存するため、非常に重要な値です。

実行結果

パイプライン全体の実装は、GitHub を参照してください。

https://github.com/getty708/mlops-sandbox/tree/main/services/monitoring

実行方法は以下のとおりです。

# Clone the repository and Install dependencies
git clone git@github.com:getty708/mlops-sandbox.git
cd mlops-sandbox
poetry install

# Run the sample pipeline
python services/monitoring/tools/run_otel_in_wandb.py -n 100 -w online

これにより 100 フレーム分の Span が生成され、以下の用に W&B にメトリクスが記録されます。model_1はおよそ 100ms 付近で分散は小さく、model_3 は 300ms を中心に大きな分散を持ちます。事前に設定した正規分布の値が確認できました。また、num_detection は sin 波をもとに生成されており、それを確認することができます。x 軸も全てのグラフで揃っていることも確認できました。

W&B Dashboard
W&B ダッシュボードで記録されたメトリクス

まとめ

各モデルの処理時間を、OTel のカスタム Exportor を実装することで、W&B に記録することができるようになりました。次は運用環境を想定し、同じトレースとメトリクスを外部のシステム (JeagerGrafana) で長期間にわたって監視できることを確認していきたいと思います。

参考文献

Discussion