🤪
Delete any layer of ONNX
1. Environment
- Python
- onnx_graphsurgeon
- https://github.com/cchen156/Learning-to-See-in-the-Dark
2. Procedure
$ python3 -m pip install onnx_graphsurgeon \
--index-url https://pypi.ngc.nvidia.com
import onnx_graphsurgeon as gs
import onnx
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--onnx_file_path", required=True, type=str)
parser.add_argument("--remove_node_name", required=True, type=str)
args = parser.parse_args()
graph = gs.import_onnx(onnx.load(args.onnx_file_path))
for i in graph.nodes:
print(i.name)
remove_node = [
node for node in graph.nodes if node.name == args.remove_node_name
][0]
# Get the input node of the fake node
# Node provides i() and o() functions that can optionally
# be provided an index (default is 0)
# These serve as convenience functions for the alternative,
# which would be to fetch the input/output
# tensor first, then fetch the input/output node of the tensor.
# For example, node.i() is equivalent to node.inputs[0].inputs[0]
inp_node = remove_node.i()
# Reconnect the input node to the output tensors of the fake node,
# so that the first identity node in the example graph now
# skips over the fake node.
inp_node.outputs = remove_node.outputs
remove_node.outputs.clear()
# Remove the fake node from the graph completely
graph.cleanup()
h = graph.inputs[0].shape[2]
w = graph.inputs[0].shape[3]
scale = 0
if graph.inputs[0].shape[1] == 4:
scale = 2
else:
scale = 3
graph.outputs[0].shape = [1,3,h*scale,w*scale]
print(graph.outputs)
onnx.save(gs.export_onnx(graph), args.onnx_file_path)
python3 remove_transpose.py \
--onnx_file_path saved_model_sony_240x320/model_float32.onnx \
--remove_node_name output__80
import onnx_graphsurgeon as gs
import onnx
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--onnx_file_path", required=True, type=str)
parser.add_argument("--remove_node_name", required=True, type=str)
args = parser.parse_args()
graph = gs.import_onnx(onnx.load(args.onnx_file_path))
remove_node = None
remove_node_idx = -1
for idx, node in enumerate(graph.nodes):
if node.name == args.remove_node_name:
remove_node = node
remove_node_idx = idx
break
graph.inputs[0].dtype = graph.nodes[remove_node_idx+1].inputs[0].dtype
graph.nodes[remove_node_idx+1].inputs[0] = graph.inputs[0]
remove_node.outputs.clear()
graph.cleanup()
onnx.save(gs.export_onnx(graph), args.onnx_file_path)
Discussion