💬

生成モデルを実装してみる ~拡散モデル実装までの道のり~(VAE編)

に公開

概要

拡散モデル(diffusion model)という単語を何かで見かけたけど、なにそれわからん。ってなったので実装して拡散モデルについてざっくり理解したいと思い、いろいろ調べたのでその備忘録になります。
データを生成するタイプのモデルだったので、クラシックなVAE、GAN、も実装して歴史を辿りながらdiffusion modelの実装を行おうと思います。

本記事では、その第1歩として VAE(Variational Autoencoder) を取り上げます。

  • VAE <-イマココ
  • GAN
  • Diffusion Model

目次

VAEとは

VAE(Variational Autoencoder)は、通常のオートエンコーダーに確率的な要素を加えた生成モデルです。
通常のオートエンコーダーでは、入力データは1つの固定された潜在変数zに圧縮されるため、生成できる画像は1種類だけです。

一方、VAEではこのzを確率分布(例:正規分布)として表現します。
これにより、同じ入力に対しても、少しずつ異なるzをサンプリングすることで、似ているけれど異なる画像を複数生成できます。

モデル構造

基本的な構造はオートエンコーダーと同様で、入力を潜在変数zに圧縮し、そこから元の画像を生成する仕組みです。

Encoder

エンコーダーでは、入力を潜在空間の平均 μ標準偏差 σ に変換します。
これらは、潜在変数が従う正規分布 \mathcal{N}(μ,σ^2) のパラメータになります。

この分布からzをサンプリングする際に用いられるのが再パラメータ化トリックです:

z = \mu + \sigma \cdot \epsilon, \epsilon\sim \mathcal{N}(0,1)

Decoder

デコーダーは、サンプリングされた潜在変数zを受け取り、元データを再構成します。
出力された再構成データと元の入力との誤差を元に損失を計算し、モデルの学習が行われます。

損失関数

VAEの損失関数は、以下の2つの要素から構成されます:

1. 再構成誤差(Reconstruction Loss)

元の入力と、デコーダーが生成したデータとの違いを表します。
これは、モデルがどれだけ正確にデータを復元できているかを示します。

2. KLダイバージェンス(Kullback-Leibler Divergence)

エンコーダーが出力する潜在変数の分布を、標準正規分布に近づけるようにするための制約項です。
標準正規分布は密度が中心に集中しており、中心付近で滑らかにサンプリングできる性質があります。
もし分布が中心から外れすぎると、学習されていない領域に入り、ノイズのようなデータが生成されやすくなります。

実装

では、ここまで書いてきたことを実装してみます。
データはMNISTを用います。

下準備

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F

from tqdm import tqdm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

class Config:
    BATCH_SIZE=50
    Z_DIM=2
    EPOCHS=20
    LR=1e-3
config = Config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

データの準備

train_data = torchvision.datasets.MNIST(root="./data",
                                           train=True,
                                           transform=transforms.ToTensor(),
                                           download=True)
val_data = torchvision.datasets.MNIST(root="./data",
                                           train=False,
                                           transform=transforms.ToTensor(),
                                           download=True)

trainDL = DataLoader(dataset=train_data,
                     batch_size=config.BATCH_SIZE,
                     shuffle=True,
                     num_workers=0)
valDL = DataLoader(dataset=val_data,
                     batch_size=config.BATCH_SIZE,
                     shuffle=True,
                     num_workers=0)

モデル定義

class Encoder(nn.Module):
    def __init__(self, z_dim):
        super().__init__()
        self.fc = nn.Linear(28*28, 400)
        self.fc_mu = nn.Linear(400, z_dim)
        self.fc_logvar = nn.Linear(400, z_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.relu(x)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)

        std = torch.exp(0.5 * logvar)
        epsilon = torch.randn_like(std)
        z = mu + std * epsilon
        return z, mu, logvar

class Decoder(nn.Module):
    def __init__(self, z_dim):
        super().__init__()
        self.fc = nn.Linear(z_dim, 400)
        self.fc2 = nn.Linear(400, 28*28)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.fc(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = torch.sigmoid(x)
        return x


class VAE(nn.Module):
    def __init__(self, z_dim):
        super().__init__()
        self.encoder = Encoder(z_dim)
        self.decoder = Decoder(z_dim)

    def forward(self, x):
        z, mu, logvar = self.encoder(x)
        x = self.decoder(z)
        return x, z, mu, logvar

損失関数定義

bce = nn.BCELoss(reduction='sum')

def criterion(input, predict, mu, logvar):
    input = input.view(input.size(0), -1)
    predict = predict.view(predict.size(0), -1)

    bce_loss = bce(predict, input)
    
    kld_loss = -0.5 * torch.sum(
        1 + logvar - mu.pow(2) - logvar.exp()
    )

    return bce_loss + kld_loss

学習

model = VAE(z_dim=config.Z_DIM).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config.LR)

for epoch in range(config.EPOCHS):
    model.train()
    total_loss = 0

    pbar = tqdm(trainDL, desc=f"Epoch {epoch+1}/{config.EPOCHS}", leave=True)

    for batch in pbar:
        x, _ = batch
        x = x.to(device)

        optimizer.zero_grad()
        x_pred, z, mu, logvar = model(x)
        loss = criterion(x, x_pred, mu, logvar)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        current_loss = loss.item() / x.size(0)
        pbar.set_postfix({"batch_loss": f"{current_loss:.4f}"})

    avg_train_loss = total_loss / len(trainDL.dataset)
    print(f"Epoch [{epoch+1}/{config.EPOCHS}] Train Loss: {avg_train_loss:.4f}")

実験結果

損失をプロットしてみるとこんな感じ。
まだ下がってそうだけど、モデル概要が知れればいいのでまずはこんなところで。

検証データを使ってちゃんと元画像を再現できているのかみてみる。

model.eval()

with torch.no_grad():
    x, _ = next(iter(valDL))
    x=x.to(device)
    x_pred, _, _, _ = model(x)

    x = x.cpu().numpy()
    x_pred = x_pred.view(-1, 1, 28, 28).cpu().numpy()

    for i in range(6):
        plt.subplot(2, 6, i + 1)
        plt.imshow(x[i][0], cmap='gray')
        plt.axis('off')

        plt.subplot(2, 6, i + 7)
        plt.imshow(x_pred[i][0], cmap='gray')
        plt.axis('off')

    plt.suptitle("Top: Original  |  Bottom: Reconstructed")
    plt.show()

まぁまぁできてそう。
5が3になったり、4が9になったりしてる。
学習進めたりモデルもう少し大きくしたりしたらこの辺もうまく再現できるようになるのか。
けど再現できてるのもありVAEの実感はできた。

次に適当な潜在変数をDecoderに突っ込んで何が生成されるかみてみる。

model.eval()
with torch.no_grad():
    z = torch.randn(16, config.Z_DIM).to(device)
    generated = model.decoder(z)
    generated = generated.view(-1, 1, 28, 28).cpu().numpy()

    # 表示
    for i in range(16):
        plt.subplot(4, 4, i + 1)
        plt.imshow(generated[i][0], cmap='gray')
        plt.axis('off')
    plt.suptitle("Generated Images from Random z")
    plt.show()

まぁまぁそれっぽいのが生成できてる気がする。

最後に潜在変数の分布についてプロットしてみる

model.eval()
all_z = []
all_labels = []

with torch.no_grad():
    for x, labels in valDL:
        x = x.to(device)
        _, z, _, _ = model(x)
        all_z.append(z.cpu())
        all_labels.append(labels)
    
all_z = torch.cat(all_z, dim=0).numpy()        # shape: [num_samples, 2]
all_labels = torch.cat(all_labels, dim=0).numpy()

# プロット
plt.figure(figsize=(8, 6))
scatter = plt.scatter(all_z[:, 0], all_z[:, 1], c=all_labels, cmap='tab10', alpha=0.7, s=10)
plt.colorbar(scatter, ticks=range(10))
plt.title("Latent Space Visualization (z_dim=2)")
plt.xlabel("z[0]")
plt.ylabel("z[1]")
plt.grid(True)
plt.show()

やはり4,9あたりがごちゃっとなってて検証データで実験した通りの分布になってる。

まとめ

というかんじで、雑にVAEを実装してみました。なんとなく概要は掴めた気がします。
次回はGANを実装してみようと思います。では。

Discussion