Open3

PyTorch v2.0.0 を改造して opset=18 の ONNX をエクスポートできるかどうか実験

PINTOPINTO
git clone --recursive https://github.com/pytorch/pytorch \
&& cd pytorch && git checkout 789b1437e945336f83c915ab2f2dd283ac472191

sudo sed -i -e \
"/^#ifndef THRUST_IGNORE_CUB_VERSION_CHECK$/i #define THRUST_IGNORE_CUB_VERSION_CHECK" \
/usr/local/cuda/targets/x86_64-linux/include/thrust/system/cuda/config.h

rm -rf third_party/onnx
git clone --recursive https://github.com/onnx/onnx.git
torch/onnx/_constants.py
"""Constant values used in ONNX."""

ONNX_ARCHIVE_MODEL_PROTO_NAME = "__MODEL_PROTO"

ONNX_BASE_OPSET = 9
ONNX_MIN_OPSET = 7
ONNX_MAX_OPSET = 18
# ONNX_DEFAULT_OPSET generated by tools/onnx/update_default_opset_version.py
ONNX_DEFAULT_OPSET = 14
ONNX_CONSTANT_FOLDING_MIN_OPSET = 9

PYTORCH_GITHUB_ISSUES_URL = "https://github.com/pytorch/pytorch/issues"

INT64_MAX = 9223372036854775807
torch/csrc/jit/serialization/export.cpp
const static int kInvalidOpsetVersion = -1;
const static int kMainOpsetVersion = 18;
// Based on OP_SET_ID_VERSION_MAP in
// https://github.com/onnx/onnx/blob/master/onnx/helper.py.
constexpr static std::array<int64_t, kMainOpsetVersion + 1>
    kOpsetVersionToIRVersion = {
        kInvalidOpsetVersion,
        3, // opset 1
        kInvalidOpsetVersion,
        kInvalidOpsetVersion,
        kInvalidOpsetVersion,
        3, // opset 5
        3, // opset 6
        3, // opset 7
        3, // opset 8
        4, // opset 9
        5, // opset 10
        6, // opset 11
        7, // opset 12
        7, // opset 13
        7, // opset 14
        8, // opset 15
        8, // opset 16
        8, // opset 17
        8, // opset 18
};
pip3 install -r requirements.txt \
&& USE_NCCL=OFF python3 setup.py build \
&& python3 setup.py bdist_wheel

pip install dist/torch-2.0.0a0+git789b143-cp38-cp38-linux_x86_64.whl --force-reinstall
PINTOPINTO
make_GroupNormalization.py
#! /usr/bin/env python

import torch
import torch.nn as nn
import numpy as np
import onnx
from onnxsim import simplify
import numpy as np
np.random.seed(0)


class pseudo_GroupNorm(nn.Module):
    def __init__(self, num_features, num_groups=3, eps=1e-5):
        super(pseudo_GroupNorm, self).__init__()
        self.weight = nn.Parameter(torch.ones(1,num_features,1,1))
        self.bias = nn.Parameter(torch.zeros(1,num_features,1,1))
        self.num_groups = num_groups
        self.eps = eps

    def forward(self, x):
        N,C,H,W = x.size()
        G = self.num_groups
        assert C % G == 0

        x = x.view(N,G,-1)
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, keepdim=True)

        x = (x-mean) / (var+self.eps).sqrt()
        x = x.view(N,C,H,W)
        return x * self.weight + self.bias


class Model(nn.Module):
    def __init__(
        self,
    ):
        super(Model, self).__init__()
        self.gn = nn.GroupNorm(3, 6)

    def forward(self, x):
        return self.gn(x)


if __name__ == "__main__":
    OPSET=11
    MODEL = f'GroupNormalization'
    model = Model()
    onnx_file = f"{MODEL}_{OPSET}.onnx"
    x = torch.randn(20, 6, 10, 10)
    torch.onnx.export(
        model,
        args=(x),
        f=onnx_file,
        opset_version=OPSET,
        input_names=[
            f'{MODEL}_input',
        ],
        output_names=[
            f'{MODEL}_output',
        ],
        do_constant_folding=False,
    )
    model_onnx1 = onnx.load(onnx_file)
    model_onnx1 = onnx.shape_inference.infer_shapes(model_onnx1)
    onnx.save(model_onnx1, onnx_file)
    model_onnx2 = onnx.load(onnx_file)
    model_simp, check = simplify(model_onnx2)
    onnx.save(model_simp, onnx_file)

    OPSET=18
    MODEL = f'GroupNormalization'
    model = Model()
    onnx_file = f"{MODEL}_{OPSET}.onnx"
    x = torch.randn(20, 6, 10, 10)
    torch.onnx.export(
        model,
        args=(x),
        f=onnx_file,
        opset_version=OPSET,
        input_names=[
            f'{MODEL}_input',
        ],
        output_names=[
            f'{MODEL}_output',
        ],
        do_constant_folding=False,
    )
    model_onnx1 = onnx.load(onnx_file)
    model_onnx1 = onnx.shape_inference.infer_shapes(model_onnx1)
    onnx.save(model_onnx1, onnx_file)
    model_onnx2 = onnx.load(onnx_file)
    model_simp, check = simplify(model_onnx2)
    onnx.save(model_simp, onnx_file)
PINTOPINTO

[結果] 出力はできたが期待したものでは無かった