🤪
Steps to merge two ONNX files into one
1. Introduction
This section describes a simple procedure for combining two ONNX into one ONNX. Because it's a pain to modify the original PyTorch program when exporting ONNX from PyTorch. In order to combine two ONNX with this simple procedure, the prerequisites are
- The number of outputs of the first ONNX must match the number of inputs of the second ONNX.
- The shape of the output of the first ONNX must match the shape of the input of the second ONNX.
-
ONNX.1: lite_hr_depth_k_t_encoder_192x640.onnx
INPUT: ['0']
-
OUTPUT: ['317', '852', '870', '897', '836']
-
ONNX.2: lite_hr_depth_k_t_depth_192x640.onnx
INPUT: ['0', 'input.1', 'input.13', 'input.25', 'input.37']
-
OUTPUT: ['732', '757', '782', '807']
-
Result ONNX (ONNX.1 + ONNX.2): lite_hr_depth_k_t_encoder_depth_192x640.onnx
INPUT: ['0a']
-
OUTPUT: ['732b', '757b', '782b', '807b']
2. Environment
- Python3.x
- ONNX 1.8.x
- onnx-graphsurgeon
- sclblonnx
3. Procedure
3-1. Install
$ python3 -m pip install onnx_graphsurgeon \
--index-url https://pypi.ngc.nvidia.com
$ pip install sclblonnx
3-2. Merge
- Create
merge_onnx.py
.- The operation name of the first ONNX must not overlap with the operation name of the second ONNX at all, so we rename them in the program to never overlap.
- Force the
a
prefix to the end of the first ONNX operation name. - Force the
b
prefix to the end of the second ONNX operation. - Modify the name of the operation with
onnx_graphsurgeon
. - Combine two ONNX with
sclblonnx
import onnx
import onnx_graphsurgeon as gs
import sclblonnx as so
H=192
W=640
MODEL1=f'lite_hr_depth_k_t_encoder_{H}x{W}.onnx'
MODEL2=f'lite_hr_depth_k_t_depth_{H}x{W}.onnx'
MODEL3=f'lite_hr_depth_k_t_encoder_depth_{H}x{W}.onnx'
graph1 = gs.import_onnx(onnx.load(MODEL1))
for n in graph1.nodes:
for cn in n.inputs:
if cn.name[-1:] != 'a':
cn.name = f'{cn.name}a'
else:
pass
for cn in n.outputs:
if cn.name[-1:] != 'a':
cn.name = f'{cn.name}a'
else:
pass
graph1_outputs = [o.name for o in graph1.outputs]
print(f'graph1 outputs: {graph1_outputs}')
onnx.save(gs.export_onnx(graph1), "graph1.onnx")
graph2 = gs.import_onnx(onnx.load(MODEL2))
graph2_inputs = []
for n in graph2.nodes:
for cn in n.inputs:
if cn.name[-1:] != 'b':
cn.name = f'{cn.name}b'
else:
pass
for cn in n.outputs:
if cn.name[-1:] != 'b':
cn.name = f'{cn.name}b'
else:
pass
graph2_inputs = [i.name for i in graph2.inputs]
print(f'graph2 inputs: {graph2_inputs}')
onnx.save(gs.export_onnx(graph2), "graph2.onnx")
"""
graph1 outputs: [
'317a',
'852a',
'870a',
'897a',
'836a'
]
graph2 inputs: [
'0b',
'input.1b',
'input.13b',
'input.25b',
'input.37b'
]
"""
sg1 = so.graph_from_file('graph1.onnx')
sg2 = so.graph_from_file('graph2.onnx')
sg3 = so.merge(
sg1,
sg2,
outputs=graph1_outputs,
inputs=graph2_inputs
)
so.graph_to_file(sg3, MODEL3)
- Run
merge_onnx.py
.
$ python3 merge_onnx.py
graph1 outputs: ['317a', '852a', '870a', '897a', '836a']
graph2 inputs: ['0b', 'input.1b', 'input.13b', 'input.25b', 'input.37b']
Constructing the io_match list from your input and output.
4. Appendix
Discussion