Open4

torch.median をONNXエクスポートするワークアラウンド

PINTOPINTO
  • 全体の要素数が奇数の場合
def get_median(v):
    v = tf.reshape(v, [-1])
    mid = v.get_shape()[0]//2 + 1
    return tf.nn.top_k(v, mid).values[-1]
  • 全体の要素数が偶数の場合(正確には中央値が2つの値になってどちらを選択するか判断不可能なときに平均値を算出して採用する場合)
def get_real_median(v):
    v = tf.reshape(v, [-1])
    l = v.get_shape()[0]
    mid = l//2 + 1
    val = tf.nn.top_k(v, mid).values
    if l % 2 == 1:
        return val[-1]
    else:
        return 0.5 * (val[-1] + val[-2])
PINTOPINTO
  • pytorch
def get_real_median(self, v):
    v = torch.reshape(v, [-1])
    l = v.shape[0]
    mid = l//2 + 1
    val = torch.topk(v, mid)
    if l % 2 == 1:
        return val[-1]
    else:
        return 0.5 * (val[-1] + val[-2])
PINTOPINTO
  • 完成形
import random
import torch
import torch.nn as nn
import numpy as np
import onnxruntime as ort

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

class Model_normal_median(nn.Module):
    def __init__(self, dim):
        super(Model_normal_median, self).__init__()
        self.dim = dim

    def forward(self, x):
        values, indices = torch.median(
            input=x,
            dim=self.dim,
            keepdim=True,
        )
        return values

class Model_pseudo_median(nn.Module):
    def __init__(self, dim, mode):
        super(Model_pseudo_median, self).__init__()
        self.dim = dim
        self.mode = mode

    def forward(self, x):
        x_shape = [int(d) for d in x.shape]
        transpose_perm = None
        reverse_transpose_perm = None

        if len(x_shape) >= 2:
            transpose_perm = [
                idx for idx in range(len(x_shape)) if idx != self.dim
            ] + [self.dim]
            reverse_transpose_perm = [
                transpose_perm.index(idx) \
                    for idx in range(len(transpose_perm))
            ]

            shape_before_compression = None
            shape_after_compression = None
            if len(transpose_perm) >= 2:
                transposed_x = x.permute(transpose_perm)
                transposed_x_shape = [
                    int(d) for d in transposed_x.shape
                ]
                shape_before_compression = \
                    transposed_x_shape[:-1] + [transposed_x_shape[-1]]
                shape_after_compression = [
                    np.prod(transposed_x_shape[:-1])
                ] + [transposed_x_shape[-1]]
            else:
                transposed_x = x
                shape_before_compression = [
                    int(d) for d in transposed_x.shape
                ]
                shape_after_compression = [
                    int(d) for d in transposed_x.shape
                ]

            transposed_reshaped_x = torch.reshape(
                input=transposed_x,
                shape=shape_after_compression,
            )
            l = transposed_reshaped_x.shape[-1]
            mid = l // 2 + 1
            values, indices = torch.topk(
                input=transposed_reshaped_x,
                k=mid,
                dim=len(transposed_reshaped_x.shape) - 1,
            )
            last_dim = [int(d) for d in values.shape][-1]
            median_1_idx = last_dim - 1 if last_dim >= 2 else last_dim
            median_2_idx = last_dim if last_dim >= 2 else last_dim
            if l % 2 == 1:
                median_values = values[:, median_1_idx:median_1_idx + 1]
            else:
                if self.mode == 'floor':
                    median_values =  values[:, median_1_idx - 1:median_1_idx]
                elif self.mode == 'ceil':
                    median_values =  values[:, median_2_idx - 1:median_2_idx]
                elif self.mode == 'mean':
                    median_values =  0.5 * (
                        values[:, median_1_idx - 1:median_1_idx] \
                            + values[:, median_2_idx - 1:median_2_idx]
                    )
            reshaped_median_values = torch.reshape(
                input=median_values,
                shape=shape_before_compression[:-1] + [1],
            )
            result = reshaped_median_values.permute(reverse_transpose_perm)

        else:
            x = torch.reshape(x, [-1])
            l = x.shape[0]
            mid = l // 2 + 1
            values, indices = torch.topk(x, mid)
            result = None
            if l % 2 == 1:
                result = values[-1]
            else:
                if self.mode == 'floor':
                    result = values[-1]
                elif self.mode == 'ceil':
                    result = values[-2]
                elif self.mode == 'mean':
                    result = 0.5 * (values[-1] + values[-2])

        return result

x = torch.randn([4,5,6])
# x = torch.randn([4])
# x = torch.randn([5])
dim = 2
mode = 'ceil'

normal_model = Model_normal_median(dim=dim)
normal_output = normal_model.forward(x).numpy()

pseudo_model = Model_pseudo_median(dim=dim, mode=mode)
onnx_file = f'median_11_pseudo.onnx'
torch.onnx.export(
    pseudo_model,
    args=(x),
    f=onnx_file,
    opset_version=11,
    input_names=[
        'input',
    ],
    output_names=[
        'output',
    ],
)
import onnx
from onnxsim import simplify
model_onnx2 = onnx.load(onnx_file)
model_simp, check = simplify(model_onnx2)
onnx.save(model_simp, onnx_file)

onnx_session = ort.InferenceSession(
    path_or_bytes=onnx_file,
    providers={
        'CUDAExecutionProvider',
        'CPUExecutionProvider',
    },
)
pseudo_result = \
    onnx_session.run(
        None,
        {'input': x.numpy()},
    )[0]

print(
    f'normal_output == pseudo_result: {np.allclose(normal_output, pseudo_result)}'
)