SSIM Lossを用いてConditionalVAEの損失関数を定義する(PyTorch)
ConditionalVAEとは
ConditionalVAE (CVAE) は、最も基底のAEから始まっていくつかの派生が加わったもので、順にAE→VAE→CVAEと変遷してきました。
- AE(AutoEncoder):データの圧縮を行うエンコーダと、再構成を行うデコーダを組み合わせたもの
- VAE(VariationalAE):AEが圧縮した特徴が、多次元正規分布に従うようにモデル構造と損失関数を調整したもの
- CVAE(ConditionalVAE):VAEに、教師ラベルの概念を追加して潜在空間に条件を加えたもの
AEたちの損失関数
AEは、通常は入力データと再構成データとの乖離を測る再構成誤差としてMSE Lossが利用されます。これはVAEもCVAEも変わらず、それに加えて、VAE・CVAEは潜在表現を正規分布に近似するためにKLダイバージェンスを損失に利用します。
KLダイバージェンスは、もともとは二つの分布がどの程度離れているかを測る尺度で、正規分布と潜在表現を比較することで損失関数として機能します。らしい。(証明略)
定義をもとに整理すると、損失関数として扱うKLダイバージェンスは以下のように計算されます。
このとき、
今回は、画像を入出力に使うCVAEを構築することを仮定して、SSIM Lossを加えた損失関数を定義してみます。
SSIM Loss
SSIM Loss[1]とは、画像間の輝度、コントラスト、構造の3つの観点に着目した類似度を表すSSIM(Structural SIMilarity)を、損失関数に流用したものです。
AEの損失関数として一般的なMSEは、ピクセルごとの距離を測るのみで、大域的な特徴を考慮することができないため再構成画像がぼやけてしまう課題がありました。そこでSSIMを利用することで周辺の特徴を考慮した学習が可能になったそうです。
PyTorchで実装
以上のKLダイバージェンスとSSIM LossをPyTorchで実装してみます。
SSIM Lossはこちらの方のコードを参考にしました。
class SSIMLoss(nn.Module):
def __init__(self, kernel_size: int = 11, sigma: float = 1.5) -> None:
"""2つの画像間のstructural similarity (SSIM) を計算する
Args:
kernel_size (int): ガウシアンフィルタ(平滑化のため)のサイズ
sigma (float): ガウシアンフィルタの標準偏差
"""
super().__init__()
self.kernel_size = kernel_size
self.sigma = sigma
self.groups = 1
self.gaussian_kernel = self._create_gaussian_kernel(self.kernel_size, self.sigma)
def forward(self, x: Tensor, y: Tensor, as_loss: bool = True) -> Tensor:
if not self.gaussian_kernel.is_cuda:
self.gaussian_kernel = self.gaussian_kernel.to(x.device)
ssim_map = self._ssim(x, y[:,:self.groups,:,:])
# 損失関数として利用する場合は as_loss = True
if as_loss:
return 1 - ssim_map.mean()
# 通常のSSIMとして出力
else:
return ssim_map
def _ssim(self, x: Tensor, y: Tensor) -> Tensor:
# 平均の計算
ux = F.conv2d(x, self.gaussian_kernel, padding=self.kernel_size // 2, groups=self.groups)
uy = F.conv2d(y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=self.groups)
# 分散の計算
uxx = F.conv2d(x * x, self.gaussian_kernel, padding=self.kernel_size // 2, groups=self.groups)
uyy = F.conv2d(y * y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=self.groups)
uxy = F.conv2d(x * y, self.gaussian_kernel, padding=self.kernel_size // 2, groups=self.groups)
vx = uxx - ux * ux
vy = uyy - uy * uy
vxy = uxy - ux * uy
c1 = 0.01 ** 2
c2 = 0.03 ** 2
numerator = (2 * ux * uy + c1) * (2 * vxy + c2)
denominator = (ux ** 2 + uy ** 2 + c1) * (vx + vy + c2)
return numerator / (denominator + 1e-12)
def _create_gaussian_kernel(self, kernel_size: int, sigma: float) -> Tensor:
start = (1 - kernel_size) / 2
end = (1 + kernel_size) / 2
kernel_1d = torch.arange(start, end, step=1, dtype=torch.float)
kernel_1d = torch.exp(-torch.pow(kernel_1d / sigma, 2) / 2)
kernel_1d = (kernel_1d / kernel_1d.sum()).unsqueeze(dim=0)
kernel_2d = torch.matmul(kernel_1d.t(), kernel_1d)
kernel_2d = kernel_2d.expand(self.groups, 1, kernel_size, kernel_size).contiguous()
return kernel_2d
class MSE_SSIM_KL_Loss(nn.Module):
def __init__(self, alpha: float = 1, beta: float = 1) -> None:
"""MSE LossとSSIM LossとKLダイバージェンスを統合した損失の値を計算
Args:
alpha(float): SSIM Lossの比重
beta(float): KLダイバージェンスの比重
"""
super(MSE_SSIM_KL_Loss, self).__init__()
self.mse = nn.MSELoss(reduction="mean")
self.ssim = SSIMLoss()
self.alpha = alpha
self.beta = beta
def forward(self, recon_x: Tensor, raw_x: Tensor, mu: Tensor, logvar: Tensor) -> Tensor:
mse_loss = self.mse(recon_x, raw_x)
# 比重が0の時は損失計算を行わない
if self.alpha != 0:
ssim_loss = self.ssim(recon_x, raw_x)
else:
ssim_loss = 0
if self.beta != 0:
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(),dim=-1).mean()
else:
kl_div = 0
return mse_loss + self.alpha * ssim_loss + self.beta * kl_div
ただ定義に従って実装しています。いろいろ粗があるかもしれません。
確認
ダミーデータで確認してみます。
再構成画像を適当なTensor
にして、元画像をその0.8倍にします。(つまりよく似ている)
そして、平均mu
と分散のlogvar
はそれぞれ0に近しいランダムな値にします。
mu
とlogvar
は標準正規分布によく近似されているとします。
この条件で損失の値を調べます。Tensor
の形状は適当です。
なお、MSE_SSIM_KL_Loss
のforward
の返り値に、MSE Loss, SSIM Loss, KLダイバージェンスの各値のタプルを追加しました。
学習時にはそのままログとして残してもいいし、無くしてLossの値だけを返すようにしても無問題です。
x = torch.randn(5,1,128,128)
y = x * 0.8
mu = torch.randn(5,256) * 0.01
logvar = torch.randn(5,256) * 0.01
criterion = MSE_SSIM_KL_Loss()
loss, mse, ssim, kl = criterion(x,y,mu,logvar)
print(loss, mse, ssim, kl)
tensor(0.1054) 0.0400715135037899 0.046952247619628906 0.018387366086244583
各Lossの値が小さくなるような条件でダミーデータを作ったので、それぞれのLossは小さくなっています。
うまくいってそうです。
Discussion