Open1

GPUでOOMを引くようなサイズが大きなPyTorchモデルをONNXへエクスポートして最適化するワークアラウンド

PINTOPINTO
H=240
W=320
# INPUT Name: input.1, gray
# OUTPUT Name: 242, 252
onnx_file = f"enlightengan_HxW.onnx"
x1 = torch.randn(1, 3, H, W).cuda()
x2 = torch.randn(1, 1, H, W).cuda()
torch.onnx.export(
    model.netG_A.module,
    args=(x1,x2),
    f=onnx_file,
    opset_version=11,
    dynamic_axes={
        'input.1' : {2: 'height', 3: 'width'},
        'gray' : {2: 'height', 3: 'width'},
        '242' : {2: 'height', 3: 'width'},
        '252' : {2: 'height', 3: 'width'}
    }
)
import sys
sys.exit(0)
set_static_shape.py
import onnx
from onnxsim import simplify

H=1440
W=2560
MODEL='enlightengan'
model = onnx.load(f'{MODEL}_HxW.onnx')
model_simp, check = simplify(
    model,
    input_shapes={
        "input.1": [1,3,H,W],
        "gray": [1,1,H,W],
    }
)
onnx.save(model_simp, f'{MODEL}_{H}x{W}.onnx')