Open5
[ONNX] AdaptiveAvgPool2d, Replace AdaptiveAvgPool2d to avg_pool2d or mean or ret/(length_h*length_w)
どうでもいいサイズチェックを削除した改造PyTorchコード
adaptive_avg_pool2d
def _unsqueeze_to_dim(self, x: torch.Tensor, dim: int):
for _ in range(dim - x.dim()):
x = x.unsqueeze(-1)
return x
from typing import Tuple
def adaptive_avg_pool2d(self, input: torch.Tensor, output_size: Tuple[int, int]):
# Preconditions
device = input.device
shape = input.shape
# Optimisation (we should also do this in the kernel implementation)
if shape[-2] % output_size[-2] == 0 and shape[-1] % output_size[-1] == 0:
stride = tuple(i // o for i, o in zip(shape[-2:], output_size))
kernel = tuple(
i - (o - 1) * s for i, o, s in zip(shape[-2:], output_size, stride)
)
return torch.nn.functional.avg_pool2d(input, kernel, stride)
def start_index(a, b, c):
return torch.div(a * c, b, rounding_mode="trunc")
def end_index(a, b, c):
return torch.div((a + 1) * c + b - 1, b, rounding_mode="trunc")
def compute_idx(in_size, out_size):
orange = torch.arange(out_size, device=device, dtype=torch.int64)
i0 = start_index(orange, out_size, in_size)
# Let length = end_index - start_index, i.e. the length of the pooling kernels
# length.max() can be computed analytically as follows:
maxlength = in_size // out_size + 1
in_size_mod = in_size % out_size
# adaptive = True iff there are kernels with different lengths
adaptive = not (in_size_mod == 0 or out_size % in_size_mod == 0)
if adaptive:
maxlength += 1
elif in_size_mod == 0:
maxlength -= 1
range_max = torch.arange(maxlength, device=device, dtype=torch.int64)
idx = i0.unsqueeze(-1) + range_max
if adaptive:
# Need to clamp to avoid accesing out-of-bounds memory
# TODO make minimum accept scalars
maxval = torch.scalar_tensor(
in_size - 1, dtype=idx.dtype, device=idx.device
)
idx = torch.minimum(idx, maxval)
# Compute the lenghts
i1 = end_index(orange, out_size, in_size)
length = i1 - i0
else:
length = maxlength
return idx, length, range_max, adaptive
# length is not None if it's constant, otherwise we'll need to compute it
idxh, length_h, range_max_h, adaptive_h = compute_idx(shape[-2], output_size[-2])
idxw, length_w, range_max_w, adaptive_w = compute_idx(shape[-1], output_size[-1])
# Workaround for large consumption of RAM
# Split the index slice process into two stages
# vals = input[..., self._unsqueeze_to_dim(idxh, 4), idxw]
tmp_vals = input[..., self._unsqueeze_to_dim(idxh, 2), :]
vals = tmp_vals[..., idxw]
# Shortcut for the simpler case
if not adaptive_h and not adaptive_w:
return torch.mean(vals, dim=(-3, -1))
def maybe_mask(vals, length, range_max, adaptive, dim):
if isinstance(length, int):
return vals, length
else:
# zero-out the things we didn't really want to select
assert dim < 0
# hack
mask = range_max >= length.unsqueeze(-1)
if dim == -2:
mask = self._unsqueeze_to_dim(mask, 4)
vals = torch.masked_fill(vals, mask, 0.0)
# Compute the length of each window
length = self._unsqueeze_to_dim(length, -dim)
return vals, length
vals, length_h = maybe_mask(
vals, length_h, range_max_h, adaptive=adaptive_h, dim=-2
)
vals, length_w = maybe_mask(
vals, length_w, range_max_w, adaptive=adaptive_w, dim=-1
)
# We unroll the sum as we assume that the kernels are going to be small
from itertools import product
ret = None
for i, j in product(range(vals.shape[-3]), range(vals.shape[-1])):
if ret is None:
ret = vals[..., i, :, j]
else:
ret = ret + vals[..., i, :, j]
return ret / (length_h * length_w)
import torch
import numpy as np
############# Numpy
a=np.ones([1, 1, 32, 32], dtype=np.float32)
b=np.ones([16384, 1, 1, 1], dtype=np.int32)
print(a[..., b, :].shape)
(1, 1, 16384, 1, 1, 1, 32)
c=np.ones([16384, 1], dtype=np.int32)
print(a[..., b, c].shape)
(1, 1, 16384, 1, 16384, 1)
############# PyTorch
a=torch.ones([1, 1, 32, 32], dtype=torch.float32)
b=torch.ones([16384, 1, 1, 1], dtype=torch.int32)
print(a[..., b, :].shape)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
IndexError: tensors used as indices must be long, byte or bool tensors
input.shape[2] > output.shape[0]
なおかつ input.shape[3] > output.shape[1]
のときは、RAM消費をおさえる目的だけの場合は単純に F.interpolate(input, size=[output.shape[0], output.shape[1]], , mode='bilinear')
で引き伸ばしても誤差 1e-7 以下で整合する模様。ただし、ONNXエクスポート時にインデックスの展開処理か何かの問題でRAMを異常に消費してOOMを引く。ちなみに、input: [1,1,32,32]
, CPU RAM 128GB でも処理できなかった事例あり。
# def inference(self, low_dep, high_dep):
def forward(self, low_dep, high_dep):
# low_dep = self.pool(low_dep)
# high_dep = self.pool(high_dep)
# low_dep = nn.AvgPool2d((low_dep.shape[2]//2*2**self.Fuse.upsize, low_dep.shape[3]//2*2**self.Fuse.upsize), ceil_mode=True)(low_dep)
# high_dep = nn.AvgPool2d((high_dep.shape[2]//2*2**self.Fuse.upsize, high_dep.shape[3]//2*2**self.Fuse.upsize), ceil_mode=True)(high_dep)
# low_dep = self.adaptive_avg_pool2d(low_dep, (low_dep.shape[2]//2*2**self.Fuse.upsize, low_dep.shape[3]//2*2**self.Fuse.upsize))
# high_dep = self.adaptive_avg_pool2d(high_dep, (high_dep.shape[2]//2*2**self.Fuse.upsize, high_dep.shape[3]//2*2**self.Fuse.upsize))
low_dep = F.interpolate(low_dep, size=[low_dep.shape[2]//2*2**self.Fuse.upsize, low_dep.shape[3]//2*2**self.Fuse.upsize], mode='bilinear')
high_dep = F.interpolate(high_dep, size=[high_dep.shape[2]//2*2**self.Fuse.upsize, high_dep.shape[3]//2*2**self.Fuse.upsize], mode='bilinear')