import onnx_graphsurgeon as gs
import onnx
from pprint import pprint

folder = 'saved_model_sony_240x320'
remove_node_name = 'output__80'

graph = gs.import_onnx(onnx.load(f"{folder}/model_float32.onnx"))

for i in graph.nodes:

remove_node = [node for node in graph.nodes if == remove_node_name][0]

# Get the input node of the fake node
# Node provides i() and o() functions that can optionally be provided an index (default is 0)
# These serve as convenience functions for the alternative, which would be to fetch the input/output
# tensor first, then fetch the input/output node of the tensor.
# For example, node.i() is equivalent to node.inputs[0].inputs[0]
inp_node = remove_node.i()

# Reconnect the input node to the output tensors of the fake node, so that the first identity
# node in the example graph now skips over the fake node.
inp_node.outputs = remove_node.outputs

# Remove the fake node from the graph completely
graph.cleanup(), "removed.onnx")