Open5

torch.nn.functional.interpolate の scale_factor 指定でモデルの出力サイズを破壊する方法

PINTOPINTO

Pythonの丸め誤差が生じるギリギリの値で scale_factor 指定をした想定の出力サイズの計算

>>> print(int(9 * 1.9999999999999998))
17
PINTOPINTO

scale_factor 指定の interpolate を使用してテンソルをリサイズし、ONNXへエクスポートしたモデルの出力サイズを確認する。

class ResizeModel(nn.Module):
    def __init__(
        self,
    ):
        super(ResizeModel, self).__init__()

    def forward(self, x):
        return  \
            nn.functional.interpolate(
                input=x,
                scale_factor=(1.5, 1.9999999999999998),
            )

OPSET = 11
MODEL = f'xxxx'
ONNX_FILE = f"{MODEL}_{OPSET}.onnx"

model = ResizeModel()

x = torch.randn(1, 3, 9, 9)

torch.onnx.export(
    model,
    args=(x),
    f=ONNX_FILE,
    opset_version=OPSET,
    input_names=[
        f'{MODEL}_input',
    ],
    output_names=[
        f'{MODEL}_output',
    ],
)
model_onnx1 = onnx.load(onnx_file)
model_onnx1 = onnx.shape_inference.infer_shapes(model_onnx1)
onnx.save(model_onnx1, onnx_file)
model_onnx2 = onnx.load(onnx_file)
model_simp, check = simplify(model_onnx2)
onnx.save(model_simp, onnx_file)

ONNXは出力サイズが [1,3,13,18] になる

PINTOPINTO

PyTorch は出力サイズが [1,3,13,17] になる

y = model(x)
y.shape
torch.Size([1, 3, 13, 17])
PINTOPINTO

TensorFlow は出力サイズが [1,13,18,3] になる

import tensorflow as tf
from tf_keras.layers import Layer

class ResizeLayer(Layer):
    def __init__(self, scale_factor=(1.5, 1.9999999999999998), **kwargs):
        super(ResizeLayer, self).__init__(**kwargs)
        self.scale_factor = tf.convert_to_tensor(scale_factor, dtype=tf.float32)

    def call(self, inputs):
        input_shape = tf.shape(inputs)
        height = tf.cast(input_shape[1], dtype=tf.float32)
        width = tf.cast(input_shape[2], dtype=tf.float32)
        new_height = tf.cast(height * self.scale_factor[0], tf.int32)
        new_width = tf.cast(width * self.scale_factor[1], tf.int32)
        return tf.image.resize(inputs, [new_height, new_width])

resize_layer = ResizeLayer()
input_tensor = tf.random.normal([1, 9, 9, 3])
output_tensor = resize_layer(input_tensor)
output_tensor.shape
TensorShape([1, 13, 18, 3])
PINTOPINTO

つまり、下記のように scale_factor によるリサイズはテンソル破壊が発生するリスクを内在する。

Python:     [1,3,13,17]
ONNX:       [1,3,13,18]
PyTorch:    [1,3,13,17]
TensorFlow: [1,13,18,3]