ONNX について
はじめに
ONNX is an open format built to represent machine learning models.
ONNX defines a common set of operators -
the building blocks of machine learning and deep learning models -
and a common file format to enable AI developers to use models with
a variety of frameworks, tools, runtimes, and compilers.onnx.ai より
ONNX とは、機械学習モデルを表現するためのオープンなフォーマット(や周辺のエコシステム)を指します。
この記事では、あまり日本語の資料が見つからない部分、特に ONNX フォーマットの構造や上手な扱い方について触れていきます。
なお、必ずしも正確な情報を提供できるわけではないため、間違いなどはコメントしていただけると助かります。
便利なリンク集
-
https://onnx.ai
- 公式サイト
- 各種ドキュメントへ飛ぶためのポータルサイトとしても機能している
-
https://github.com/onnx/onnx
- 公式 GitHub リポジトリ
- フォーマットの定義や周辺ライブラリ・ツールなどを管理している
-
https://github.com/onnx/onnx/blob/main/docs/
- フォーマットのドキュメント
- 困ったらこのディレクトリ内を探す
-
https://github.com/onnx/onnx/blob/main/docs/IR.md
- IR (中間表現)のドキュメント
-
https://github.com/onnx/onnx/blob/main/docs/Operators.md
- オペレータのドキュメント(一覧)
-
https://github.com/onnx/models
- ONNX 形式の学習済みモデルが多数置いてあるリポジトリ
- PyTorch などからエクスポートする他にこういった場所からダウンロードすることもできる
-
https://github.com/PINTO0309/PINTO_model_zoo
- 個人で膨大な数の ONNX 形式モデルを管理されている方のリポジトリ
-
https://github.com/onnx/onnx/blob/main/docs/PythonAPIOverview.md
- Python 向け API の使い方
ONNX の構造
ONNX フォーマット(以下 ONNX)は、IR.md にもあるように、Protocol Buffers として定義されています。
つまり、*.onnx
というファイルを見かけたら、Protocol Buffers としてデシリアライズすれば中身を覗くことができるというわけです。
まずは大まかな構造を把握して、その後に細かい部分を見ていきましょう。
全体像
ONNX は以下のような階層構造から成り立っています(いくつか省いている要素があります)。
- Model (
onnx.ModelProto
) - モデル(ファイルのトップレベル要素)- Graph (
onnx.GraphProto
) - 計算グラフ(DAG; トポロジカルソート済み)- Node (
onnx.NodeProto
) - 計算ノード
- Node (
- Graph (
また、これらの要素に含まれるフィールドとして以下が存在します(個人的に重要だと判断したものを抜粋しています)。
- Tensor (
onnx.TensorProto
) - テンソル(重み)のデータ - ValueInfo (
onnx.ValueInfoProto
) - 値を説明する情報 - Attribute (
onnx.AttributeProto
) - Node に指定される情報
解釈
ONNX は基本的に、なにか入力値を受け取って、なにか出力値を返すという計算グラフを表現するためのものです。
ここで、入力値(の配列)は model.graph.input
(つまり ModelProto model > GraphProto graph > ValueInfoProto[] input
) に対応します(これはモデルの実行時に渡される値)。
同様に、出力値(の配列)は model.graph.output
に対応します。
また、重みなどの定数値はしばしば model.graph.initializer
として定義されており、これも入力値とみなせます(これは実行前から用意されている値)。
これらの入力値を model.graph.node
(つまり ModelProto model > GraphProto graph > NodeProto[] node
) で指定された計算に従って処理することで、出力値を得ます。
Operator
Node には op_type
(文字列)というフィールドがあり、ここに Operators にて紹介されている命令を指定することで、計算の種類を指定することができます。
ある operator はある domain に属するという定義となっていて、例えば Operators で定義されている operator はほとんどが ai.onnx
というドメインに属しています。他には com.microsoft.experimental
などのドメインが存在しています。
自ら新たなドメインを追加・そのドメインに属する operator を追加することで、ONNX の仕様を超えた計算を可能にすることもできます。
Operator versioning
Operator にはバージョンが付与されます。
バージョンは、Operator の定義に何らかの変更が加えられるごとに増える数値です。
結論としては、Operator は三つ組 (domain, op_type, since_version)
(表記: ドメイン:Operator名:変更が加えられる最初のバージョン
; e.g., ai.onnx:Conv:3
)により区別できます。
また、Operator set(opset)という概念にも注意しましょう。
Opset は組 (domain, version)
として与えられ、あるドメインのどのバージョンの Operator が利用可能かを表します。
Opset を ModelProto.opset_import
(これは配列)に追加することで、そのモデルで利用可能な opset を指定できます。
たとえば、opset_import
に (ai.onnx, 18)
と (com.myownops, 2)
が指定されており、
ある計算ノードにはドメインとして com.myownops
が指定されている場合、
その計算ノードの Operator のバージョンは 2 となります。
PyTorch などのフレームワークから ONNX をエクスポートする場合、
(そのモデルを表現できる Operator が存在しない場合)指定する opset_version によってはエラーになる可能性があるので注意しましょう。
(PyTorch 側から自分で Operator を定義して、独自ドメインから利用することもできます。詳しくは調べてみてください)
ONNX の扱い方
実際に ONNX ファイルを用いて開発を行う上で、便利な豆知識をまとめておきます。
モデルの可視化
便利なコマンドとして netron
があります。
ブラウザからモデルを確認することができ便利です。
# netron /path/to/model.onnx
Serving '/path/to/model.onnx' at http://localhost:8080
便利なスニペット
モデルの読み込み
PyPI からインストールできる onnx
ライブラリを使うことで、Protocol Buffers をあまり意識せずに、モデルを扱うことができます。
import onnx # pip install onnx
model = onnx.load("/path/to/model.onnx") # ModelProto として読み込まれる
モデルの書き込み
import onnx
onnx.save(model, "/path/to/model.onnx")
モデルの整合性チェック
import onnx
try:
onnx.checker.check_model(model)
except onnx.checker.ValidationError as e:
print(f"Invalid: {e}")
else:
print("Valid")
モデルの Shape 推論
import onnx
from onnx import shape_inference
inferred_model = shape_inference.infer_shapes(model) # 値に Shape が付与される
モデルの Opset version 変換
from onnx import version_converter
model = onnx.load("/path/to/model.onnx")
# 古い opset version → 新しい opset version なら大抵成功する
new_model = version_converter.convert_version(original_model, TARGET_VERSION)
モデルに含まれる Operator の種類の列挙
import onnx
model = onnx.load("/path/to/model.onnx")
ops = set()
for node in model.graph.node:
ops.add(node.op_type)
print(ops)
モデルに含まれる値の列挙
import onnx
model = onnx.load("/path/to/model.onnx")
value_names = set()
# 入力
for input in model.graph.input:
value_names.add(input.name)
# 出力
for output in model.graph.output:
value_names.add(output.name)
# 初期値
for init in model.graph.initializer:
value_names.add(init.name)
# Node の出力
for node in model.graph.node:
for output in node.output:
value_names.add(output)
print(value_names)
Numpy との連携
onnx.numpy_helper
を使うことで、numpy 配列等と protobuf の変換が簡単に行えます。
import numpy
from onnx import numpy_helper
# e.g., numpy.array → TensorProto
init = numpy_helper.from_array(numpy.array([3.14, 3.14, 3.14], dtype=numpy.float32))
init.name = "pi"
print(init)
モデルのグラフ構造を単純化
しばしば、モデルのグラフ構造が複雑(e.g., 使われてほしくない operator が含まれている)な場合があります。
その場合、onnx-simplifier を用いることで、単純化できる場合があります。
import onnx
import onnxsim # pip install onnxsim
model = onnx.load("/path/to/model.onnx")
new_model, ok = onnxsim.simplify(model)
assert ok, "Failed to simplify"
PyTorch モデルを ONNX としてエクスポート
torch.onnx.export
を使うことが多いです。
import torch
import torchvision
model = torchvision.model.mobilenet_v3_large(pretrained=True) # 簡単なモデルなら大抵は大丈夫
path = "/path/to/out_model.onnx"
torch.onnx.export(
model,
torch.randn(1, 3, 224, 224), # shape が正しければ値は関係ない
path,
opset_version=14 # 適切な opset version を自ら指定
)
Optimum を用いた ONNX エクスポート
Optimum を用いることで、PyTorch などから素直にエクスポートできないモデルに対応することができます。
# pip install git+https://github.com/huggingface/optimum.git
# optimum-cli export onnx -m deepset/roberta-base-squad2 --optimize O2 roberta_base_qa_onnx
# ls
roberta_base_qa_onnx
まだありそう
TODO...
ModelProto
の使いやすいラッパー
OnnxModel
のようなラッパーを用意することで、グラフの解析がしやすくなります。
いくつもスニペットを作成して、散らかっていくよりは良いかもしれません。
# ...
model = onnx.load("/path/to/model.onnx")
new_model = OnnxModel(model)
# e.g., あるノードの children(そのノードの出力を使っているノード)を知りたい
children = new_model.get_children(specific_node)
# e.g., あるノードに定数入力があるならそれを知りたい
nth, const_node = new_model.get_constant_input(specific_node)
Discussion