Open1

onnxファイルから不要なノードをまとめて削除しつつ、さらに別のノードを新規生成して指定位置に挿入する一例

PINTOPINTO
import onnx
from onnx import helper
from onnx import AttributeProto, TensorProto, GraphProto
import onnx_graphsurgeon as gs
import numpy as np

# Delete garbage nodes
remove_node_names = ['Shape_4385','Gather_4386','Equal_4388','If_4389']
graph = gs.import_onnx(onnx.load(f"msnet3d_sf_maxdisp192_240x320.onnx"))
remove_nodes = [node for node in graph.nodes if node.name in remove_node_names]
inp_node = remove_nodes[0].i()
out_node = remove_nodes[2].o()
inp_node.outputs = out_node.outputs
print(out_node.outputs)
for remove_node in remove_nodes:
    remove_node.outputs.clear()
graph.cleanup()
# Temporarily output to ONNX file to debug that the node was deleted successfully,
# even though it is not necessary.
onnx.save(gs.export_onnx(graph), "removed.onnx")

# Adding a Squeeze node
graph = gs.import_onnx(onnx.load(f"removed.onnx"))
pre_node_names = ['Resize_4383']
post_node_names = ['Transpose_4392']
pre_nodes = [node for node in graph.nodes if node.name in pre_node_names]
post_nodes = [node for node in graph.nodes if node.name in post_node_names]
dummy_squeeze_out = gs.Variable("10000", dtype=np.float32)
dummy_squeeze = gs.Node(
    op="Squeeze",
    name="dummy_squeeze",
    attrs={'axes': [1]},
    inputs=pre_nodes[0].outputs,
    outputs=[dummy_squeeze_out]
)
graph.nodes.append(dummy_squeeze)
post_nodes[0].inputs = dummy_squeeze.outputs

graph.cleanup().toposort()
onnx.save(gs.export_onnx(graph), "removed_appended.onnx")