SSIM Lossを用いてConditionalVAEの損失関数を定義する(PyTorch)

に公開

ConditionalVAEとは

ConditionalVAE (CVAE) は、最も基底のAEから始まっていくつかの派生が加わったもので、順にAE→VAE→CVAEと変遷してきました。

  • AE(AutoEncoder):データの圧縮を行うエンコーダと、再構成を行うデコーダを組み合わせたもの
  • VAE(VariationalAE):AEが圧縮した特徴が、多次元正規分布に従うようにモデル構造と損失関数を調整したもの
  • CVAE(ConditionalVAE):VAEに、教師ラベルの概念を追加して潜在空間に条件を加えたもの

AEたちの損失関数

AEは、通常は入力データと再構成データとの乖離を測る再構成誤差としてMSE Lossが利用されます。これはVAEもCVAEも変わらず、それに加えて、VAE・CVAEは潜在表現を正規分布に近似するためにKLダイバージェンスを損失に利用します。
 KLダイバージェンスは、もともとは二つの分布がどの程度離れているかを測る尺度で、正規分布と潜在表現を比較することで損失関数として機能します。らしい。(証明略)
 定義をもとに整理すると、損失関数として扱うKLダイバージェンスは以下のように計算されます。

D_{KL}(Z||Y) = -\frac{1}{2}\sum_{dim=1}^{Dim}(1+\log{(\sigma_{dim})^2}-(\mu_{dim})^2-(\sigma_{dim})^2)

このとき、Zは潜在表現、Yは正規分布、Dimは潜在表現の次元数です。

今回は、画像を入出力に使うCVAEを構築することを仮定して、SSIM Lossを加えた損失関数を定義してみます。

SSIM Loss

SSIM Loss[1]とは、画像間の輝度、コントラスト、構造の3つの観点に着目した類似度を表すSSIM(Structural SIMilarity)を、損失関数に流用したものです。
 AEの損失関数として一般的なMSEは、ピクセルごとの距離を測るのみで、大域的な特徴を考慮することができないため再構成画像がぼやけてしまう課題がありました。そこでSSIMを利用することで周辺の特徴を考慮した学習が可能になったそうです。

PyTorchで実装

以上のKLダイバージェンスとSSIM LossをPyTorchで実装してみます。
SSIM Lossはこちらの方のコードを参考にしました。

class SSIMLoss(nn.Module):
    def __init__(self, kernel_size: int = 11, sigma: float = 1.5) -> None:

        """2つの画像間のstructural similarity (SSIM) を計算する
        Args:
            kernel_size (int): ガウシアンフィルタ(平滑化のため)のサイズ
            sigma (float): ガウシアンフィルタの標準偏差
        """
        super().__init__()
        self.kernel_size = kernel_size
        self.sigma = sigma
        self.groups = 1
        self.gaussian_kernel = self._create_gaussian_kernel(self.kernel_size, self.sigma)

    def forward(self, x: Tensor, y: Tensor, as_loss: bool = True) -> Tensor:
        if not self.gaussian_kernel.is_cuda:
            self.gaussian_kernel = self.gaussian_kernel.to(x.device)
        ssim_map = self._ssim(x, y[:,:self.groups,:,:])
        # 損失関数として利用する場合は as_loss = True
        if as_loss:
            return 1 - ssim_map.mean() 
        # 通常のSSIMとして出力
        else:
            return ssim_map

    def _ssim(self, x: Tensor, y: Tensor) -> Tensor:
        # 平均の計算
        ux = F.conv2d(x, self.gaussian_kernel, padding=self.kernel_size // 2, groups=self.groups)
        uy = F.conv2d(y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=self.groups)
        # 分散の計算
        uxx = F.conv2d(x * x, self.gaussian_kernel, padding=self.kernel_size // 2, groups=self.groups)
        uyy = F.conv2d(y * y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=self.groups)
        uxy = F.conv2d(x * y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=self.groups)
        vx = uxx - ux * ux
        vy = uyy - uy * uy
        vxy = uxy - ux * uy

        c1 = 0.01 ** 2
        c2 = 0.03 ** 2
        numerator = (2 * ux * uy + c1) * (2 * vxy + c2)
        denominator = (ux ** 2 + uy ** 2 + c1) * (vx + vy + c2)
        return numerator / (denominator + 1e-12)

    def _create_gaussian_kernel(self, kernel_size: int, sigma: float) -> Tensor:
        start = (1 - kernel_size) / 2
        end = (1 + kernel_size) / 2
        kernel_1d = torch.arange(start, end, step=1, dtype=torch.float)
        kernel_1d = torch.exp(-torch.pow(kernel_1d / sigma, 2) / 2)
        kernel_1d = (kernel_1d / kernel_1d.sum()).unsqueeze(dim=0)

        kernel_2d = torch.matmul(kernel_1d.t(), kernel_1d)
        kernel_2d = kernel_2d.expand(self.groups, 1, kernel_size, kernel_size).contiguous()
        return kernel_2d

class MSE_SSIM_KL_Loss(nn.Module):
    def __init__(self, alpha: float = 1, beta: float = 1) -> None:
        """MSE LossとSSIM LossとKLダイバージェンスを統合した損失の値を計算
        Args:
            alpha(float): SSIM Lossの比重
            beta(float): KLダイバージェンスの比重
        """
        super(MSE_SSIM_KL_Loss, self).__init__()
        self.mse = nn.MSELoss(reduction="mean")
        self.ssim = SSIMLoss()
        self.alpha = alpha
        self.beta = beta

    def forward(self, recon_x: Tensor, raw_x: Tensor, mu: Tensor, logvar: Tensor) -> Tensor:
        mse_loss = self.mse(recon_x, raw_x)
        # 比重が0の時は損失計算を行わない
        if self.alpha != 0:
            ssim_loss = self.ssim(recon_x, raw_x)
        else:
            ssim_loss = 0

        if self.beta != 0:
            kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(),dim=-1).mean()
        else:
            kl_div = 0
        return mse_loss + self.alpha * ssim_loss + self.beta * kl_div

ただ定義に従って実装しています。いろいろ粗があるかもしれません。

確認

ダミーデータで確認してみます。
再構成画像を適当なTensorにして、元画像をその0.8倍にします。(つまりよく似ている)
そして、平均muと分散の\logであるlogvarはそれぞれ0に近しいランダムな値にします。
\log{(\sigma)^2}\simeq0\iff(\sigma)^2\simeq1なので、mulogvarは標準正規分布によく近似されているとします。
この条件で損失の値を調べます。Tensorの形状は適当です。
なお、MSE_SSIM_KL_Lossforwardの返り値に、MSE Loss, SSIM Loss, KLダイバージェンスの各値のタプルを追加しました。
学習時にはそのままログとして残してもいいし、無くしてLossの値だけを返すようにしても無問題です。

x = torch.randn(5,1,128,128)
y = x * 0.8

mu = torch.randn(5,256) * 0.01
logvar = torch.randn(5,256) * 0.01

criterion = MSE_SSIM_KL_Loss()
loss, mse, ssim, kl = criterion(x,y,mu,logvar)
print(loss, mse, ssim, kl)
出力
tensor(0.1054) 0.0400715135037899 0.046952247619628906 0.018387366086244583

各Lossの値が小さくなるような条件でダミーデータを作ったので、それぞれのLossは小さくなっています。
うまくいってそうです。

脚注
  1. https://arxiv.org/abs/1807.02011 ↩︎

Discussion