🤪

2021/10/24に公開

# 2. Procedure

``````\$ python3 -m pip install onnx_graphsurgeon \
--index-url https://pypi.ngc.nvidia.com
``````
``````import onnx_graphsurgeon as gs
import onnx
import argparse

parser = argparse.ArgumentParser()
args = parser.parse_args()

for i in graph.nodes:
print(i.name)

remove_node = [
node for node in graph.nodes if node.name == args.remove_node_name
][0]

# Get the input node of the fake node
# Node provides i() and o() functions that can optionally
# be provided an index (default is 0)
# These serve as convenience functions for the alternative,
# which would be to fetch the input/output
# tensor first, then fetch the input/output node of the tensor.
# For example, node.i() is equivalent to node.inputs[0].inputs[0]
inp_node = remove_node.i()

# Reconnect the input node to the output tensors of the fake node,
# so that the first identity node in the example graph now
# skips over the fake node.
inp_node.outputs = remove_node.outputs
remove_node.outputs.clear()

# Remove the fake node from the graph completely
graph.cleanup()

h = graph.inputs[0].shape[2]
w = graph.inputs[0].shape[3]

scale = 0
if graph.inputs[0].shape[1] == 4:
scale = 2
else:
scale = 3

graph.outputs[0].shape = [1,3,h*scale,w*scale]
print(graph.outputs)

onnx.save(gs.export_onnx(graph), args.onnx_file_path)
``````
``````python3 remove_transpose.py \
--onnx_file_path saved_model_sony_240x320/model_float32.onnx \
--remove_node_name output__80
``````
``````import onnx_graphsurgeon as gs
import onnx
import argparse

parser = argparse.ArgumentParser()
args = parser.parse_args()

remove_node = None
remove_node_idx = -1
for idx, node in enumerate(graph.nodes):
if node.name == args.remove_node_name:
remove_node = node
remove_node_idx = idx
break
graph.inputs[0].dtype = graph.nodes[remove_node_idx+1].inputs[0].dtype
graph.nodes[remove_node_idx+1].inputs[0] = graph.inputs[0]
remove_node.outputs.clear()
graph.cleanup()
onnx.save(gs.export_onnx(graph), args.onnx_file_path)
``````

ログインするとコメントできます