Open2

PyTorchのモデルを各種解像度で一括ONNX出力(エクスポート)するコードスニペット torch.onnx.export

PINTOPINTO
        import onnx
        from onnxsim import simplify
        RESOLUTION = [
            [192,320],
            [192,416],
            [192,640],
            [192,800],
            [256,320],
            [256,416],
            [256,448],
            [256,640],
            [256,800],
            [256,960],
            [288,480],
            [288,640],
            [288,800],
            [288,960],
            [288,1280],
            [320,320],
            [384,480],
            [384,640],
            [384,800],
            [384,960],
            [384,1280],
            [416,416],
            [480,640],
            [480,800],
            [480,960],
            [480,1280],
            [512,512],
            [512,640],
            [512,896],
            [544,800],
            [544,960],
            [544,1280],
            [640,640],
            [736,1280],
        ]
        MODEL = f'xxxx'
        for H, W in RESOLUTION:
            onnx_file = f"{MODEL}_1x3x{H}x{W}.onnx"
            x = torch.randn(1, 3, H, W).cuda()
            torch.onnx.export(
                self.model.module,
                args=(x),
                f=onnx_file,
                opset_version=11,
                input_names=['input'],
                #output_names=['output'],
            )
            model_onnx1 = onnx.load(onnx_file)
            model_onnx1 = onnx.shape_inference.infer_shapes(model_onnx1)
            onnx.save(model_onnx1, onnx_file)

            model_onnx2 = onnx.load(onnx_file)
            model_simp, check = simplify(model_onnx2)
            onnx.save(model_simp, onnx_file)
            model_onnx2 = onnx.load(onnx_file)
            model_simp, check = simplify(model_onnx2)
            onnx.save(model_simp, onnx_file)
            model_onnx2 = onnx.load(onnx_file)
            model_simp, check = simplify(model_onnx2)
            onnx.save(model_simp, onnx_file)

        onnx_file = f"{MODEL}_1x3xHxW.onnx"
        x = torch.randn(1, 3, 192, 320).cuda()
        torch.onnx.export(
            self.model.module,
            args=(x),
            f=onnx_file,
            opset_version=11,
            input_names=['input'],
            #output_names=['output'],
            dynamic_axes={
                'input' : {2: 'height', 3: 'width'},
                'output' : {2: 'height', 3: 'width'},
            }
        )
        model_onnx1 = onnx.load(onnx_file)
        model_onnx1 = onnx.shape_inference.infer_shapes(model_onnx1)
        onnx.save(model_onnx1, onnx_file)

        model_onnx2 = onnx.load(onnx_file)
        model_simp, check = simplify(model_onnx2)
        onnx.save(model_simp, onnx_file)
        model_onnx2 = onnx.load(onnx_file)
        model_simp, check = simplify(model_onnx2)
        onnx.save(model_simp, onnx_file)
        model_onnx2 = onnx.load(onnx_file)
        model_simp, check = simplify(model_onnx2)
        onnx.save(model_simp, onnx_file)

        onnx_file = f"{MODEL}_Nx3xHxW.onnx"
        x = torch.randn(1, 3, 192, 320).cuda()
        torch.onnx.export(
            self.model.module,
            args=(x),
            f=onnx_file,
            opset_version=11,
            input_names=['input'],
            output_names=['output'],
            dynamic_axes={
                'input' : {0: 'N', 2: 'height', 3: 'width'},
                'output' : {0: 'N', 2: 'height', 3: 'width'},
            }
        )
        model_onnx1 = onnx.load(onnx_file)
        model_onnx1 = onnx.shape_inference.infer_shapes(model_onnx1)
        onnx.save(model_onnx1, onnx_file)

        model_onnx2 = onnx.load(onnx_file)
        model_simp, check = simplify(model_onnx2)
        onnx.save(model_simp, onnx_file)
        model_onnx2 = onnx.load(onnx_file)
        model_simp, check = simplify(model_onnx2)
        onnx.save(model_simp, onnx_file)
        model_onnx2 = onnx.load(onnx_file)
        model_simp, check = simplify(model_onnx2)
        onnx.save(model_simp, onnx_file)

        import sys
        sys.exit(0)
PINTOPINTO
        import onnx
        from onnxsim import simplify
        RESOLUTION = [
            [180,320],
            [180,416],
            [180,512],
            [180,640],
            [180,800],
            [240,320],
            [240,416],
            [240,512],
            [240,640],
            [240,800],
            [240,960],
            [256,448],
            [288,480],
            [288,512],
            [288,640],
            [288,800],
            [288,960],
            [288,1280],
            [320,320],
            [360,480],
            [360,512],
            [360,640],
            [360,800],
            [360,960],
            [360,1280],
            [376,1344],
            [416,416],
            [480,640],
            [480,800],
            [480,960],
            [480,1280],
            [512,512],
            [512,896],
            [540,800],
            [540,960],
            [540,1280],
            [640,640],
            [640,960],
            [720,1280],
            [720,2560],
            [1080,1920],
        ]
        MODEL = f'xxxx'
        for H, W in RESOLUTION:
            onnx_file = f"{MODEL}_1x3x{H}x{W}.onnx"
            x = torch.randn(1, 3, H, W).cuda()
            torch.onnx.export(
                self.model.module,
                args=(x),
                f=onnx_file,
                opset_version=11,
                input_names=['input'],
                #output_names=['output'],
            )
            model_onnx1 = onnx.load(onnx_file)
            model_onnx1 = onnx.shape_inference.infer_shapes(model_onnx1)
            onnx.save(model_onnx1, onnx_file)

            model_onnx2 = onnx.load(onnx_file)
            model_simp, check = simplify(model_onnx2)
            onnx.save(model_simp, onnx_file)
            model_onnx2 = onnx.load(onnx_file)
            model_simp, check = simplify(model_onnx2)
            onnx.save(model_simp, onnx_file)
            model_onnx2 = onnx.load(onnx_file)
            model_simp, check = simplify(model_onnx2)
            onnx.save(model_simp, onnx_file)

        onnx_file = f"{MODEL}_1x3xHxW.onnx"
        x = torch.randn(1, 3, 180, 320).cuda()
        torch.onnx.export(
            self.model.module,
            args=(x),
            f=onnx_file,
            opset_version=11,
            input_names=['input'],
            #output_names=['output'],
            dynamic_axes={
                'input' : {2: 'height', 3: 'width'},
                'output' : {2: 'height', 3: 'width'},
            }
        )
        model_onnx1 = onnx.load(onnx_file)
        model_onnx1 = onnx.shape_inference.infer_shapes(model_onnx1)
        onnx.save(model_onnx1, onnx_file)

        model_onnx2 = onnx.load(onnx_file)
        model_simp, check = simplify(model_onnx2)
        onnx.save(model_simp, onnx_file)
        model_onnx2 = onnx.load(onnx_file)
        model_simp, check = simplify(model_onnx2)
        onnx.save(model_simp, onnx_file)
        model_onnx2 = onnx.load(onnx_file)
        model_simp, check = simplify(model_onnx2)
        onnx.save(model_simp, onnx_file)

        onnx_file = f"{MODEL}_Nx3xHxW.onnx"
        x = torch.randn(1, 3, 180, 320).cuda()
        torch.onnx.export(
            self.model.module,
            args=(x),
            f=onnx_file,
            opset_version=11,
            input_names=['input'],
            output_names=['output'],
            dynamic_axes={
                'input' : {0: 'N', 2: 'height', 3: 'width'},
                'output' : {0: 'N', 2: 'height', 3: 'width'},
            }
        )
        model_onnx1 = onnx.load(onnx_file)
        model_onnx1 = onnx.shape_inference.infer_shapes(model_onnx1)
        onnx.save(model_onnx1, onnx_file)

        model_onnx2 = onnx.load(onnx_file)
        model_simp, check = simplify(model_onnx2)
        onnx.save(model_simp, onnx_file)
        model_onnx2 = onnx.load(onnx_file)
        model_simp, check = simplify(model_onnx2)
        onnx.save(model_simp, onnx_file)
        model_onnx2 = onnx.load(onnx_file)
        model_simp, check = simplify(model_onnx2)
        onnx.save(model_simp, onnx_file)

        import sys
        sys.exit(0)