🌊

Pytorchでoverlap patchingしたテンソルを作る

2024/02/12に公開

画像をk×kの矩形領域(kernel)で切り取って1つのチャンネルにし、これを画像全体で行うことでc×k×kのテンソルを作りたい。

引用: Qian, Yan & Barthélemy, Johan & Iqbal, Umair & Perez, Pascal. (2022). V2ReID: Vision-Outlooker-Based Vehicle Re-Identification. Sensors. 22. 8651. 10.3390/s22228651.

このコードは次のように書ける。(padding部分はもうちょっと上手く書けそう...)

class OverlapPatching(nn.Module):
    """
    Split image by kernel size, and stack in channel direction.
    [B,1,H,W] -> [B,H*W,K,K]
    """
    def __init__(self, kernel_size):
        super().__init__()
        self.kernel_size = kernel_size
        if kernel_size%2 == 1:
            p = (kernel_size//2, kernel_size//2, kernel_size//2, kernel_size//2)
        else:
            p = (kernel_size//2, kernel_size//2-1, kernel_size//2, kernel_size//2-1) 
        self.pad = nn.ReflectionPad2d(p)
        return None

    def forward(self, img):
        b, _, h, w = img.shape
        kernel_size = self.kernel_size
        img = self.pad(img)
        return (
            img
            .unfold(3, kernel_size, 1)
            .unfold(2, kernel_size, 1)
            .reshape(b, h*w, kernel_size, kernel_size)
            )

使い方

> In [1]: op = OverlapPatching(8)
OverlapPatching(
  (pad): ReflectionPad2d((4, 3, 4, 3))
)
> In [2]: output = op(torch.ones(5,1,100,100))
> In [3]: output.shape
torch.Size([5, 10000, 8, 8])

Discussion