🪡

PyTorchのConv2dで画像処理フィルターを作る

2024/01/12に公開

NumpyとOpenCVでやっていた処理をPyTorchで書き換えてGPU上で動かしたい欲求がある。

端的にコード例を示す。微分フィルター以外もkernelさえ書ければ任意のフィルターが作れるはず。

class DifferencialFilter(nn.Module):
    def __init__(self, channel):
        super().__init__()
        kernel = torch.tensor(
            [[[ [-1.0, 0.0, 1.0],
                [-1.0, 0.0, 1.0],
                [-1.0, 0.0, 1.0] ]]]).repeat(channel,1,1,1)
        self.conv = nn.Conv2d(
            channel, channel, 3, 
            padding=1, padding_mode="reflect", 
            bias=False, groups=channel, dtype=torch.float)
        self.conv.weight = torch.nn.Parameter(kernel1)
        return None

    def forward(self, x):
        return self.conv(x)

ガウスぼかしは少しテクい方法がいる。次のディスカッションが役に立った。

https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351

class Gaussian(nn.Module):
    def __init__(self, channels, kernel_size, sigma=1.0):
        super().__init__()
        kernel_size = [kernel_size] * 2
        sigma = [sigma] * 2

        kernel = 1
        meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float) for size in kernel_size], indexing="ij")
        for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
            mean = (size - 1) / 2
            kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp((-((mgrid - mean) / std) ** 2) / 2)
        kernel = kernel / torch.sum(kernel)
        kernel = kernel.view(1, 1, *kernel.size())
        kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))

        self.conv = nn.Conv2d(
            channels, channels, kernel_size, 
            padding=(int(kernel_size[0]//2), int(kernel_size[1]//2)), padding_mode="reflect",
            groups=channels, bias=False, dtype=torch.float)
        self.conv.weight = torch.nn.Parameter(kernel)
        return None

    def forward(self, x):
        return self.conv(x)

Discussion