Zenn
Open1

Module-LLM用にonnx変換できるか調べるスクリプト

nnn112358nnn112358

Module-LLM用にonnx変換できるか調べるスクリプトです。
こちらにAX630C でサポートしているOperatorのリストがあります。
https://pulsar2-docs.readthedocs.io/en/latest/appendix/op_support_list_ax620e.html
Pythonでこれとマッチするかの確認を行っています。

import onnx
from collections import Counter
import argparse

# 指定されたオペレーターのカテゴリ
categories = {
    "基本的な算術演算子": ["Add", "Sub", "Mul", "Div", "Max", "Min", "Pow"],
    "活性化関数": [
        "Relu", "Sigmoid", "Tanh", "LeakyRelu", "Elu", "Gelu", "HardSigmoid",
        "HardSwish", "Softmax", "Softplus", "PRelu", "Mish", "Silu", "Swish"
    ],
    "正規化層": ["BatchNormalization", "LayerNormalization", "InstanceNormalization", "LpNormalization"],
    "畳み込み演算": [
        "Conv", "ConvTranspose", "AveragePool", "MaxPool", 
        "GlobalAveragePool", "GlobalMaxPool"
    ],
    "削減演算": ["ReduceSum", "ReduceMean", "ReduceMax", "ReduceL2"],
    "形状操作": ["Reshape", "Transpose", "Squeeze", "Unsqueeze", "Flatten", "DepthToSpace", "SpaceToDepth"],
    "比較演算子": ["Equal", "Greater", "GreaterOrEqual", "Less", "LessOrEqual"],
    "数学関数": ["Abs", "Sqrt", "Exp", "Sin", "Ceil", "Erf", "Clip"],
    "その他の演算子": [
        "Where", "Gather", "Split", "Concat", "Slice", "Cast", "TopK", 
        "ArgMax", "ArgMin", "Constant", "ConstantOfShape", "Identity", 
        "Expand", "LSTM", "Pad", "GridSample", "Resize", 
        "SpatialTransformer", "MatMul", "Gemm"
    ]
}

# すべての指定されたオペレーター一覧を1つのリストにまとめる
all_defined_operators = set(op for ops in categories.values() for op in ops)

def check_undefined_operators(onnx_file_path):
    # ONNXモデルを読み込む
    model = onnx.load(onnx_file_path)
    graph = model.graph

    # モデル内で使用されているオペレーターを取得
    model_operators = set(node.op_type for node in graph.node)

    # 指定されたリストにないオペレーターを特定
    undefined_operators = model_operators - all_defined_operators

    # 結果を表示
    if undefined_operators:
        print("ax630cのNPU用にPulsar2で変換できません undefined operators:")
        for op in undefined_operators:
            print(op)
    else:
        print("ax630cのNPU用にPulsar2で変換できます!")

def main():
    parser = argparse.ArgumentParser(description="Check for undefined operators in an ONNX model.")
    parser.add_argument("onnx_file", type=str, help="Path to the ONNX file.")
    args = parser.parse_args()

    check_undefined_operators(args.onnx_file)

if __name__ == "__main__":
    main()
ログインするとコメントできます