Stable Diffusion で遊んでみる (3) — わざわざ Diffusers で VAE を実装する
目的
diffusers
を使ったコードを見ると、AutoencoderKL
というものが使われていることがあって、arXiv:1312.6114 Auto-Encoding Variational Bayes とどういう関係にあるのかな?と思ったので少し調べた話。ほぼ備忘録。
AutoencoderKL
AutoencoderKL
は autoencoder_kl.py で定義されていて、forward を見ると、
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
という実装である。冒頭で
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
をしていて、AutoencoderKL.encode
がやっていることは、ほぼ Encoder.forward
だし、decode
も同様なので、結局は色々機能が追加された VAE(変分オートエンコーダ)ということかな?と思った。
VAE を実装してみる
「鶏を割くに焉んぞ牛刀を用いん」な部分は否めないが、わざわざ diffusers
を使って VAE を実装してみる。
必要なモジュールは全て diffusers
の中に揃っているので、後は VAE の実装テンプレがあれば良い。何でも良いのだが、ゼロから作る Deep Learning ❺ 用のコードが GitHub の vae.py にあって、道筋として単純だしこれに沿ってみることにした。
データセットは QMNIST
を使ってみる。
必要なモジュールの import をする:
from __future__ import annotations
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
import torchvision
from torchvision import transforms, datasets
from diffusers.models.autoencoders.vae import (
Decoder,
DiagonalGaussianDistribution,
Encoder
)
パラメータ類の設定:
epochs = 30
learning_rate = 3e-4
batch_size = 32
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
データローダの作成:
transform = transforms.Compose([
transforms.ToTensor(),
])
root = os.path.join(os.getenv("HOME"), ".torch")
dataset = datasets.QMNIST(root=root, train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=True
)
VAE のクラスの実装:
ゼロから作る Deep Learning ❺ の GitHub のコードをベースに AutoencoderKL
の実装を合わせこんでみた。
class VAE(nn.Module):
def __init__(self, in_channels=1, out_channels=1):
super().__init__()
latent_channel = 4
self.encoder = Encoder(in_channels=in_channels, out_channels=latent_channel)
self.decoder = Decoder(in_channels=latent_channel, out_channels=out_channels)
def get_loss(self, x):
h = self.encoder(x)
posterior = DiagonalGaussianDistribution(h)
z = posterior.sample() # reparametrization trick
x_hat = self.decoder(z)
batch_size = len(x)
L1 = F.mse_loss(x_hat, x, reduction="sum")
L2 = torch.sum(posterior.kl())
return (L1 + L2) / batch_size
訓練
%%time
model = VAE()
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
losses = []
for epoch in range(epochs):
loss_sum = 0.0
cnt = 0
for x, label in dataloader:
x = x.to(device)
optimizer.zero_grad()
loss = model.get_loss(x)
loss.backward()
optimizer.step()
loss_sum += loss.item()
cnt += 1
loss_avg = loss_sum / cnt
losses.append(loss_avg)
print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss_avg}")
Epoch [1/30], Loss: 47.38421940511068
Epoch [2/30], Loss: 42.57707533162435
...
Epoch [29/30], Loss: 37.03056034037272
Epoch [30/30], Loss: 37.003285912068684
CPU times: user 32min 16s, sys: 3.79 s, total: 32min 20s
Wall time: 32min 15s
潜在空間の次元が大きくて結構時間がかかってしまった。どれかの引数を指定して次元をもっと下げたら良いのかもしれないが、今回雑にやっているのでそこまでちゃんと調べていない。
可視化
with torch.no_grad():
sample_size = 64
latent_channel = 4
z = torch.randn(sample_size, 4*28*28).reshape(sample_size, latent_channel, 28, 28)
z = z.to(device)
x = model.decoder(z)
x = x.detach().cpu()
generated_images = x.view(sample_size, 1, 28, 28)
grid_img = torchvision.utils.make_grid(
generated_images, nrow=8, padding=2, normalize=True
)
plt.imshow(grid_img.permute(1, 2, 0))
plt.axis("off")
plt.show()
手書き数字風味の画像がランダムノイズから生成された。
まとめ
AutoencoderKLOutput を見ても潜在変数の入れ物に過ぎず、AutoencoderKL
が名前ほどに KL 成分が陽に表に出ているかは疑問なのだが、内部でやっていることは概ね VAE に近いものだなというのが分かった(但し、バニラ VAE として使っているわけではないので、標準正規分布
ついでに “内部でやっていることは概ね VAE に近い” であろうことを確認するために、内部実装の一部を書き換えて普通の変分オートエンコーダを実装してみて確認してみた。
Discussion