🌸

RustでOnnxモデルを読み込んで物体検出してみる

に公開

概要

前回は、tractを使ってonnxモデルを呼び出して、クラスカテゴリなどのメタデータを取り出しました。

https://zenn.dev/bamboo_nova/articles/b9c4e880083662

今回は、実際にonnxモデルを読み込んで、対象の画像に対して物体を検出するところまで実施してみます。

使用したコードは下記のリポジトリに公開済みです。

https://github.com/bamboo-nova/onnx-calling-rs

実装

主に必要な実装は、下記の三種類です。

  • 画像の前処理
  • Bboxの設計
  • YOLOで推論する処理

onnxモデルを作成する

こちらについては、前回の記事で説明してるので省略します。

画像の前処理

画像の前処理として、アスペクト比を維持して対応する入力形式に変換する必要があります。
デフォルト設定でultralyticsからエクスポートされたyolov8nは640x640の画像を対象にしています。そのため、スケーリング調整して、アスペクト比を維持したまま640x640の画像に変換する処理を行います。

src/yolo.rs
pub fn letterbox(input_image: &DynamicImage, target_size: u32) -> DynamicImage {
    let img_width = input_image.width();
    let img_height = input_image.height();

    // スケールは短辺基準ではなく「長辺を target_size にフィット」
    let scale = target_size as f32 / img_width.max(img_height) as f32;
    let new_width = (img_width as f32 * scale).round() as u32;
    let new_height = (img_height as f32 * scale).round() as u32;

    // リサイズされた画像
    let resized = image::imageops::resize(
        &input_image.to_rgb8(),
        new_width,
        new_height,
        image::imageops::FilterType::Triangle,
    );

    // target_size の正方形キャンバスを作成
    let mut canvas = image::RgbImage::from_pixel(target_size, target_size, image::Rgb([0, 0, 0]));

    // 貼り付けオフセット(中央)
    let x_offset = ((target_size as i32 - new_width as i32) / 2).max(0) as u32;
    let y_offset = ((target_size as i32 - new_height as i32) / 2).max(0) as u32;

    // リサイズ画像を中央に貼り付け
    image::imageops::replace(&mut canvas, &resized, x_offset as i64, y_offset as i64);

    DynamicImage::ImageRgb8(canvas)
}

これは、ultralyticsのutilsで提供されているletterboxと同じ仕組みになっています。

画像のサイズを適切に変換したら、tractで取り扱っているTensor形式に変換します。これで準備完了です。

src/yolo.rs
fn preprocess(input_image: &DynamicImage, target_size: u32) -> Tensor {
    let padded = letterbox(input_image, target_size);
    
    // Convert tract tensor.
    // (Batch, Channel, Height, Width)
    // Choice c: channel and get pixel [u8; 4], and normalize.
    let image: Tensor = tract_ndarray::Array4::from_shape_fn(
        (1, 3, target_size as usize, target_size as usize),
        |(_, c, y, x)| padded.get_pixel(x as u32, y as u32)[c] as f32 / 255.0,
    ).into();
    image
}

Bboxの設計

一般的なxywhn形式のフォーマットに加えて、検出された領域の切り出しもしたかったのでxyxy形式も加えました。

src/bbox_struct.rs
use anyhow::{Error, Result};
use image::DynamicImage;

#[derive(Debug, Clone)]
pub struct Bbox {
    pub xywhn: Xywhn,
    pub xyxy: Xyxy,
    pub conf: f32,
    pub cls: String,
}

#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct Xywhn {
    /// xywhn format.
    pub x: f32,
    pub y: f32,
    pub w: f32,
    pub h: f32,
}

impl Xywhn {
    pub fn is_normalized(&self) -> bool {
        self.x >= 0.0 && self.x <= 1.0 &&
        self.y >= 0.0 && self.y <= 1.0 &&
        self.w >= 0.0 && self.w <= 1.0 &&
        self.h >= 0.0 && self.h <= 1.0
    }
}

YOLOで推論する処理

tractを用いてonnxモデルを読み込んで推論を実施します。
まず、必要なモデルを読み込みます。

src/yolo.rs
pub fn load_yolo_model(model_path: &str, input_size: (u32, u32)) -> YoloModel {
    let pred_model = tract_onnx::onnx()
        .model_for_path(model_path)
        .unwrap()
        .with_input_fact(0, f32::fact([1, 3, input_size.0 as i32, input_size.1 as i32]).into())
        .unwrap()
        .into_optimized()
        .unwrap()
        .into_runnable()
        .unwrap();
    YoloModel { model: pred_model }
}

そして、推論を実施します。今回は前述したBbox構造体を用意しており、そちらにbboxの位置情報やconfidence、クラスなどの情報を記録します。onnxモデルの推論結果は正規化されてないxywh情報と各クラス(80種類)のconfidenceになります。各xywhを640で割って0~1の値に正規化した上で、confidenceの最も高いクラスを割り当てます。

また、non-maximum supression(NMS)も後処理で実装しており、これは複数のbboxが検出された際に、IoUが閾値を超えるbboxは最もconfidenceの高いbbox以外は削除する処理を実装しています。これによって、重複してるbboxは排除され、ユニークなbboxが検出されるようになります。

src/yolo.rs
impl YoloModel {
    pub fn get_bbox(
        &self,
        input_image: &DynamicImage,
        confidence_threshold: f32,
        iou_threshold: f32,
        imgsz: u32,
        class_maps: HashMap<String, String>,
    ) -> Result<Vec<Bbox>, Error> {
        // Preprocess
        let preprocess_image = preprocess(input_image, imgsz);

        // run forward pass and then convert result to f32
        let forward = self.model.run(tvec![preprocess_image.to_owned().into()]).unwrap();
        let output = forward.get(0).unwrap().to_array_view::<f32>().unwrap().view().t().into_owned();

        // process results(reference: https://github.com/AndreyGermanov/yolov8_onnx_rust/blob/main/src/main.rs)
        let mut bboxes: Vec<Bbox> = vec![];
        let output = output.slice(s![..,..,0]);
        for row in output.axis_iter(Axis(0)) {
            let row:Vec<_> = row.iter().map(|x| *x).collect();
            let (class_id, confidence) = row.iter().skip(4).enumerate()
                .map(|(index,value)| (index,*value))
                .reduce(|accum, row| if row.1>accum.1 { row } else {accum}).unwrap();
            if confidence < confidence_threshold {
                continue
            }

            // if confidence >= confidence_threshold {
            let x = row[0] / imgsz as f32;
            let y = row[1] / imgsz as f32;
            let w = row[2] / imgsz as f32;
            let h = row[3] / imgsz as f32;
            let class_name = class_maps
                .get(&class_id.to_string())
                .unwrap_or(&"unknown".to_string())
                .clone();
            // 推論結果は640x640で行われているので、元の画像でやるとズレる
            // xywhnの元になった画像サイズにすること(入力画像とは限らないので)
            let bbox = Bbox::new(
                x, y, w, h,
                confidence, class_name,
                imgsz, imgsz,
            );
            bboxes.push(bbox);
        }
        Ok(nms_boxes(bboxes, iou_threshold))
    }
}

引数の設計と実行スクリプト

引数を指定するargs.rsと推論して切り出すところまでを実行するlib.rsを実装します。

まず引数ですが、今回は簡単にモデルのファイルパスと入力画像のパス、confidenceの閾値とNMSにおける後処理のためのIoUの閾値を反映しました。

src/args.rs
use clap::Parser;

#[derive(Parser, Clone)]
#[command(name = "onnx-calling", version = "0.1.0", author = "chiikawa", about = "onnx model usage")]
pub struct Args {
    /// ONNX model path
    #[arg(value_name = "MODEL", help = "ONNX model path")]
    pub yolo_model: String,

    /// Image path
    #[arg(value_name = "SOURCE", help = "source image path")]
    pub source: String,

    /// Confidence threshold.
    #[arg(short='c', long="conf-threshold", default_value="0.3")]
    pub conf_threshold: f32,

    /// IoU threshold.
    #[arg(short='i', long="iou-threshold", default_value="0.7")]
    pub iou_threhold: f32,
}

この引数を元に推論するコードは、下記にまとめています。

src/lib.rs
use clap::Parser;
//use serde_json::{Value, to_writer_pretty};
//use std::fs::File;
use std::error::Error;

mod args;
mod bbox_struct;
mod yolo;
mod onnx_metadata;

use args::Args;
use yolo::YoloModel;
use crate::yolo::{load_yolo_model, letterbox};
use crate::onnx_metadata::get_class_ids;

type MyResult<T> = Result<T, Box<dyn Error>>;

#[derive(Debug)]
pub struct Config {
    source: String,
    model_hash: String,
    conf_threshold: f32,
    iou_threhold: f32,
}

pub fn get_args() -> MyResult<Config> {
    let matches = Args::parse();

    Ok(Config{
        source: matches.source,
        model_hash: matches.yolo_model,
        conf_threshold: matches.conf_threshold,
        iou_threhold: matches.iou_threhold,
    })
}

#[allow(deprecated)]
pub fn run(config: Config) -> MyResult<()> {
    let image = image::open(config.source)?;
    let original_image = image.clone();

    let yolo_model: YoloModel = load_yolo_model(&config.model_hash, (640, 640));
    let class_maps = get_class_ids(config.model_hash)?;
    let results = yolo_model.get_bbox(
        &image,
        config.conf_threshold,
        config.iou_threhold,
        640,
        class_maps,
    )?;
    println!("{:?}", results);
    let resized_image = letterbox(&image, 640);
    for (i, mut result) in results.into_iter().enumerate() {
        let mut crop_image = result.crop_bbox(&resized_image)?;
        let mut save_path = format!("output_{}.jpg", i.to_string());
        crop_image.save(save_path).unwrap();
    }
    Ok(())
}

結果

今回はサンプル画像として、下記で提供されてる画像を使用します。
https://www.photo-ac.com/main/detail/32062764/

では、実際に推論してみます。

# 実行
$ cargo run --bin main -- yolov8n.onnx <image_path>

# 推論結果
[Bbox { xywhn: Xywhn { x: 0.4742462, y: 0.55981433, w: 0.05040426, h: 0.156077 }, xyxy: Xyxy { x1: 287, y1: 308, x2: 319, y2: 408 }, conf: 0.6236874, cls: "tie" },
Bbox { xywhn: Xywhn { x: 0.39460135, y: 0.51867336, w: 0.4226303, h: 0.5541665 }, xyxy: Xyxy { x1: 117, y1: 154, x2: 387, y2: 509 }, conf: 0.8008655, cls: "person" },
Bbox { xywhn: Xywhn { x: 0.69637465, y: 0.56644905, w: 0.33534288, h: 0.45844665 }, xyxy: Xyxy { x1: 338, y1: 215, x2: 552, y2: 509 }, conf: 0.8570766, cls: "person" }]

このように、三つのBbox情報でネクタイ(tie)と二人の人物(person)が取り出せました。
実際に切り取った画像も見てみましょう。


このように、適切に検出された領域を切り出せました。

まとめ

今回は、onnxモデルをtractで読み込んで、実際に物体検出するところまでを実装しました。

参考資料

https://github.com/sonos/tract/blob/06edcef7f3d90fcd61c1e58fcde7510398725ffe/examples/face_similarity_arcface_onnx/src/yolo_face.rs#L16

Discussion