Open5

[ONNX] AdaptiveAvgPool2d, Replace AdaptiveAvgPool2d to avg_pool2d or mean or ret/(length_h*length_w)

PINTOPINTO

どうでもいいサイズチェックを削除した改造PyTorchコード

adaptive_avg_pool2d
    def _unsqueeze_to_dim(self, x: torch.Tensor, dim: int):
        for _ in range(dim - x.dim()):
            x = x.unsqueeze(-1)
        return x

    from typing import Tuple
    def adaptive_avg_pool2d(self, input: torch.Tensor, output_size: Tuple[int, int]):
        # Preconditions
        device = input.device
        shape = input.shape

        # Optimisation (we should also do this in the kernel implementation)
        if shape[-2] % output_size[-2] == 0 and shape[-1] % output_size[-1] == 0:
            stride = tuple(i // o for i, o in zip(shape[-2:], output_size))
            kernel = tuple(
                i - (o - 1) * s for i, o, s in zip(shape[-2:], output_size, stride)
            )
            return torch.nn.functional.avg_pool2d(input, kernel, stride)

        def start_index(a, b, c):
            return torch.div(a * c, b, rounding_mode="trunc")

        def end_index(a, b, c):
            return torch.div((a + 1) * c + b - 1, b, rounding_mode="trunc")

        def compute_idx(in_size, out_size):
            orange = torch.arange(out_size, device=device, dtype=torch.int64)
            i0 = start_index(orange, out_size, in_size)
            # Let length = end_index - start_index, i.e. the length of the pooling kernels
            # length.max() can be computed analytically as follows:
            maxlength = in_size // out_size + 1
            in_size_mod = in_size % out_size
            # adaptive = True iff there are kernels with different lengths
            adaptive = not (in_size_mod == 0 or out_size % in_size_mod == 0)
            if adaptive:
                maxlength += 1
            elif in_size_mod == 0:
                maxlength -= 1

            range_max = torch.arange(maxlength, device=device, dtype=torch.int64)
            idx = i0.unsqueeze(-1) + range_max
            if adaptive:
                # Need to clamp to avoid accesing out-of-bounds memory
                # TODO make minimum accept scalars
                maxval = torch.scalar_tensor(
                    in_size - 1, dtype=idx.dtype, device=idx.device
                )
                idx = torch.minimum(idx, maxval)

                # Compute the lenghts
                i1 = end_index(orange, out_size, in_size)
                length = i1 - i0
            else:
                length = maxlength
            return idx, length, range_max, adaptive

        # length is not None if it's constant, otherwise we'll need to compute it
        idxh, length_h, range_max_h, adaptive_h = compute_idx(shape[-2], output_size[-2])
        idxw, length_w, range_max_w, adaptive_w = compute_idx(shape[-1], output_size[-1])

        # Workaround for large consumption of RAM
        # Split the index slice process into two stages
        # vals = input[..., self._unsqueeze_to_dim(idxh, 4), idxw]
        tmp_vals = input[..., self._unsqueeze_to_dim(idxh, 2), :]
        vals = tmp_vals[..., idxw]

        # Shortcut for the simpler case
        if not adaptive_h and not adaptive_w:
            return torch.mean(vals, dim=(-3, -1))

        def maybe_mask(vals, length, range_max, adaptive, dim):
            if isinstance(length, int):
                return vals, length
            else:
                # zero-out the things we didn't really want to select
                assert dim < 0
                # hack
                mask = range_max >= length.unsqueeze(-1)
                if dim == -2:
                    mask = self._unsqueeze_to_dim(mask, 4)
                vals = torch.masked_fill(vals, mask, 0.0)
                # Compute the length of each window
                length = self._unsqueeze_to_dim(length, -dim)
                return vals, length

        vals, length_h = maybe_mask(
            vals, length_h, range_max_h, adaptive=adaptive_h, dim=-2
        )
        vals, length_w = maybe_mask(
            vals, length_w, range_max_w, adaptive=adaptive_w, dim=-1
        )

        # We unroll the sum as we assume that the kernels are going to be small
        from itertools import product
        ret = None
        for i, j in product(range(vals.shape[-3]), range(vals.shape[-1])):
            if ret is None:
                ret = vals[..., i, :, j]
            else:
                ret = ret + vals[..., i, :, j]
        return ret / (length_h * length_w)
PINTOPINTO
import torch
import numpy as np

############# Numpy
a=np.ones([1, 1, 32, 32], dtype=np.float32)
b=np.ones([16384, 1, 1, 1], dtype=np.int32)
print(a[..., b, :].shape)
(1, 1, 16384, 1, 1, 1, 32)

c=np.ones([16384, 1], dtype=np.int32)
print(a[..., b, c].shape)
(1, 1, 16384, 1, 16384, 1)

############# PyTorch
a=torch.ones([1, 1, 32, 32], dtype=torch.float32)
b=torch.ones([16384, 1, 1, 1], dtype=torch.int32)
print(a[..., b, :].shape)

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
IndexError: tensors used as indices must be long, byte or bool tensors
PINTOPINTO

input.shape[2] > output.shape[0] なおかつ input.shape[3] > output.shape[1] のときは、RAM消費をおさえる目的だけの場合は単純に F.interpolate(input, size=[output.shape[0], output.shape[1]], , mode='bilinear') で引き伸ばしても誤差 1e-7 以下で整合する模様。ただし、ONNXエクスポート時にインデックスの展開処理か何かの問題でRAMを異常に消費してOOMを引く。ちなみに、input: [1,1,32,32], CPU RAM 128GB でも処理できなかった事例あり。

https://github.com/YuiNsky/Gradient-based-depth-map-fusion

    # def inference(self, low_dep, high_dep):
    def forward(self, low_dep, high_dep):
        # low_dep = self.pool(low_dep)
        # high_dep = self.pool(high_dep)

        # low_dep = nn.AvgPool2d((low_dep.shape[2]//2*2**self.Fuse.upsize, low_dep.shape[3]//2*2**self.Fuse.upsize), ceil_mode=True)(low_dep)
        # high_dep = nn.AvgPool2d((high_dep.shape[2]//2*2**self.Fuse.upsize, high_dep.shape[3]//2*2**self.Fuse.upsize), ceil_mode=True)(high_dep)

        # low_dep = self.adaptive_avg_pool2d(low_dep, (low_dep.shape[2]//2*2**self.Fuse.upsize, low_dep.shape[3]//2*2**self.Fuse.upsize))
        # high_dep = self.adaptive_avg_pool2d(high_dep, (high_dep.shape[2]//2*2**self.Fuse.upsize, high_dep.shape[3]//2*2**self.Fuse.upsize))

        low_dep = F.interpolate(low_dep, size=[low_dep.shape[2]//2*2**self.Fuse.upsize, low_dep.shape[3]//2*2**self.Fuse.upsize], mode='bilinear')
        high_dep = F.interpolate(high_dep, size=[high_dep.shape[2]//2*2**self.Fuse.upsize, high_dep.shape[3]//2*2**self.Fuse.upsize], mode='bilinear')