SSIM Loss を PyTorch で実装

4 min read読了の目安(約4300字

  • 画像 X と画像 Y から局所領域 xy を切り抜く
  • 局所領域 xy 内の画素値から、平均 \mu _x\mu _y、 標準偏差 \sigma _x\sigma_y、共分散 \sigma_{xy} を計算
  • (1) から局所領域における SSIM を計算
SSIM = \frac{(2 \mu_x \mu_y + c_1)(2 \mu_{xy} + c_2)}{(\mu_x^2 + \mu_y^2 + c_1)(\sigma_x^2 \sigma_y^2 + c_2)} \quad \quad (1)
  • 1 ピクセルずつ xy 方向に局所領域をスライドさせ SSIM を再び計算。画像サイズ 256×256、局所領域サイズ 64×64 の場合、(256 - 64 + 1) \times (256 - 64 + 1) = 37,249 回の SSIM の計算を行う必要がある (padding は 0 とする)。

PyTorch の Conv2d を使った実装

  • 局所領域と同じサイズの kernel を用意し、PyTorch Conv2d を使って、平均、分散、共分散を計算します。
  • 実用上は、画像 X と画像 Y を平滑化した後に SSIM を計算するので、uniform kernel ではなく gaussian kernel を使っています。
  • 標準偏差は \sigma^2 = \overline {x^2} - (\overline x)^2 より計算しています。次の章で、この式の導出を行っています。
  • 定数は c_1 = (k_1 L)^2c_2 = (k_2 L)^2 のように定義されており、 L はダイナミックレンジ (8 bit の場合 255)、k_1k_2 はハイパーパラメータでデフォルト値は 0.01 と 0.03 です。
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module


class SSIMLoss(Module):
    def __init__(self, kernel_size: int = 11, sigma: float = 1.5) -> None:

        """Computes the structural similarity (SSIM) index map between two images.

        Args:
            kernel_size (int): Height and width of the gaussian kernel.
            sigma (float): Gaussian standard deviation in the x and y direction.
        """

        super().__init__()
        self.kernel_size = kernel_size
        self.sigma = sigma
        self.gaussian_kernel = self._create_gaussian_kernel(self.kernel_size, self.sigma)

    def forward(self, x: Tensor, y: Tensor, as_loss: bool = True) -> Tensor:

        if not self.gaussian_kernel.is_cuda:
            self.gaussian_kernel = self.gaussian_kernel.to(x.device)

        ssim_map = self._ssim(x, y)

        if as_loss:
            return 1 - ssim_map.mean()
        else:
            return ssim_map

    def _ssim(self, x: Tensor, y: Tensor) -> Tensor:

        # Compute means
        ux = F.conv2d(x, self.gaussian_kernel, padding=self.kernel_size // 2, groups=3)
        uy = F.conv2d(y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=3)

        # Compute variances
        uxx = F.conv2d(x * x, self.gaussian_kernel, padding=self.kernel_size // 2, groups=3)
        uyy = F.conv2d(y * y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=3)
        uxy = F.conv2d(x * y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=3)
        vx = uxx - ux * ux
        vy = uyy - uy * uy
        vxy = uxy - ux * uy

        c1 = 0.01 ** 2
        c2 = 0.03 ** 2
        numerator = (2 * ux * uy + c1) * (2 * vxy + c2)
        denominator = (ux ** 2 + uy ** 2 + c1) * (vx + vy + c2)
        return numerator / (denominator + 1e-12)

    def _create_gaussian_kernel(self, kernel_size: int, sigma: float) -> Tensor:

        start = (1 - kernel_size) / 2
        end = (1 + kernel_size) / 2
        kernel_1d = torch.arange(start, end, step=1, dtype=torch.float)
        kernel_1d = torch.exp(-torch.pow(kernel_1d / sigma, 2) / 2)
        kernel_1d = (kernel_1d / kernel_1d.sum()).unsqueeze(dim=0)

        kernel_2d = torch.matmul(kernel_1d.t(), kernel_1d)
        kernel_2d = kernel_2d.expand(3, 1, kernel_size, kernel_size).contiguous()
        return kernel_2d

分散公式の変形

\begin{aligned} \sigma^2 &= \frac{(x_1 - \overline x)^2 + (x_2 - \overline x)^2 + \dots + (x_n - \overline x)^2}{n} \\ \\ &= \frac{x_1^2 -2x_1 \overline x + (\overline x) ^2 + x_2^2 -2x_2 \overline x + (\overline x) ^2 + \dots + x_n^2 -2x_n \overline x + (\overline x) ^2}{n} \\ \\ &= \frac{(x_1^2 + x_2^2 + \dots + x_n^2) - 2 \overline x (x_1 + x_2 + \dots + x_n) + n (\overline x)^2}{n} \\ \\ &= \overline {x^2} - 2 (\overline x)^2 + (\overline x)^2 \\ \\ &= \overline {x^2} - (\overline x)^2 \end{aligned}

References