Open1
Fast-SCNNのONNXエクスポート, AdaptiveAvgPool2d を AvgPool2d へ置き換え
- 入力解像度は 32 と 6 の最小公倍数の96の倍数に指定する(Pooling時の端数を出さずPaddingせず切り上げ切り捨てを発生させないようにしてきっちり収まるようにするため)
-
AvgPool2d
のkernel_size
は入力解像度を32で割った数値にする -
AdaptiveAvgPool2d
のモデル設計上の入力パラメータsize
の組み合わせは 1,2,3,6 -
size
の意味は Pooling したあとの出力画像のサイズ。スカラーのときはsize x size
に読み替える。タプル(size1, size2)
のときはsize1 x size2
となる。 -
AvgPool2d
へ置き換えるためにはAdaptiveAvgPool2d
への入力画像のサイズを最初に調べる必要がある。なので、ワークアラウンドとして、超適当なサイズのAvgPool2d
に置き換えたうえでとりあえずモデルを出力しておいて、Netron でAdaptiveAvgPool2d
のINPUTの構造を見てしまうのが手っ取り早い。PyTorchなら[N, C, H, W]
。バカ正直にOUTPUTから逆算するほうが手間だし、そもそもStride
が分からない。
192x384
fast_scnn.py
class PyramidPooling(nn.Module):
"""Pyramid pooling module"""
def __init__(self, in_channels, out_channels, **kwargs):
super(PyramidPooling, self).__init__()
inter_channels = int(in_channels / 4)
self.conv1 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
self.conv2 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
self.conv3 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
self.conv4 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
self.out = _ConvBNReLU(in_channels * 2, out_channels, 1)
def pool(self, x, size, base):
# avgpool = nn.AdaptiveAvgPool2d(size)
print(f'@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ size: {size}')
h = base[0].item()
w = base[1].item()
avgpool = nn.AvgPool2d((h//size, w//size), ceil_mode=True)
return avgpool(x)
def upsample(self, x, size):
return F.interpolate(x, size, mode='bilinear', align_corners=True)
def forward(self, x):
size = x.size()[2:]
feat1 = self.upsample(self.conv1(self.pool(x, 1, size)), size)
feat2 = self.upsample(self.conv2(self.pool(x, 2, size)), size)
feat3 = self.upsample(self.conv3(self.pool(x, 3, size)), size)
feat4 = self.upsample(self.conv4(self.pool(x, 6, size)), size)
x = torch.cat([x, feat1, feat2, feat3, feat4], dim=1)
x = self.out(x)
return x
384x576
fast_scnn.py
class PyramidPooling(nn.Module):
"""Pyramid pooling module"""
def __init__(self, in_channels, out_channels, **kwargs):
super(PyramidPooling, self).__init__()
inter_channels = int(in_channels / 4)
self.conv1 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
self.conv2 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
self.conv3 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
self.conv4 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
self.out = _ConvBNReLU(in_channels * 2, out_channels, 1)
def pool(self, x, size, base):
# avgpool = nn.AdaptiveAvgPool2d(size)
print(f'@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ size: {size}')
h = base[0].item()
w = base[1].item()
avgpool = nn.AvgPool2d((h//size, w//size), ceil_mode=True)
return avgpool(x)
def upsample(self, x, size):
return F.interpolate(x, size, mode='bilinear', align_corners=True)
def forward(self, x):
size = x.size()[2:]
feat1 = self.upsample(self.conv1(self.pool(x, 1, size)), size)
feat2 = self.upsample(self.conv2(self.pool(x, 2, size)), size)
feat3 = self.upsample(self.conv3(self.pool(x, 3, size)), size)
feat4 = self.upsample(self.conv4(self.pool(x, 6, size)), size)
x = torch.cat([x, feat1, feat2, feat3, feat4], dim=1)
x = self.out(x)
return x
768x1344
fast_scnn.py
class PyramidPooling(nn.Module):
"""Pyramid pooling module"""
def __init__(self, in_channels, out_channels, **kwargs):
super(PyramidPooling, self).__init__()
inter_channels = int(in_channels / 4)
self.conv1 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
self.conv2 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
self.conv3 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
self.conv4 = _ConvBNReLU(in_channels, inter_channels, 1, **kwargs)
self.out = _ConvBNReLU(in_channels * 2, out_channels, 1)
def pool(self, x, size, base):
# avgpool = nn.AdaptiveAvgPool2d(size)
print(f'@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ size: {size}')
h = base[0].item()
w = base[1].item()
avgpool = nn.AvgPool2d((h//size, w//size), ceil_mode=True)
return avgpool(x)
def upsample(self, x, size):
return F.interpolate(x, size, mode='bilinear', align_corners=True)
def forward(self, x):
size = x.size()[2:]
feat1 = self.upsample(self.conv1(self.pool(x, 1, size)), size)
feat2 = self.upsample(self.conv2(self.pool(x, 2, size)), size)
feat3 = self.upsample(self.conv3(self.pool(x, 3, size)), size)
feat4 = self.upsample(self.conv4(self.pool(x, 6, size)), size)
x = torch.cat([x, feat1, feat2, feat3, feat4], dim=1)
x = self.out(x)
return x