🪨

ONNX について

2023/12/25に公開

はじめに

ONNX is an open format built to represent machine learning models.

ONNX defines a common set of operators -
the building blocks of machine learning and deep learning models -
and a common file format to enable AI developers to use models with
a variety of frameworks, tools, runtimes, and compilers.

onnx.ai より

ONNX とは、機械学習モデルを表現するためのオープンなフォーマット(や周辺のエコシステム)を指します。

この記事では、あまり日本語の資料が見つからない部分、特に ONNX フォーマットの構造や上手な扱い方について触れていきます。

なお、必ずしも正確な情報を提供できるわけではないため、間違いなどはコメントしていただけると助かります。

便利なリンク集

ONNX の構造

ONNX フォーマット(以下 ONNX)は、IR.md にもあるように、Protocol Buffers として定義されています。
つまり、*.onnx というファイルを見かけたら、Protocol Buffers としてデシリアライズすれば中身を覗くことができるというわけです。

まずは大まかな構造を把握して、その後に細かい部分を見ていきましょう。

全体像

ONNX は以下のような階層構造から成り立っています(いくつか省いている要素があります)。

また、これらの要素に含まれるフィールドとして以下が存在します(個人的に重要だと判断したものを抜粋しています)。

解釈

ONNX は基本的に、なにか入力値を受け取って、なにか出力値を返すという計算グラフを表現するためのものです。

ここで、入力値(の配列)は model.graph.input (つまり ModelProto model > GraphProto graph > ValueInfoProto[] input) に対応します(これはモデルの実行時に渡される値)。
同様に、出力値(の配列)は model.graph.output に対応します。
また、重みなどの定数値はしばしば model.graph.initializer として定義されており、これも入力値とみなせます(これは実行前から用意されている値)。

これらの入力値を model.graph.node (つまり ModelProto model > GraphProto graph > NodeProto[] node) で指定された計算に従って処理することで、出力値を得ます。

Operator

Node には op_type(文字列)というフィールドがあり、ここに Operators にて紹介されている命令を指定することで、計算の種類を指定することができます。

ある operator はある domain に属するという定義となっていて、例えば Operators で定義されている operator はほとんどが ai.onnx というドメインに属しています。他には com.microsoft.experimental などのドメインが存在しています。

自ら新たなドメインを追加・そのドメインに属する operator を追加することで、ONNX の仕様を超えた計算を可能にすることもできます。

Operator versioning

Operator にはバージョンが付与されます。
バージョンは、Operator の定義に何らかの変更が加えられるごとに増える数値です。
結論としては、Operator は三つ組 (domain, op_type, since_version)(表記: ドメイン:Operator名:変更が加えられる最初のバージョン; e.g., ai.onnx:Conv:3)により区別できます。

また、Operator set(opset)という概念にも注意しましょう。
Opset は組 (domain, version) として与えられ、あるドメインのどのバージョンの Operator が利用可能かを表します。
Opset を ModelProto.opset_import(これは配列)に追加することで、そのモデルで利用可能な opset を指定できます。

たとえば、opset_import(ai.onnx, 18)(com.myownops, 2) が指定されており、
ある計算ノードにはドメインとして com.myownops が指定されている場合、
その計算ノードの Operator のバージョンは 2 となります。

PyTorch などのフレームワークから ONNX をエクスポートする場合、
(そのモデルを表現できる Operator が存在しない場合)指定する opset_version によってはエラーになる可能性があるので注意しましょう。
(PyTorch 側から自分で Operator を定義して、独自ドメインから利用することもできます。詳しくは調べてみてください)

ONNX の扱い方

実際に ONNX ファイルを用いて開発を行う上で、便利な豆知識をまとめておきます。

モデルの可視化

便利なコマンドとして netron があります。
ブラウザからモデルを確認することができ便利です。

# netron /path/to/model.onnx
Serving '/path/to/model.onnx' at http://localhost:8080

netron

便利なスニペット

モデルの読み込み

PyPI からインストールできる onnx ライブラリを使うことで、Protocol Buffers をあまり意識せずに、モデルを扱うことができます。

import onnx # pip install onnx
model = onnx.load("/path/to/model.onnx") # ModelProto として読み込まれる

モデルの書き込み

import onnx
onnx.save(model, "/path/to/model.onnx")

モデルの整合性チェック

import onnx
try:
    onnx.checker.check_model(model)
except onnx.checker.ValidationError as e:
    print(f"Invalid: {e}")
else:
    print("Valid")

モデルの Shape 推論

import onnx
from onnx import shape_inference
inferred_model = shape_inference.infer_shapes(model) # 値に Shape が付与される

モデルの Opset version 変換

from onnx import version_converter
model = onnx.load("/path/to/model.onnx")
# 古い opset version → 新しい opset version なら大抵成功する
new_model = version_converter.convert_version(original_model, TARGET_VERSION)

モデルに含まれる Operator の種類の列挙

import onnx
model = onnx.load("/path/to/model.onnx")
ops = set()
for node in model.graph.node:
    ops.add(node.op_type)
print(ops)

モデルに含まれる値の列挙

import onnx
model = onnx.load("/path/to/model.onnx")
value_names = set()
# 入力
for input in model.graph.input:
    value_names.add(input.name)
# 出力
for output in model.graph.output:
    value_names.add(output.name)
# 初期値
for init in model.graph.initializer:
    value_names.add(init.name)
# Node の出力
for node in model.graph.node:
    for output in node.output:
        value_names.add(output)
print(value_names)

Numpy との連携

onnx.numpy_helper を使うことで、numpy 配列等と protobuf の変換が簡単に行えます。

import numpy
from onnx import numpy_helper

# e.g., numpy.array → TensorProto
init = numpy_helper.from_array(numpy.array([3.14, 3.14, 3.14], dtype=numpy.float32))
init.name = "pi"
print(init)

モデルのグラフ構造を単純化

しばしば、モデルのグラフ構造が複雑(e.g., 使われてほしくない operator が含まれている)な場合があります。
その場合、onnx-simplifier を用いることで、単純化できる場合があります。

import onnx
import onnxsim # pip install onnxsim
model = onnx.load("/path/to/model.onnx")
new_model, ok = onnxsim.simplify(model)
assert ok, "Failed to simplify"

PyTorch モデルを ONNX としてエクスポート

torch.onnx.export を使うことが多いです。

import torch
import torchvision
model = torchvision.model.mobilenet_v3_large(pretrained=True) # 簡単なモデルなら大抵は大丈夫
path = "/path/to/out_model.onnx"
torch.onnx.export(
    model,
    torch.randn(1, 3, 224, 224), # shape が正しければ値は関係ない
    path,
    opset_version=14 # 適切な opset version を自ら指定
)

Optimum を用いた ONNX エクスポート

Optimum を用いることで、PyTorch などから素直にエクスポートできないモデルに対応することができます。

# pip install git+https://github.com/huggingface/optimum.git
# optimum-cli export onnx -m deepset/roberta-base-squad2 --optimize O2 roberta_base_qa_onnx
# ls
roberta_base_qa_onnx

まだありそう

TODO...

ModelProto の使いやすいラッパー

OnnxModel のようなラッパーを用意することで、グラフの解析がしやすくなります。
いくつもスニペットを作成して、散らかっていくよりは良いかもしれません。

# ...
model = onnx.load("/path/to/model.onnx")
new_model = OnnxModel(model)

# e.g., あるノードの children(そのノードの出力を使っているノード)を知りたい
children = new_model.get_children(specific_node)

# e.g., あるノードに定数入力があるならそれを知りたい
nth, const_node = new_model.get_constant_input(specific_node)

Discussion