Zenn
🌟

rustでonnxモデルを呼び出して必要なメタデータを取得

に公開

概要

物体検出モデルなどを呼び出す際に、学習モデルで定義されているクラス情報などを事前に推論前に取得したくなりました。

しかし、生成AIに聞いてもメタデータをうまく取り出す方法を教えてくれないし、事前にyamlファイル用意して呼び出せと言われてしまうし、tractの該当してそうなPRを見ても完全に対応が終わってなさそうなので、なんとかならないかと思い色々中身のソースコードを確認して弄っていたらできたので記事にしました。

解説

今回はtractを使ってonnxモデルを呼び出して、必要なメタデータを取り出したいと思います。

ちなみに、メタデータと呼んでいるのは下図の右側に表示されているバッチサイズや画像サイズ、クラス情報などを指しています。今回はクラス情報を取得したいと思います。

下記はNetronでyolov8nのonnxモデルを読み込ませています。

Netronは下記でアクセスできます。
https://netron.app/

ONNXモデルの生成

まずは、適当なONNXモデルを作成します。今回は、ultralyticsのyolov8nのonnxモデルを生成します。ultralyticsのライブラリが入っていれば、下記の3行をpythonで実行するだけです。

from ultralytics import YOLO

model = YOLO("yolov8n.pt")
model.export(format="onnx", opset=12, dynamic=True)

どうやってメタデータを取得するか

下記の部分に記述があります。
https://github.com/sonos/tract/blob/06edcef7f3d90fcd61c1e58fcde7510398725ffe/onnx/src/prost/onnx.rs#L232C9-L232C23

ModelProto はMLモデルをバンドルするためのトップレベルのファイル/コンテナ形式で、グラフ情報やメタデータの情報を取得することができます。

proto_modelで呼び出した後、metadata_propsにアクセスすることで取得することが可能です。

実際にメタデータ情報を表示してみる

use tract_onnx::prelude::*;
use tract_onnx::model::Onnx;
use serde_json;
use std::collections::HashMap;

fn main() -> TractResult<()> {
    let model_proto = Onnx::default().proto_model_for_path("yolov8n.onnx")?;
    println!("{:?}", &model_proto.metadata_props);
    Ok(())
}

上記のコードを実行すると、下記が表示されます。

[StringStringEntryProto { key: "description", value: "Ultralytics YOLOv8n model trained on coco.yaml"
    }, StringStringEntryProto { key: "author", value: "Ultralytics"
    }, StringStringEntryProto { key: "date", value: "2025-03-23T17:00:29.906890"
    }, StringStringEntryProto { key: "version", value: "8.3.51"
    }, StringStringEntryProto { key: "license", value: "AGPL-3.0 License (https://ultralytics.com/license)"
    }, StringStringEntryProto { key: "docs", value: "https://docs.ultralytics.com"
    }, StringStringEntryProto { key: "stride", value: "32"
    }, StringStringEntryProto { key: "task", value: "detect"
    }, StringStringEntryProto { key: "batch", value: "1"
    }, StringStringEntryProto { key: "imgsz", value: "[640, 640]"
    }, StringStringEntryProto { key: "names", value: "{0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'}"
    }
]

ここで、10番目のStringStringEntryProtoに検出されたbboxのクラス情報(80種類)が定義されています。この情報をHashMapで取り出せれば、ONNXモデルを読み込んで検出されたidがどんなクラスなのか特定できるようになりますね。

では、HashMapを取得しましょう。10番目のStringStringEntryProtoの要素は文字列なので、必要な情報を取り出すためにparse_classes_map関数を用意します。

use tract_onnx::prelude::*;
use tract_onnx::model::Onnx;
use std::collections::HashMap;

fn parse_classes_map(raw: &str) -> HashMap<u32, String> {
    let json_like = raw
        .replace('\'', "\"")
        .replace(": ", "\": ")
        .replace("{", "{\"")
        .replace(", ", ", \"");

    serde_json::from_str::<HashMap<String, String>>(&json_like)
        .expect("Failed to parse class names")
        .into_iter()
        .map(|(k, v)| (k.parse::<u32>().unwrap(), v))
        .collect()
}

fn main() -> TractResult<()> {
    let model_proto = Onnx::default().proto_model_for_path("yolov8n.onnx")?;
    let classes = parse_classes_map(&model_proto.metadata_props[10].value);

    let mut sorted: Vec<_> = classes.into_iter().collect();
    sorted.sort_by_key(|(k, _)| *k);

    for (k, v) in sorted {
        println!("{}: {}", k, v);
    }

    Ok(())
}

実際に実行してみると、昇順で必要なクラスラベルの情報を取り出せるようになりました。

0: person
1: bicycle
2: car
3: motorcycle
4: airplane
5: bus
6: train
7: truck
8: boat
9: traffic light
10: fire hydrant

まとめ

今回は、rustのtractを用いてonnxモデルを読み込んで、メタデータを取得して必要な情報であるクラス情報を取り出しました。

Discussion

ログインするとコメントできます