物体検出モデルの推論高速化入門
はじめに
株式会社EVERSTEELで機械学習エンジニアをしている加藤です。
機械学習システムの運用において、推論の高速化は重要な課題です。特にリアルタイムでの処理が求められるアプリケーションでは、レスポンス時間の短縮がユーザー体験に直結します。また、クラウド環境のコスト削減やエッジデバイスのリソース制約など、様々な観点から推論の効率化が必要とされます。
本記事では特に物体検出モデルのCPU推論に焦点を当て、ディープラーニングモデルの推論を高速化する方法を紹介するとともに、それらのベンチマーク結果を共有します。
「鉄ナビ検収AI」における推論高速化ニーズ
弊社では鉄スクラップの画像解析を行う「鉄ナビ検収AI」というアプリケーションを開発しています。本アプリケーションを提供するために多様な画像認識モデルを運用していますが、その中でも速度要件が厳しいものとして、荷台検出モデルが存在します。
荷台検出モデルはその名の通り、トラックの荷台領域を検出する物体検出モデルです。画像内にトラックが写っているかどうかの判定やトラック位置の特定に活用され、後段の鉄スクラップ解析を正確に行う上で重要な役割を担っています。
詳細は省きますが、エッジデバイス上のCPUで推論を行う場合があることや、システム構成上の都合により、荷台検出モデルの効率的な推論が必要とされています。一方で認識性能に関する要件も甘くはなく、高い認識性能と高速な推論を両立する必要があります。
推論高速化のアプローチ
ディープラーニングモデルの推論高速化を実現するためのアプローチは主に以下です。
- モデル構造や入出力設計の変更
- 数値精度の削減・量子化
- プルーニング
- 小規模モデルへの知識蒸留
- 計算処理の最適化
1〜4は基本的に、大なり小なりのモデルの認識性能の低下と引き換えに推論を高速化するアプローチです。中でもプルーニングや知識蒸留はアルゴリズムやモデル構築フローの大きな変更を伴うため、システムへの導入ハードルが高いです。一方で計算処理の最適化は、アルゴリズムや数値精度を変えずに、同じ計算を効率的に実行することで推論を高速化するアプローチであり、比較的導入しやすいことが大きな利点です。
以下では主に、計算処理の最適化によりモデルの認識性能を保ったまま推論を高速化する方法を見ていきます。その際、物体検出モデルの構築に広く用いられるMMDetectionで構築したPyTorchモデルへの実際の適用方法を共に紹介します。(紹介するほとんどの方法は物体検出モデルに限らないディープラーニングモデル全般に適用可能です)
PyTorchモデルのまま高速化
まず、PyTorchモデルを使用しつつ推論を高速化する方法を見ていきます。これらの方法は、既存のPyTorchパイプラインに対して最小限の変更で適用できる利点があります。
torch.compile
torch.compileはPyTorch 2.x系で導入された機能で、PyTorchコードを最適化されたカーネルにJust-In-Time(JIT)コンパイルすることで実行速度の改善を図ります。コンパイルされたモデルは基本的には通常のモデルと同様に学習や推論に使用できます。
model = torch.compile(model, backend="inductor")
torch.compile
を実行すると、JITコンパイラであるTorchDynamoがPythonバイトコードからFXグラフを抽出し、そのグラフを任意のバックエンドを用いて最適化された関数にコンパイルします。デフォルトのバックエンドであるTorchInductor(inductor
)を用いた場合、NVIDIA GPU環境では主にTritonベースのカーネルが、CPU環境ではC++/OpenMPベースのカーネルが生成されます。
GPU環境ではGPU読み書きの削減などの最適化の恩恵が得られ、TorchDynamo Performance DashBoardでは、A100 GPUにおいてTorchInductorを用いたtimmモデルのコンパイルにより、PyTorch標準の実行方式であるEagerモードと比べてfloat32精度において1.24倍、Automatic Mixed Precisionにおいて1.41倍の幾何平均学習速度の改善が確認されています。
CPUバックエンドは開発初期段階において最適化が十分でなく、多くのモデルはEagerモードよりも性能が低下していました(出典)。しかし近年では最適化が進んでおり、TorchInductor CPU Performance Dashboardによると、2022年10月時点ではtimmモデルの幾何平均推論速度の改善が1.03倍だったのに対して、2025年9月時点のPyTorch 2.10.0a0では1.96倍と大きく改善していることが分かります。
TorchInductorの最新のパフォーマンスダッシュボードはこちらで、任意の条件における速度改善比率やコンパイル時間などを確認することができます。
TorchInductor以外のバックエンドとして、ONNX Runtime(onnxrt
)、OpenVINO(openvino
)、IPEX(ipex
)、Torch-TensorRT(tensorrt
)などがサポートされています。モデルをエクスポートすることなく利用でき便利なものの、計算グラフの違いや演算子のサポート状況などによりネイティブランタイムを用いた場合ほど性能が出ない可能性があるため注意が必要です。
torch.compile
の効果はバックエンドの種類やバージョン、コンパイルモード、モデルの種類、バッチサイズ、計算精度、ハードウェアなど様々な条件に依存するため、実際の動作条件で検証を行うことを推奨します。
メモリ形式の変更
見落とされがちなものの速度改善が見込める最適化施策として、メモリ形式の変更があります。
PyTorchのテンソルはデフォルトでcontiguous (NCHW) メモリ形式でメモリに格納されますが、channels last (NHWC) メモリ形式に変更することで、インターフェース上の次元順序を保ったままメモリへの格納形式を変更することができます。
モデルのchannels lastメモリ形式への変換は、以下のように簡単に行うことができます。
model = model.to(memory_format=torch.channels_last)
channels last形式では空間的に隣接するピクセルがメモリ上でも近い位置に配置されるため、畳み込み演算時のキャッシュ効率が改善されます。そのため、特に畳み込みニューラルネットワークにおいて、メモリアクセスパターンの最適化による推論速度の向上が期待できます。
例えば、channels lastメモリ形式のチュートリアルでは、複数の畳み込みニューラルネットワークを用いた学習ベンチマークにおいて、channels last形式の採用によりNvidia Volta GPUで8%~35%、Intel Xeon Ice Lake CPUで26%~76%の速度改善が確認されています。
ただし、モデルやその他条件によっては逆効果になる場合もあるため、こちらについても導入前にベンチマークを取ることを推奨します。(MatMul処理が主体のTransformer系では改善しにくい傾向があるようです)
半精度推論
PyTorchではfloat32(FP32)が標準のデータ型として使用されますが、これは歴史的なものであり、現代のディープラーニングモデルにおいてこの精度は必ずしも必要ないことが分かっています。対応するハードウェアを利用している場合は、半精度推論により認識性能をほとんど損なうことなく推論を高速化することが可能です。
半精度データ型にはfloat16(FP16)とbfloat16(BF16)があり、ビット構成や特性が異なります。FP16は符号1ビット、指数5ビット、仮数10ビットを使用します。指数部が少なくFP32と比べ表現可能な数値範囲が限られるため、オーバーフローに注意が必要です。BF16はGoogleが当初TPU向けに考案したデータ型で、AI用途で必要十分な精度と広い範囲を兼ね備えるよう設計されています。符号1ビット、指数8ビット、仮数7ビットを使用しており、FP32と同等の広い数値範囲を持つ一方で、細かい数値の表現は粗くなります。
PyTorchでは1.10以降でBF16がサポートされています。例えば、CPU上でのBF16推論は以下のようにして行うことができます。torch.amp.autocast
は数値安定性の低い処理を自動的にFP32で計算する混合精度の動作を提供します。
with torch.amp.autocast(device_type="cpu", dtype=torch.bfloat16):
outputs = model(inputs)
ただし、ハードウェアの対応状況に注意が必要です。半精度推論の性能は、使用するハードウェアがサポートするデータ型と命令セットに大きく左右されます。GPUにおいては、2017年のNVIDIA Volta GPUでTensor Coreが導入され、FP16行列演算が大幅に高速化されました。その後、2020年のAmpere世代からはTensor CoreでのBF16演算に対応しています。CPUでは、2020年のIntel第3世代Xeon(Cooper Lake / Ice Lake)からAVX-512によるBF16命令がサポートされています。一方で、FP16についてはx86 CPU上で直接のハードウェア命令が存在せず、内部的には32ビットに拡張して計算を行う必要があるため、速度面のメリットがほとんどありません。そのため、CPU上での半精度推論ではBF16の利用が一般的です。
半精度推論による高速化は主に、対応ハードウェアが16ビット演算を並列処理することで演算スループットが向上すること、またメモリ使用量が削減され帯域負荷が軽くなることによってもたらされます。改善幅はモデルやハードウェアに大きく依存しますが、一例として、Intel Cooper Lake CPUでBF16を用いることによりTorchVisionモデルの推論速度がFP32と比べ1.4〜2.2倍改善することが確認されています(出典)。
NMS設定の最適化
物体検出モデルでは基本的に、検出矩形群に対してNon-Maximum Suppression(NMS)と呼ばれる後処理が適用されます。NMSは類似した検出結果を除去する処理ですが、その設定パラメータを調整することでシステム自体の振る舞いを変えることなく推論速度を改善できる場合があります。
MMDetectionにおける主な設定項目は以下です:
-
score_thr
: スコア閾値(NMS前に適用される) -
nms_pre
: NMS前の最大候補数 -
nms
: NMS設定。NMSの種類やiou_threshold
などを指定 -
max_per_img
: 最終的な最大検出数
以下のようにしてtest_cfg
属性の内容を上書きすることで、学習済みモデルの設定パラメータを変更することができます。(ただし検出器の種類によりtest_cfg
の構造や設定可能項目が異なることに注意が必要です)
# スコア閾値を変更
model.test_cfg.score_thr = 0.5
# NMS自体を無効化
model.test_cfg.nms = None
これらのパラメータを適切に設定することで、不要な計算を削減し推論速度を向上できる可能性があります。
特に有効と思われるのがscore_thr
の調整です。score_thr
はデフォルトで0.001〜0.05程度の値に設定されているため、出力の大部分を占める場合が多い、スコアの小さな矩形群に対してもNMSが適用されることになります。しかし、物体検出器をシステムで運用する際はある程度スコアの大きな矩形の情報のみを利用する場合がほとんどかと思います。そのため、score_thr
をシステム要件に見合った大きな値に設定することで、NMSが適用される矩形の数を削減して推論速度を改善することができます。
また、最もスコアの高い矩形情報のみを利用するようなユースケースの場合は、NMS自体を無効化することにより更なる速度向上が見込まれます。
推論エンジンの使用
推論エンジンは、推論に特化した形式に変換された学習済みモデルを、最適化された実行環境で動作させます。計算グラフの最適化、メモリ使用量の削減、CPU命令の最適化などにより、大幅な速度向上を期待できます。
ONNX Runtime
ONNX Runtimeはクロスプラットフォームの推論エンジンです。ONNX(Open Neural Network Exchange)形式のモデルを最適化し、ハードウェアに応じて高速に実行します。PyTorchやTensorFlow、Keras、Scikit-learnなどの様々なフレームワークで構築したモデルをONNX形式にエクスポートし、ONNX Runtimeで実行することが可能です。
Execution Provider(EP)という仕組みを通じて、異なるハードウェア向けに最適化された実行ライブラリを使い分けることができます。Intel製プロセッサで高速動作するOpenVINO、GPUでの高速推論を可能とするCUDAやTensorRT、エッジ/モバイルデバイス向けのExecution Providerなどがサポートされており、同一のONNXモデルを様々な実行環境で効率的に動作させることが可能です。
ONNX RuntimeはPythonやJava、C/C++、JavaScript、Ruby等の多様なAPIを提供しています。
Pythonでの推論コード例は以下です。providers
を指定することで容易にExecution Providerを切り替えることが可能です(事前に依存ライブラリのインストールが必要です)。
import onnxruntime as ort
sess = ort.InferenceSession("model.onnx", providers=["CPUExecutionProvider"])
input_name = sess.get_inputs()[0].name
outputs = sess.run(None, {input_name: inputs})
OpenVINO
OpenVINOはIntelが開発するオープンソースの推論最適化ツールキットで、Intel製のCPUやGPU、NPUなどのハードウェアにおいて高速な推論を可能とします。
サポートするモデル形式として、ONNXやPyTorch、TensorFlowなどの他、OpenVINO独自の中間表現であるOpenVINO IR形式があります。OpenVINOのモデル変換APIを使用して任意のモデルをOpenVINO IR形式(.xml
と.bin
の2つで構成)にエクスポートすることができます。OpenVINO IR形式のモデルを使用することで、ストレージ消費量や初回推論のレイテンシを削減することができます。
OpenVINO RuntimeはCおよびPython APIを備えたC++ライブラリとなっています。
Pythonでの推論コード例は以下です。
import openvino as ov
core = ov.Core()
model = core.read_model(model_path)
compiled_model = core.compile_model(model, device_name="CPU")
infer_request = compiled_model.create_infer_request()
outputs = infer_request.infer({compiled_model.inputs[0]: inputs})
INT8量子化
OpenVINOがサポートするモデル最適化ツールであるNNCF(Neural Network Compression Framework)を利用することで、モデルの量子化やプルーニング、LLMの重み圧縮などの最適化を行うことができます。
それらの中でも8ビット(INT8)への学習後量子化は、モデルの再学習やファインチューニングを必要としないため比較的導入が容易な手法です。2019年の第2世代Xeon(Cascade Lake)以降において、INT8の行列演算を効率的に行うことができるVNNI命令がサポートされており、その恩恵を得ることができます。数値精度が落ちるため認識性能への影響は多少存在するものの、半精度推論以上に推論速度が向上することを期待できます。
以下コードのようにしてINT8量子化を行うことができます。量子化パラメータを推定するために、数百件程度のキャリブレーションデータセットを用意する必要があります。
import nncf
import torch
# キャリブレーションデータセットを用意
calibration_loader = torch.utils.data.DataLoader(...)
calibration_dataset = nncf.Dataset(calibration_loader)
# モデルを量子化
model = ov.Core().read_model(model_path)
quantized_model = nncf.quantize(model, calibration_dataset)
認識性能への影響を最小限に抑えたい場合は、検証関数を使用して精度低下幅を制御した量子化を行うnncf.quantize_with_accuracy_control
を使用することもできます。
quantized_model = nncf.quantize_with_accuracy_control(
model,
calibration_dataset=calibration_dataset,
validation_dataset=validation_dataset,
validation_fn=validate,
max_drop=0.01,
drop_type=nncf.DropType.ABSOLUTE,
)
MMDeploy Inference SDK
MMDeployはMMDetectionやMMPose等のOpenMMLab製ライブラリで構築したPyTorchモデルを各種プラットフォームにデプロイするためのツールキットです。ONNXやTensorRT、OpenVINO IR等の様々な形式へのモデル変換機能や推論SDK、ベンチマークツールなどが提供されています。
以下のようにしてOpenMMLabモデルを推論エンジン向けの形式に変換できます。
python mmdeploy/tools/deploy.py \
<mmdeploy_config> \
<model_config> \
<model_weights> \
<test_image> \
--work-dir <work_dir> \
--device <device> \
--dump-info \
--show
<mmdeploy_config>
の指定によりモデルの変換形式やその他オプションを切り替えることができます。モデルの変換により、各種バックエンド向けのファイルに加え、推論SDKが使用するdeploy.json
とpipeline.json
を含むディレクトリが生成されます。
推論SDKはC/C++で開発されており、C、C++、Python、C#、JavaなどのAPIが提供されています。モデル推論における前処理、順伝播、後処理をまとめて実行することができます。
Pythonでの推論コード例は以下です。
from mmdeploy_runtime import Detector
detector = Detector(mmdeploy_model, device_name="cpu")
bboxes, labels, _ = detector(image)
速度ベンチマーク
弊社の荷台検出データセットを用いて、上述した高速化アプローチの速度ベンチマークを行いました。
検証環境のセットアップ
Pythonパッケージマネージャのuvを用いて、以下手順でCPUでの検証環境をセットアップしました。
# Python 仮想環境を作成
uv venv --python 3.11.12
# 各種パッケージをインストール
uv pip install "numpy<2"
uv pip install torch==2.1.2 torchvision --extra-index-url https://download.pytorch.org/whl/cpu
uv pip install mmcv==2.1.0 -f https://download.openmmlab.com/mmcv/dist/cpu/torch2.1/index.html
uv pip install \
mmengine==0.10.5 \
mmdet==3.3.0 \
mmdeploy==1.3.1 \
mmdeploy-runtime==1.3.1 \
onnxruntime==1.22.1 \
openvino-dev==2023.3.0 \
onnxruntime-openvino==1.22.0 \
nncf==2.10.0
モデルの変換
MMDetectionなどで構築したOpenMMLabモデルは複雑なパイプラインを持っており、ONNX形式等への変換は同じくOpenMMLabが提供するMMDeployを介して行うことが推奨されています。今回はMMDeploy 1.3.1を用いて以下のようにしてモデルをONNX形式に変換し、各種推論エンジンで利用しました。
# MMDeploy のリポジトリをクローン
git clone -b v1.3.1 https://github.com/open-mmlab/mmdeploy.git
# モデルを ONNX 形式に変換
uv run mmdeploy/tools/deploy.py \
mmdeploy/configs/mmdet/detection/detection_onnxruntime_static.py \
<mmdet_model_config> \
<mmdet_model_weights> \
<test_image> \
--work-dir work_dirs/onnx/ \
--device cpu \
--dump-info \
--show
ベンチマーク結果
MMDetectionの標準的な推論APIであるDetInferencerをそのまま使用して推論した場合をベースラインとし、CPU上での各速度改善施策の速度ベンチマークを行いました(使用したCPUが対応していない半精度推論のみ未検証です)。DetInferencerでは内部でリサイズやパディング、正規化等の前処理が実行される一方、MMDeployで変換したONNXファイルには基本的にそれらの前処理が含まれません。そのため公平な速度比較を行うために、推論エンジンを用いたアプローチにおいてはPython実装の前処理を含めた速度を測定しました。
検証はAWSのg4dn.xlargeインスタンス(Intel Cascade Lake CPU 4コア)上で行い、検証対象のモデルとして軽量かつ高性能な物体検出モデルであるRTMDet tinyを使用しました。単一画像を入力とし、ウォームアップ用の推論を10回行った後、100回連続して推論を行い、その平均レイテンシを測定しました。
ベンチマーク結果は以下です。
手法 | 平均レイテンシ (ms) |
---|---|
ベースライン | 199.8 |
torch.compile (TorchInductor) | 201.2 |
channels last メモリ形式 | 151.0 |
NMS score_thr 0.5 | 171.1 |
NMS 無効化 | 172.0 |
MMDeploy Inference SDK | 150.7 |
ONNX Runtime (CPU EP) | 136.5 (内前処理14.0) |
ONNX Runtime (OpenVINO EP) | 92.7 (内前処理10.1) |
OpenVINO Runtime | 91.2 (内前処理9.6) |
INT8量子化 | 51.1 (内前処理9.5) |
torch.compile
は今回の検証では効果が見られませんでした。今回はmmcvとの互換性の問題から少し古めのPyTorch 2.1.2を使用しましたが、バックエンドの最適化が進んでいる最新のPyTorchを使用できれば速度改善する可能性があります。また今回は検証できていないですが、TorchInductor以外のバックエンドを使用することによる改善の可能性もあります。
channels lastメモリ形式の採用により1.32倍、NMSパラメータの調整により1.17倍程度の速度向上が確認されました。これらは導入コストが低い割に効果が大きく、試す価値の高い施策であると言えます。またNMSのスコア閾値を変更した場合とNMSを無効化した場合でほぼ同程度のレイテンシとなりました。スコア閾値の変更のみで十分に計算処理が削減されているためであると思われます。
推論エンジンを使用するアプローチでは、MMDeploy Inference SDK、ONNX Runtime、OpenVINOにおいてそれぞれ1.33倍、1.46倍、2.17倍の速度向上が得られました。Intel CPUの使用時は速度改善効果の大きなOpenVINOを、それ以外のCPUではONNX Runtimeを積極的に利用していくのが良さそうです。またONNX Runtime上でOpenVINOExecutionProviderを使用した場合とOpenVINOネイティブランタイムを用いた場合は同程度のレイテンシとなりました。ONNX Runtimeをインターフェースとして使用し、計算環境に応じてExecution Providerを切り替えることで推論コードの保守コストを低減できそうです。
NNCFを用いたINT8量子化により最も大きな3.91倍の速度改善が得られました。ただし、計算精度を落とすアプローチであるため、認識性能への影響については別途検証が必要です。
推論エンジンの利用時に前処理の実行時間が占める割合がある程度存在することが確認されました。更なる高速化を目指す場合は、前処理の最適化(ONNXへの組み込み等)も検討余地があると考えられます。また今回は基本的な設定で検証を行いましたが、推論エンジンの設定パラメータやオプションを調整することでの更なる速度改善余地があります。
まとめ
この記事では物体検出モデルの推論を高速化する方法を見ていくとともに、それらの効果を検証しました。各種アプローチによる速度改善が確認されましたが、有効性の大きな方法ほど導入や保守のコストが高い傾向があります。そのためシステムへの導入時はそれら全体のバランスを考慮して施策を決定する必要があります。
推論高速化は非常に奥の深いトピックであり、モデル構造やハードウェア環境、実行条件などによって最適なアプローチが異なってきます。各種ソフトウェアやハードウェアの開発も日進月歩で進んでおり、継続的な検証と最適化により、さらなる速度向上を実現できると思います。
本記事の内容が少しでも参考になれば幸いです。
Discussion