Open1
Module-LLM用にonnx変換できるか調べるスクリプト

Module-LLM用にonnx変換できるか調べるスクリプトです。
こちらにAX630C でサポートしているOperatorのリストがあります。
Pythonでこれとマッチするかの確認を行っています。
import onnx
from collections import Counter
import argparse
# 指定されたオペレーターのカテゴリ
categories = {
"基本的な算術演算子": ["Add", "Sub", "Mul", "Div", "Max", "Min", "Pow"],
"活性化関数": [
"Relu", "Sigmoid", "Tanh", "LeakyRelu", "Elu", "Gelu", "HardSigmoid",
"HardSwish", "Softmax", "Softplus", "PRelu", "Mish", "Silu", "Swish"
],
"正規化層": ["BatchNormalization", "LayerNormalization", "InstanceNormalization", "LpNormalization"],
"畳み込み演算": [
"Conv", "ConvTranspose", "AveragePool", "MaxPool",
"GlobalAveragePool", "GlobalMaxPool"
],
"削減演算": ["ReduceSum", "ReduceMean", "ReduceMax", "ReduceL2"],
"形状操作": ["Reshape", "Transpose", "Squeeze", "Unsqueeze", "Flatten", "DepthToSpace", "SpaceToDepth"],
"比較演算子": ["Equal", "Greater", "GreaterOrEqual", "Less", "LessOrEqual"],
"数学関数": ["Abs", "Sqrt", "Exp", "Sin", "Ceil", "Erf", "Clip"],
"その他の演算子": [
"Where", "Gather", "Split", "Concat", "Slice", "Cast", "TopK",
"ArgMax", "ArgMin", "Constant", "ConstantOfShape", "Identity",
"Expand", "LSTM", "Pad", "GridSample", "Resize",
"SpatialTransformer", "MatMul", "Gemm"
]
}
# すべての指定されたオペレーター一覧を1つのリストにまとめる
all_defined_operators = set(op for ops in categories.values() for op in ops)
def check_undefined_operators(onnx_file_path):
# ONNXモデルを読み込む
model = onnx.load(onnx_file_path)
graph = model.graph
# モデル内で使用されているオペレーターを取得
model_operators = set(node.op_type for node in graph.node)
# 指定されたリストにないオペレーターを特定
undefined_operators = model_operators - all_defined_operators
# 結果を表示
if undefined_operators:
print("ax630cのNPU用にPulsar2で変換できません undefined operators:")
for op in undefined_operators:
print(op)
else:
print("ax630cのNPU用にPulsar2で変換できます!")
def main():
parser = argparse.ArgumentParser(description="Check for undefined operators in an ONNX model.")
parser.add_argument("onnx_file", type=str, help="Path to the ONNX file.")
args = parser.parse_args()
check_undefined_operators(args.onnx_file)
if __name__ == "__main__":
main()
ログインするとコメントできます