Open2

"RuntimeError: Unsupported: ONNX export of instance_norm for unknown channel size." のワークアラウンド

PINTOPINTO
class InstanceNormAlternative(nn.InstanceNorm2d):
    def forward(self, inp: torch.Tensor) -> torch.Tensor:
        self._check_input_dim(inp)
        desc = 1 / (inp.var(axis=[2, 3], keepdim=True, unbiased=False) + self.eps) ** 0.5
        retval = (inp - inp.mean(axis=[2, 3], keepdim=True)) * desc
        return retval