Open7

ONNX エクスポートのための torch.inverse (torch.inv) の置き換えワークアラウンド(2D〜6D)

PINTOPINTO
import torch
from torch.linalg import det

def cof1(M,index):
    zs = M[:index[0]-1,:index[1]-1]
    ys = M[:index[0]-1,index[1]:]
    zx = M[index[0]:,:index[1]-1]
    yx = M[index[0]:,index[1]:]
    s = torch.cat((zs,ys),dim=1)
    x = torch.cat((zx,yx),dim=1)
    return det(torch.cat((s,x),dim=0))

def alcof(M,index):
    return pow(-1,index[0]+index[1])*cof1(M,index)

def adj(M):
    result = torch.zeros((M.shape[0],M.shape[1]))
    for i in range(1,M.shape[0]+1):
        for j in range(1,M.shape[1]+1):
            result[j-1][i-1] = alcof(M,[i,j])
    return result

def invmat(M):
    return 1.0/det(M)*adj(M)

M = torch.FloatTensor([[1,2,-1],[2,3,4],[3,1,2]])
print(invmat(M))
print(torch.inverse(M))
PINTOPINTO

あるいは、Export Built-In Contrib Ops というものを取り込んでエクスポートをする。

from onnxruntime.tools import pytorch_export_contrib_ops
import torch

pytorch_export_contrib_ops.register()
torch.onnx.export(...)
PINTOPINTO
import torch
M = torch.tensor(
    [
        [[1., 2., 3.],[1.5, 2., 2.3],[.1, .2, .5]],
        [[.1, .2, .5],[1.5, 2., 2.3],[1.,2., 3.]],
    ]
)

Binv = torch.linalg.inv(M)
print(Binv)
a=0

"""
M.shape
torch.Size([2, 3, 3])

tensor(
    [
        [
            [-2.7000e+00,  2.0000e+00,  7.0000e+00],
            [ 2.6000e+00, -1.0000e+00, -1.1000e+01],
            [-5.0000e-01,  1.9868e-08,  5.0000e+00]
        ],
        [
            [ 7.0000e+00,  2.0000e+00, -2.7000e+00],
            [-1.1000e+01, -1.0000e+00,  2.6000e+00],
            [ 5.0000e+00, -4.7684e-08, -5.0000e-01]
        ]
    ]
)
"""

M = torch.tensor(
    [
        [[1., 2., 3.],[1.5, 2., 2.3],[.1, .2, .5]],
    ]
)
Binv = torch.linalg.inv(M)
print(Binv)
a=0

"""
M.shape
torch.Size([1, 3, 3])

tensor(
    [
        [
            [-2.7000e+00,  2.0000e+00,  7.0000e+00],
            [ 2.6000e+00, -1.0000e+00, -1.1000e+01],
            [-5.0000e-01,  1.9868e-08,  5.0000e+00]
        ]
    ]
)
"""


import torch.nn as nn
from torch.linalg import det

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

    def cof1(self,M,index):
        zs = M[:index[0]-1,:index[1]-1]
        ys = M[:index[0]-1,index[1]:]
        zx = M[index[0]:,:index[1]-1]
        yx = M[index[0]:,index[1]:]
        s = torch.cat((zs,ys),dim=1)
        x = torch.cat((zx,yx),dim=1)
        return det(torch.cat((s,x),dim=0))

    def alcof(self,M,index):
        return pow(-1,index[0]+index[1])*self.cof1(M,index)

    def adj(self,M):
        result = torch.zeros((M.shape[0],M.shape[1]))
        for i in range(1,M.shape[0]+1):
            for j in range(1,M.shape[1]+1):
                result[j-1][i-1] = self.alcof(M,[i,j])
        return result

    def forward(self,M):
        M_shape = M.shape
        M_rank = len(M_shape)

        if M_rank == 2:
            return 1.0/det(M)*self.adj(M)
        elif M_rank == 3:
            batched_M = []
            for partial_M in M:
                batched_M.append(torch.unsqueeze(1.0/det(partial_M)*self.adj(partial_M), dim=0))
            return torch.cat(batched_M, dim=0)
        elif M_rank == 4:
            batched_0 = []
            for partial_0 in M:
                batched_1 = []
                for partial_1 in partial_0:
                    batched_1.append(torch.unsqueeze(1.0/det(partial_1)*self.adj(partial_1), dim=0))
                batched_0.append(torch.unsqueeze(torch.cat(batched_1, dim=0), dim=0))
            return torch.cat(batched_0, dim=0)
        elif M_rank == 5:
            batched_0 = []
            for partial_0 in M:
                batched_1 = []
                for partial_1 in partial_0:
                    batched_2 = []
                    for partial_2 in partial_1:
                        batched_2.append(torch.unsqueeze(1.0/det(partial_2)*self.adj(partial_2), dim=0))
                    batched_1.append(torch.unsqueeze(torch.cat(batched_2, dim=0), dim=0))
                batched_0.append(torch.unsqueeze(torch.cat(batched_1, dim=0), dim=0))
            return torch.cat(batched_0, dim=0)
        elif M_rank == 6:
            batched_0 = []
            for partial_0 in M:
                batched_1 = []
                for partial_1 in partial_0:
                    batched_2 = []
                    for partial_2 in partial_1:
                        batched_3 = []
                        for partial_3 in partial_2:
                            batched_3.append(torch.unsqueeze(1.0/det(partial_3)*self.adj(partial_3), dim=0))
                        batched_2.append(torch.unsqueeze(torch.cat(batched_3, dim=0), dim=0))
                    batched_1.append(torch.unsqueeze(torch.cat(batched_2, dim=0), dim=0))
                batched_0.append(torch.unsqueeze(torch.cat(batched_1, dim=0), dim=0))
            return torch.cat(batched_0, dim=0)
        else:
            # Unsupported Error
            pass

model = Model()
test_tensors = [
    # torch.randn([1,2,3,4,5,5], dtype=torch.float32),
    torch.randn([1,2,3,4,4], dtype=torch.float32),
    torch.randn([1,3,224,224], dtype=torch.float32),
    torch.randn([2,224,224], dtype=torch.float32),
]

for x in test_tensors:
    onnx_file = f'pseudo_invert_11_rank{len(x.shape)}.onnx'
    torch.onnx.export(
        model,
        args=(x),
        f=onnx_file,
        opset_version=11,
        input_names=['input'],
        output_names=['output'],
    )
    import onnx
    from onnxsim import simplify
    model_onnx2 = onnx.load(onnx_file)
    model_simp, check = simplify(model_onnx2)
    onnx.save(model_simp, onnx_file)
PINTOPINTO
  • 再帰処理による次元数無制限版(かなり冗長)、正方行列になるまで再帰処理でひたすら分解して逆行列を求める、det が正方行列にしか対応していないため
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

    def cof1(self,M,index):
        zs = M[:index[0]-1,:index[1]-1]
        ys = M[:index[0]-1,index[1]:]
        zx = M[index[0]:,:index[1]-1]
        yx = M[index[0]:,index[1]:]
        s = torch.cat((zs,ys),dim=1)
        x = torch.cat((zx,yx),dim=1)
        return det(torch.cat((s,x),dim=0))

    def alcof(self,M,index):
        return pow(-1,index[0]+index[1])*self.cof1(M,index)

    def adj(self,M):
        result = torch.zeros((M.shape[0],M.shape[1]))
        for i in range(1,M.shape[0]+1):
            for j in range(1,M.shape[1]+1):
                result[j-1][i-1] = self.alcof(M,[i,j])
        return result

    def forward(self, x):
        def _inverse_matrix_recursion(x):
            x_shape = x.shape
            x_rank = len(x_shape)
            if x_rank == 2:
                return torch.unsqueeze(1.0/det(x)*self.adj(x), dim=0)
            batched_tensor = []
            for splitted_tensor in x:
                batched_tensor.append(_inverse_matrix_recursion(splitted_tensor))
            return \
                torch.unsqueeze(
                    torch.cat(
                        batched_tensor,
                        dim=0
                    ),
                dim=0,
            )
        return torch.squeeze(_inverse_matrix_recursion(x), dim=0)


model = Model()
x = torch.randn([1,3,4,4], dtype=torch.float32)

onnx_file = f'pseudo_invert_11_rank{len(x.shape)}.onnx'
torch.onnx.export(
    model,
    args=(x),
    f=onnx_file,
    opset_version=11,
    input_names=['input'],
    output_names=['output'],
)
import onnx
from onnxsim import simplify
model_onnx2 = onnx.load(onnx_file)
model_simp, check = simplify(model_onnx2)
onnx.save(model_simp, onnx_file)
PINTOPINTO

https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftinverse

Inverse_11.json
{
  "irVersion": "8",
  "graph": {
    "node": [
      {
        "input": [
          "input"
        ],
        "output": [
          "output"
        ],
        "name": "Inverse_0",
        "opType": "Inverse",
        "domain": "com.microsoft"
      }
    ],
    "name": "inverse_graph",
    "input": [
      {
        "name": "input",
        "type": {
          "tensorType": {
            "elemType": 1,
            "shape": {
              "dim": [
                {
                  "dimValue": "1"
                },
                {
                  "dimValue": "3"
                },
                {
                  "dimValue": "224"
                },
                {
                  "dimValue": "224"
                }
              ]
            }
          }
        }
      }
    ],
    "output": [
      {
        "name": "output",
        "type": {
          "tensorType": {
            "elemType": 1,
            "shape": {
              "dim": [
                {
                  "dimValue": "1"
                },
                {
                  "dimValue": "3"
                },
                {
                  "dimValue": "224"
                },
                {
                  "dimValue": "224"
                }
              ]
            }
          }
        }
      }
    ]
  },
  "opsetImport": [
    {
      "domain": "",
      "version": "11"
    },
    {
      "domain": "com.microsoft",
      "version": "1"
    }
  ]
}