Open1

ONNXの特定のレイヤーを削除

PINTOPINTO
import onnx_graphsurgeon as gs
import onnx
from pprint import pprint

folder = 'saved_model_sony_240x320'
remove_node_name = 'output__80'

graph = gs.import_onnx(onnx.load(f"{folder}/model_float32.onnx"))

for i in graph.nodes:
    print(i.name)

remove_node = [node for node in graph.nodes if node.name == remove_node_name][0]

# Get the input node of the fake node
# Node provides i() and o() functions that can optionally be provided an index (default is 0)
# These serve as convenience functions for the alternative, which would be to fetch the input/output
# tensor first, then fetch the input/output node of the tensor.
# For example, node.i() is equivalent to node.inputs[0].inputs[0]
inp_node = remove_node.i()

# Reconnect the input node to the output tensors of the fake node, so that the first identity
# node in the example graph now skips over the fake node.
inp_node.outputs = remove_node.outputs
remove_node.outputs.clear()

# Remove the fake node from the graph completely
graph.cleanup()
onnx.save(gs.export_onnx(graph), "removed.onnx")