Open1
onnxファイルから不要なノードをまとめて削除しつつ、さらに別のノードを新規生成して指定位置に挿入する一例
import onnx
from onnx import helper
from onnx import AttributeProto, TensorProto, GraphProto
import onnx_graphsurgeon as gs
import numpy as np
# Delete garbage nodes
remove_node_names = ['Shape_4385','Gather_4386','Equal_4388','If_4389']
graph = gs.import_onnx(onnx.load(f"msnet3d_sf_maxdisp192_240x320.onnx"))
remove_nodes = [node for node in graph.nodes if node.name in remove_node_names]
inp_node = remove_nodes[0].i()
out_node = remove_nodes[2].o()
inp_node.outputs = out_node.outputs
print(out_node.outputs)
for remove_node in remove_nodes:
remove_node.outputs.clear()
graph.cleanup()
# Temporarily output to ONNX file to debug that the node was deleted successfully,
# even though it is not necessary.
onnx.save(gs.export_onnx(graph), "removed.onnx")
# Adding a Squeeze node
graph = gs.import_onnx(onnx.load(f"removed.onnx"))
pre_node_names = ['Resize_4383']
post_node_names = ['Transpose_4392']
pre_nodes = [node for node in graph.nodes if node.name in pre_node_names]
post_nodes = [node for node in graph.nodes if node.name in post_node_names]
dummy_squeeze_out = gs.Variable("10000", dtype=np.float32)
dummy_squeeze = gs.Node(
op="Squeeze",
name="dummy_squeeze",
attrs={'axes': [1]},
inputs=pre_nodes[0].outputs,
outputs=[dummy_squeeze_out]
)
graph.nodes.append(dummy_squeeze)
post_nodes[0].inputs = dummy_squeeze.outputs
graph.cleanup().toposort()
onnx.save(gs.export_onnx(graph), "removed_appended.onnx")