Open8

onnxruntime C++ APIのsessionを見てみる

やまもとやまもと

モデルの読み込みからsessionの準備は次のような感じに書く

const std::wstring modelFile = L"....onnx";

Ort::Env env;

Ort::SessionOptions sessionOptions;
// 必要に応じてオプションを設定していく
int num_threads = 1;
sessionOptions.SetInterOpNumThreads(num_threads);
sessionOptions.SetIntraOpNumThreads(num_threads);
sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_DISABLE_ALL);

Ort::Session session(nullptr);
session = Ort::Session(env, modelFile.data(), sessionOptions);
やまもとやまもと

グラフ最適化

https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html
グラフレベルの変換や小さなグラフの単純化、ノードの削除、はたまた寄り複雑なノードの融合やレイアウトの最適化まで様々なレベルの最適化を提供している。

sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_DISABLE_ALL);
この部分でどうグラフを最適化するかのレベルを決定できる

  • GraphOptimizationLevel::ORT_DISABLE_ALL -> 最適化のオフ
  • GraphOptimizationLevel::ORT_ENABLE_BASIC -> 基本的な最適化を有効化する
  • GraphOptimizationLevel::ORT_ENABLE_EXTENDED -> 基本的な最適化と拡張最適化を有効化する
  • GraphOptimizationLevel::ORT_ENABLE_ALL -> レイアウトの最適化を含む使用可能なすべての最適化を有効化する

基本的な最適化

冗長なノードや計算を削除を実施。実行プロバイダにかかわらず適用可能。

  • 定数の畳み込み・・・あらかじめ定数部分を計算しておき実行時に計算しないようにする
  • 冗長なノードを削除する・・・ID削除、スライス除去、ドロップアウトの排除など
  • 複数のノードの融合/折りたたみ・・・例えば、Conv+AddやConv+MulなどのAdd,Mul部分をConvのバイアスとして扱うことでノードを融合する

拡張グラフの最適化

やや複雑なノードの融合を実施。CPU,CUDA,またはROCmなどの特定の実行プロバイダーでのみ適応可能

レイアウトの最適化

CPUで実行するときのみ適応可能

やまもとやまもと

オンライン/オフライン最適化

基本的に上記でやれば最適化されるが、推論セッションの初期化に最適化処理が走るとオーバーヘッドが発生する可能性があるので(特に複雑な場合)、あらかじめPythonからONNXに変換する時点で最適化を行っておき、推論時の最適化をオフにすることで起動時間を短縮できる。

PythonでONNX->ONNX変換

import onnxruntime as rt

sess_options = rt.SessionOptions()

# Set graph optimization level
sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_EXTENDED

# To enable model serialization after graph optimization set this
sess_options.optimized_model_filepath = "<model_output_path\optimized_model.onnx>"

session = rt.InferenceSession("<model_path>", sess_options)

C++ API

Ort::SessionOptions session_options;

// Set graph optimization level
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);

// To enable model serialization after graph optimization set this
session_options.SetOptimizedModelFilePath("optimized_file_path");

auto session_ = Ort::Session(env, "model_file_path", session_options);

やまもとやまもと

session.Run()

https://onnxruntime.ai/docs/api/c/struct_ort_api.html#ad8b12cad4160d43da92f49191cd91895

推論実行

Ort::RunOptions runOptions;
session.Run(runOptions, inputNames.data(), &inputTensor, 1, outputNames.data(), &outputTensor, 1);

session.Runは以下のラッパー
基本的に入出力ノードが1ずつであれば基本的に下の書き方で問題なさそう

OrtStatus * OrtApi::Run(
    OrtSession * 	session,
    const OrtRunOptions * 	run_options, // nullptr デフォルトのOrtRunOptionが使われる
    const char *const * 	input_names, // 入力名(自分で設定):UTF8エンコード文字列の配列
    const OrtValue *const * 	inputs,  // 入力配列 : OrtValueのテンソル配列(複数だったらvectorに格納して良い?)
    size_t 	input_len, // 入力ノードの個数
    const char *const * 	output_names, // 出力名(自分で設定):UTF8エンコード文字列の配列
    size_t 	output_names_len, // 出力ノードの個数
    OrtValue ** 	outputs   // 出力配列 : OrtValueのテンソル配列(複数だったらvectorに格納して良い?)
)	

入力テンソルの作り方は以下が便利
https://github.com/microsoft/onnxruntime-inference-examples/blob/dfa685fc0a5102346e3048dcfc9db8096d7d2378/c_cxx/model-explorer/model-explorer.cpp#L48-L54

ただサンプルを見る限りsession.Runのお作法が他にもあるみたい
https://github.com/microsoft/onnxruntime-inference-examples/tree/main/c_cxx
また複数の入出力ノード場合どう動かせば良いか?

やまもとやまもと

例1

https://github.com/microsoft/onnxruntime-inference-examples/blob/main/c_cxx/model-explorer/model-explorer.cpp

sessionから入出力のノード名とshapeを取ってくることができる

入力ノード情報

vectorにノード名とshapeを順番に追加していく
https://github.com/microsoft/onnxruntime-inference-examples/blob/dfa685fc0a5102346e3048dcfc9db8096d7d2378/c_cxx/model-explorer/model-explorer.cpp#L73-L88

出力ノード情報

入力ノード同様
https://github.com/microsoft/onnxruntime-inference-examples/blob/dfa685fc0a5102346e3048dcfc9db8096d7d2378/c_cxx/model-explorer/model-explorer.cpp#L90-L97

ノード数は.size()で簡単にとれる
input_names.size() ,output_names.size()

session.Runに渡せる形に変換する

https://github.com/microsoft/onnxruntime-inference-examples/blob/dfa685fc0a5102346e3048dcfc9db8096d7d2378/c_cxx/model-explorer/model-explorer.cpp#L116-L123

配列をテンソルに変換する

テンプレートを作っておくと楽
https://github.com/microsoft/onnxruntime-inference-examples/blob/dfa685fc0a5102346e3048dcfc9db8096d7d2378/c_cxx/model-explorer/model-explorer.cpp#L48-L54

https://github.com/microsoft/onnxruntime-inference-examples/blob/dfa685fc0a5102346e3048dcfc9db8096d7d2378/c_cxx/model-explorer/model-explorer.cpp#L107-L110

推論実行

https://github.com/microsoft/onnxruntime-inference-examples/blob/dfa685fc0a5102346e3048dcfc9db8096d7d2378/c_cxx/model-explorer/model-explorer.cpp#L125-L129

std::vectorOrt::Valueのかたちでoutput_tensorsが帰ってくるので、ベクトル要素ごとに推論結果を取得する...という動き方だと思う