🐍
SSIM Loss を PyTorch で実装
- 画像
と画像X から局所領域Y とx を切り抜くy - 局所領域
とx 内の画素値から、平均y と\mu _x 、 標準偏差\mu _y と\sigma _x 、共分散\sigma_y を計算\sigma_{xy} - 式
から局所領域における(1) を計算SSIM
- 1 ピクセルずつ xy 方向に局所領域をスライドさせ SSIM を再び計算。画像サイズ 256×256、局所領域サイズ 64×64 の場合、
回の SSIM の計算を行う必要がある (padding は 0 とする)。(256 - 64 + 1) \times (256 - 64 + 1) = 37,249
PyTorch の Conv2d を使った実装
- 局所領域と同じサイズの kernel を用意し、PyTorch Conv2d を使って、平均、分散、共分散を計算します。
- 実用上は、画像
と画像X を平滑化した後に SSIM を計算するので、uniform kernel ではなく gaussian kernel を使っています。Y - 標準偏差は
より計算しています。次の章で、この式の導出を行っています。\sigma^2 = \overline {x^2} - (\overline x)^2 - 定数は
、c_1 = (k_1 L)^2 のように定義されており、c_2 = (k_2 L)^2 はダイナミックレンジ (8 bit の場合 255)、L とk_1 はハイパーパラメータでデフォルト値は 0.01 と 0.03 です。k_2
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module
class SSIMLoss(Module):
def __init__(self, kernel_size: int = 11, sigma: float = 1.5) -> None:
"""Computes the structural similarity (SSIM) index map between two images.
Args:
kernel_size (int): Height and width of the gaussian kernel.
sigma (float): Gaussian standard deviation in the x and y direction.
"""
super().__init__()
self.kernel_size = kernel_size
self.sigma = sigma
self.gaussian_kernel = self._create_gaussian_kernel(self.kernel_size, self.sigma)
def forward(self, x: Tensor, y: Tensor, as_loss: bool = True) -> Tensor:
if not self.gaussian_kernel.is_cuda:
self.gaussian_kernel = self.gaussian_kernel.to(x.device)
ssim_map = self._ssim(x, y)
if as_loss:
return 1 - ssim_map.mean()
else:
return ssim_map
def _ssim(self, x: Tensor, y: Tensor) -> Tensor:
# Compute means
ux = F.conv2d(x, self.gaussian_kernel, padding=self.kernel_size // 2, groups=3)
uy = F.conv2d(y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=3)
# Compute variances
uxx = F.conv2d(x * x, self.gaussian_kernel, padding=self.kernel_size // 2, groups=3)
uyy = F.conv2d(y * y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=3)
uxy = F.conv2d(x * y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=3)
vx = uxx - ux * ux
vy = uyy - uy * uy
vxy = uxy - ux * uy
c1 = 0.01 ** 2
c2 = 0.03 ** 2
numerator = (2 * ux * uy + c1) * (2 * vxy + c2)
denominator = (ux ** 2 + uy ** 2 + c1) * (vx + vy + c2)
return numerator / (denominator + 1e-12)
def _create_gaussian_kernel(self, kernel_size: int, sigma: float) -> Tensor:
start = (1 - kernel_size) / 2
end = (1 + kernel_size) / 2
kernel_1d = torch.arange(start, end, step=1, dtype=torch.float)
kernel_1d = torch.exp(-torch.pow(kernel_1d / sigma, 2) / 2)
kernel_1d = (kernel_1d / kernel_1d.sum()).unsqueeze(dim=0)
kernel_2d = torch.matmul(kernel_1d.t(), kernel_1d)
kernel_2d = kernel_2d.expand(3, 1, kernel_size, kernel_size).contiguous()
return kernel_2d
分散公式の変形
References
- https://github.com/kornia/kornia/blob/master/kornia/losses/ssim.py
- https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/
- https://github.com/scikit-image/scikit-image/blob/master/skimage/metrics/_structural_similarity.py
- https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/metrics/functional/ssim.py
- https://github.com/pytorch/ignite/blob/master/ignite/metrics/ssim.py
Discussion