🌊
Pytorchでoverlap patchingしたテンソルを作る
画像をk×kの矩形領域(kernel)で切り取って1つのチャンネルにし、これを画像全体で行うことでc×k×kのテンソルを作りたい。
引用: Qian, Yan & Barthélemy, Johan & Iqbal, Umair & Perez, Pascal. (2022). V2ReID: Vision-Outlooker-Based Vehicle Re-Identification. Sensors. 22. 8651. 10.3390/s22228651.
このコードは次のように書ける。(padding部分はもうちょっと上手く書けそう...)
class OverlapPatching(nn.Module):
"""
Split image by kernel size, and stack in channel direction.
[B,1,H,W] -> [B,H*W,K,K]
"""
def __init__(self, kernel_size):
super().__init__()
self.kernel_size = kernel_size
if kernel_size%2 == 1:
p = (kernel_size//2, kernel_size//2, kernel_size//2, kernel_size//2)
else:
p = (kernel_size//2, kernel_size//2-1, kernel_size//2, kernel_size//2-1)
self.pad = nn.ReflectionPad2d(p)
return None
def forward(self, img):
b, _, h, w = img.shape
kernel_size = self.kernel_size
img = self.pad(img)
return (
img
.unfold(3, kernel_size, 1)
.unfold(2, kernel_size, 1)
.reshape(b, h*w, kernel_size, kernel_size)
)
使い方
> In [1]: op = OverlapPatching(8)
OverlapPatching(
(pad): ReflectionPad2d((4, 3, 4, 3))
)
> In [2]: output = op(torch.ones(5,1,100,100))
> In [3]: output.shape
torch.Size([5, 10000, 8, 8])
Discussion