Open4
torch.median をONNXエクスポートするワークアラウンド
- 全体の要素数が奇数の場合
def get_median(v):
v = tf.reshape(v, [-1])
mid = v.get_shape()[0]//2 + 1
return tf.nn.top_k(v, mid).values[-1]
- 全体の要素数が偶数の場合(正確には中央値が2つの値になってどちらを選択するか判断不可能なときに平均値を算出して採用する場合)
def get_real_median(v):
v = tf.reshape(v, [-1])
l = v.get_shape()[0]
mid = l//2 + 1
val = tf.nn.top_k(v, mid).values
if l % 2 == 1:
return val[-1]
else:
return 0.5 * (val[-1] + val[-2])
- pytorch
def get_real_median(self, v):
v = torch.reshape(v, [-1])
l = v.shape[0]
mid = l//2 + 1
val = torch.topk(v, mid)
if l % 2 == 1:
return val[-1]
else:
return 0.5 * (val[-1] + val[-2])
- 完成形
import random
import torch
import torch.nn as nn
import numpy as np
import onnxruntime as ort
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
class Model_normal_median(nn.Module):
def __init__(self, dim):
super(Model_normal_median, self).__init__()
self.dim = dim
def forward(self, x):
values, indices = torch.median(
input=x,
dim=self.dim,
keepdim=True,
)
return values
class Model_pseudo_median(nn.Module):
def __init__(self, dim, mode):
super(Model_pseudo_median, self).__init__()
self.dim = dim
self.mode = mode
def forward(self, x):
x_shape = [int(d) for d in x.shape]
transpose_perm = None
reverse_transpose_perm = None
if len(x_shape) >= 2:
transpose_perm = [
idx for idx in range(len(x_shape)) if idx != self.dim
] + [self.dim]
reverse_transpose_perm = [
transpose_perm.index(idx) \
for idx in range(len(transpose_perm))
]
shape_before_compression = None
shape_after_compression = None
if len(transpose_perm) >= 2:
transposed_x = x.permute(transpose_perm)
transposed_x_shape = [
int(d) for d in transposed_x.shape
]
shape_before_compression = \
transposed_x_shape[:-1] + [transposed_x_shape[-1]]
shape_after_compression = [
np.prod(transposed_x_shape[:-1])
] + [transposed_x_shape[-1]]
else:
transposed_x = x
shape_before_compression = [
int(d) for d in transposed_x.shape
]
shape_after_compression = [
int(d) for d in transposed_x.shape
]
transposed_reshaped_x = torch.reshape(
input=transposed_x,
shape=shape_after_compression,
)
l = transposed_reshaped_x.shape[-1]
mid = l // 2 + 1
values, indices = torch.topk(
input=transposed_reshaped_x,
k=mid,
dim=len(transposed_reshaped_x.shape) - 1,
)
last_dim = [int(d) for d in values.shape][-1]
median_1_idx = last_dim - 1 if last_dim >= 2 else last_dim
median_2_idx = last_dim if last_dim >= 2 else last_dim
if l % 2 == 1:
median_values = values[:, median_1_idx:median_1_idx + 1]
else:
if self.mode == 'floor':
median_values = values[:, median_1_idx - 1:median_1_idx]
elif self.mode == 'ceil':
median_values = values[:, median_2_idx - 1:median_2_idx]
elif self.mode == 'mean':
median_values = 0.5 * (
values[:, median_1_idx - 1:median_1_idx] \
+ values[:, median_2_idx - 1:median_2_idx]
)
reshaped_median_values = torch.reshape(
input=median_values,
shape=shape_before_compression[:-1] + [1],
)
result = reshaped_median_values.permute(reverse_transpose_perm)
else:
x = torch.reshape(x, [-1])
l = x.shape[0]
mid = l // 2 + 1
values, indices = torch.topk(x, mid)
result = None
if l % 2 == 1:
result = values[-1]
else:
if self.mode == 'floor':
result = values[-1]
elif self.mode == 'ceil':
result = values[-2]
elif self.mode == 'mean':
result = 0.5 * (values[-1] + values[-2])
return result
x = torch.randn([4,5,6])
# x = torch.randn([4])
# x = torch.randn([5])
dim = 2
mode = 'ceil'
normal_model = Model_normal_median(dim=dim)
normal_output = normal_model.forward(x).numpy()
pseudo_model = Model_pseudo_median(dim=dim, mode=mode)
onnx_file = f'median_11_pseudo.onnx'
torch.onnx.export(
pseudo_model,
args=(x),
f=onnx_file,
opset_version=11,
input_names=[
'input',
],
output_names=[
'output',
],
)
import onnx
from onnxsim import simplify
model_onnx2 = onnx.load(onnx_file)
model_simp, check = simplify(model_onnx2)
onnx.save(model_simp, onnx_file)
onnx_session = ort.InferenceSession(
path_or_bytes=onnx_file,
providers={
'CUDAExecutionProvider',
'CPUExecutionProvider',
},
)
pseudo_result = \
onnx_session.run(
None,
{'input': x.numpy()},
)[0]
print(
f'normal_output == pseudo_result: {np.allclose(normal_output, pseudo_result)}'
)