Open1
GPUでOOMを引くようなサイズが大きなPyTorchモデルをONNXへエクスポートして最適化するワークアラウンド
H=240
W=320
# INPUT Name: input.1, gray
# OUTPUT Name: 242, 252
onnx_file = f"enlightengan_HxW.onnx"
x1 = torch.randn(1, 3, H, W).cuda()
x2 = torch.randn(1, 1, H, W).cuda()
torch.onnx.export(
model.netG_A.module,
args=(x1,x2),
f=onnx_file,
opset_version=11,
dynamic_axes={
'input.1' : {2: 'height', 3: 'width'},
'gray' : {2: 'height', 3: 'width'},
'242' : {2: 'height', 3: 'width'},
'252' : {2: 'height', 3: 'width'}
}
)
import sys
sys.exit(0)
set_static_shape.py
import onnx
from onnxsim import simplify
H=1440
W=2560
MODEL='enlightengan'
model = onnx.load(f'{MODEL}_HxW.onnx')
model_simp, check = simplify(
model,
input_shapes={
"input.1": [1,3,H,W],
"gray": [1,1,H,W],
}
)
onnx.save(model_simp, f'{MODEL}_{H}x{W}.onnx')