リアルタイムAIアプリケーションにおけるONNXのチューニング
Parakeet株式会社でResearcherをしている金子(nadare)です。CPUのみで動作するリアルタイムAIボイスチェンジャーのParavoの研究開発をしております。
ParavoはAIモデルをPythonのPyTorchで学習した後、モデルをONNXというフォーマットに変換し、Rust上でONNX Runtimeを用いて動かしています。Paravoは音声変換時に最短で10msごとに推論しており、これの処理間隔や処理時間が短くなるほど、変換した音声をループバック再生する際の遅延が減ってしゃべりやすくなります。また、わずかにでも遅れると音声がプツっと途絶え体験を損ねてしまう問題もあります。そのため、Paravoではモデル推論が高速化するように様々なチューニングに力を入れています。
本記事ではONNXを用いたリアルタイムAIアプリケーションを作成する方向けに、ONNX作成時や推論時のパフォーマンスを上げるためのTipsを共有します。ONNX全般で役立つコツから、リアルタイムAIアプリケーションに特化したコツまでありますのでぜひご覧ください。
PyTorchでONNXを出力する際のノウハウ
TorchDynamo-basedのexportを使う
PyTorchのonnx.exportではTorchScript-basedとTorchDynamo-basedの二種類を選ぶことができます。torch.onnx.exportにおいてdynamo=True
を選ぶことでTorchDynamo-basedのexportが利用可能です。TorchDynamo-basedの方が新しく、最適化が進んでいるのでこちらを選ぶ方が良いです。
export時にoptimize=Trueを指定する。
これはTorchDynamo-basedのexportのみ可能です。PyTorch 2.7からはデフォルトになりました。この引数をONにすることで不要なグラフノードを除いたり、一つのグラフにまとめられる処理はまとめてくれます。
opset versionはあまり気にしなくてよい
ONNXにはopset versionがあり、TorchDynamoベースのexportでは18まで対応しています。バージョンが高いほど、複雑な演算が一つのグラフにまとまっており、現在は2025/8/27公開のONNX 1.19.0で更新された24まであります。ただ、2025/8/27時点のONNX Runtimeの最新バージョン1.20では21までの対応で、18→21への進化は量子化や低精度計算の最適化などが中心のため、このバージョンを無理に上げる恩恵は少ないです。ParavoでもPyTorchでopset versionを18で出力した後、ONNXライブラリで21に変更してみましたが、パフォーマンスに影響ありませんでした。
input_namesとoutput_namesは指定しておく
onnxでは入出力について、input_namesとoutput_namesで名前を指定しなくても動かすことはできます、が利用側のライブラリによっては名前で指定するメソッドなどがあるので、最初から面倒くさがらずにinput_namesとoutput_namesを明示しておくことでpython以外の環境で動かす際に楽になります。
(必要なら)メタデータを消しておく
TorchDynamo-basedのexportでは各ノードにpythonの元の関数の情報などが追加されます。これらは後述のNetronという可視化ツールで可視化すると簡単に見えるので、見られたくない場合は消しておきましょう。
model = onnx.load("model.onnx")
model.metadata_props.clear()
for node in model.graph.node:
node.metadata_props.clear()
onnx.save(model, "model.onnx")
可変長入力への対応
Paravoでは一回の処理で処理する音声の長さを10ms ~ 100msの間で可変長にして負荷の調節ができるようにしています。ただ、onnx.exportは固定長入力・固定長出力を想定しており、可変長の入力や入力によってループの回数が変わるモデルを実行する際には次に上げるような工夫が必要です。
早いうちからdynamic_shapesの指定で動くようにコードを書いておく
可変長に関する部分やfor/while文を使うモデルはTorchで動いてもonnxで出力できない、出力しても動かないことがあります。モデルを作って学習し、出力するときになって実はonnxでは動かないモデルだったというのはつらいので、事前にチェックしておきましょう。exportが成功し、かつexport時に与えた引数とは異なる形状でも動くことまで事前に確認してください。
dynamic axesの指定との比較
動的形状の次元を与える方法は推奨されるdynamic_shapesの指定の他に、TorchScript-based版からのdynamic_axesによる指定もあります。ただ、dynamic_shapesの指定の方が厳密なようで、dynamic_axesの指定で動いていたモデルがdynamic_shapesの指定では動かない場合もあります。できるだけdynamic_shapesで通るようにモデルを書いておくのが良いです。
convolutionの活用
例えばスペクトログラムにメルバンクフィルタをかけてメルスペクトログラムにする処理はdotやeinsumだと可変長のチェックで弾かれます。conv1dを使うと可変長にも対応できます。
# (batch, freq, time), (freq, n_mels) -> (..., n_mels, time)
fb = torchaudio.functional.melscale_fbanks(params)
# NG: x = torch.einsum("nft,fm->nmt", x, fb)
x = torch.nn.functional.conv1d(x, fb.transpose(0, 1).unsqueeze(-1), stride=1, padding=0)
ONNX RuntimeでONNXを扱うときのノウハウ
ONNX Runtimeのバージョンをできるだけ上げよう
opsetは結構最適化されているのでそこまでパフォーマンスに影響しないですが、ONNX Runtimeのバージョンは性能を左右します。2025/8/27時点の最新は1.22.2です。
サポートするWindowsのバージョンに注意
ただし、v1.22.1の時点からWindowsのサポートするバージョンは20H1以降になっています。古いPCで動かす場合は古いONNX Runtimeを使う必要があるかもしれません。
実行時のパラメーターについて
optimization level
実行時にプロバイダーに合わせて最適化を行います。最適化はDISABLEと三段階のレベルによる設定があります。デフォルトでは最大のoptimizeを行うようになっています。(最適化レベルごとに行う処理は公式ドキュメントをご確認ください。)
optimizeによって若干実行結果が変化することがありますが、Paravoで確認したところ音質に変化はなかったのでENABLE_ALLに設定しています。
INTRA parallel
ONNXのグラフにおける、ノード内の演算の並列化の設定です。defaultは0で全スレッドを使うようにします。
以前のONNX RuntimeではParavoのようなリアルタイム処理では並列化をしない1
に設定するのが最も早く動いたのですが、現在は計算効率は落ちるものの一定の数まで並列化が効くようになりました。Paravoでは1を指定する省電力モード、物理コア数の1/3までを使うバランスモード・物理コア数に対し8コアまでは使い切るパフォーマンスモードとして用意しています。(バランスモードについては、Pコア・Eコアの指定はできないので、1/3ならPコアだけ使ってくれるだろうという想定です)
同じような設定でINTER parallelの設定もありますが、こちらはノード間の並列の設定になります。execution modeの設定をデフォルトのSEQUENTIALからPARALLELに変更したうえで並列化の数を指定すると有効になりますが、ParavoではSEQUENTIALのまま利用していません。
allow_spinning
待機時にもCPUを動かすかどうかです。これが1だと待機時も動かし続け消費電力が増え、CPUの表示も使用しているコアでは100%になります。Paravoの場合は0に設定しています。
providers
GPUで動かす場合、"CUDAExecutionProvider"も指定できますが、Paravoのリアルタイム推論ではCPUExecutionProviderの方が早いです。
ONNX関連の可視化ノウハウ
グラフ可視化ツール Netron
Netronにより、ONNXで吐き出されたモデルについてグラフや値を可視化することができます。このグラフを可視化することで、例えばグラフの最適化前後でどのように変化したか、dynamic_shapesはどのように扱われているかなどをチェックすることができます。
プロファイリングツール
ONNX Runtimeのプロファイリングを使うと、各ノードでどれくらいの時間が使われていたかjsonで吐き出されます。このjsonをchrome://tracingやPerfetto UIなどの専用ツールで開くと、各実行時間が可視化されます。
計測の際複数回呼び出すと膨大な量のjsonが吐き出されるので、一回の推論で十分です。
まとめ
ONNXを利用する際のTips/ノウハウを紹介しました。ParavoのAIモデルの設計を行う際は、見かけの計算量だけでなくONNXで実際にexportした際のパフォーマンスを比較しながら作成しています。リアルタイムAIアプリケーションの開発においてこれらのコツが役に立てば幸いです。
Discussion