🖼️

画像生成AI:StableDiffusionの仕組み

2024/09/13に公開

https://huggingface.co/spaces/stabilityai/stable-diffusion

SD(Stable Diffusion)は、画像生成AIのデファクトスタンダードです。

前回SDXLの解説を書いたのですが、そもそもその大前提であるSDが難しいということで、今回の記事を用意しました。相補的に使っていただければと思います。

ちょっと今回は技術オタク寄りの内容かもしれません。(書いてたら説明したくなってしまったため)

対象

今回は、元論文

https://arxiv.org/abs/2112.10752

を参考にしています。Submitted on 2021/12とあり、2024/9現在、3年近くも経過してしまったのでもう古い話題かもしれないのですが、SD2, SD3, SDXLなどを理解する上でも重要です。

次元の確認

「このモデルはプロンプトからH×Wの解像度の、RGBフルカラー画像を生成するものだ」というのを常に念頭においておきましょう。

混乱を避けるために時々次元を確認するのが便利です。
今回は単純なモデルなので項目は少なめです。

変数名 種類 次元
x 画像 H×W×3
z = E(x) latent h×w×c
f downsample factor scalar(f = H/h = W/w)
m = log 2 f log f scalar(0, 1, 2, 3, 4, 5)
x~ = D(E(x)) 画像 H×W×3

推論の仕組み

SDの推論はy(プロンプト) → x~(画像)という処理で、大まかに以下の7stepで行われています。

①画像xのencodeを行う

②noisingを行う

これは下のような正規分布ノイズを加えるという意味

αtx0(潜在空間のベクトル長さを調整したもの)を中心に、等方向的なノイズを加えます。

α・σはSignal Noise Ratioを決めるパラメーター

③ϕi(zt)とτθ(y)を次のようにconcatする

④ただし、Q, K, Vは学習可能射影Wを使って次のように作る

ϕi(zt) → Q

y → τ(y) → K, V

Q, K, Vはcross-attentionなので、denoising前と後との一貫性を保持する

⑤denoisingを行う

ただしεはconditional denoising autoencoderと呼ばれる、オートエンコーダーxθのリパラメトリゼーションで、noisingのパラメーターを含む。

⑥denoisingをあとT-1回行う

⑦画像x~へのdecodeを行う。

こうして欲しかった推論結果y(プロンプト) → x~(画像)が得られるわけです。

trainの仕組み

loss

ただし

\epsilon = \frac{\alpha_t x_0}{\sigma_t}

である。

学習可能なパラメーターはθであるから、θでこれをSDGすれば、εθとτθはjointly optimized

正則化

KL-reg, VQ-regという、以前発明された方法を使います。

実装の確認

https://github.com/CompVis/latent-diffusion

predictはタスクごとに行う仕様で、以下のような実行可能pyファイルがたくさんあります。

https://github.com/CompVis/latent-diffusion/blob/main/scripts/txt2img.py

いずれにせよforwardはmodels以下のddpm(後述)でハンドリングされています。

https://github.com/CompVis/latent-diffusion/blob/a506df5756472e2ebaf9078affdde2c4f1502cd4/ldm/models/diffusion/ddpm.py#L1402

trainはモデルごとに行う仕様で、以下の2種類があります。

# AutoEncoder, E
CUDA_VISIBLE_DEVICES=<GPU_ID> python main.py --base configs/autoencoder/<config_spec>.yaml -t --gpus 0,    
# Latent Diffusion Model, tau_theta, epsilon_theta
CUDA_VISIBLE_DEVICES=<GPU_ID> python main.py --base configs/latent-diffusion/<config_spec>.yaml -t --gpus 0,

そのため、ステップごとに対応したクラスがあるはずです。

Class解説

LightningDataModule

train/validation/test splitやデータ準備、変換が標準化されモデル間で一貫します。

https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningDataModule.html

ImageNetを例にとると、具体的なデータは以下のように呼んでいます。

https://github.com/CompVis/latent-diffusion/blob/main/ldm/data/imagenet.py

①AutoEncoder

first_stage_modelというのが入り口で、この最初のステップに対応します。
pythonではなくyamlが貼られていてギョッとすると思いますが、diffuserフレームワークは設定ファイルをインスタンス化するのです。(慣れると読みやすいかも)

https://github.com/CompVis/latent-diffusion/blob/a506df5756472e2ebaf9078affdde2c4f1502cd4/models/first_stage_models/kl-f32/config.yaml

これはAutoencoderKLクラスを参照しています。

このconfigから、

https://github.com/CompVis/latent-diffusion/blob/main/configs/autoencoder/autoencoder_kl_64x64x3.yaml

このAutoencoderKLクラスをinstantiateする仕組みです。以下を読むと、__init__でinit_from_ckptされていることがわかります。

https://github.com/CompVis/latent-diffusion/blob/a506df5756472e2ebaf9078affdde2c4f1502cd4/ldm/models/autoencoder.py#L285

②noising

以下のように正規分布からサンプルしているだけです。

posterior = DiagonalGaussianDistribution(moments)
z = posterior.sample()

実装はtorch.randnです。

https://github.com/CompVis/latent-diffusion/blob/a506df5756472e2ebaf9078affdde2c4f1502cd4/ldm/modules/distributions/distributions.py#L24

③concat・⑤denoising

このステップではx = ϕi(zt)とc = τθ(y)をconcatしますが、それを実行しているのはこの行です。

def forward(self, x, c, *args, **kwargs):
	  t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
	  if self.model.conditioning_key is not None:
	      assert c is not None
	      if self.cond_stage_trainable:
	          c = self.get_learned_conditioning(c)
	      if self.shorten_cond_schedule:  # TODO: drop this option
	          tc = self.cond_ids[t].to(self.device)
	          c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
	  return self.p_losses(x, c, t, *args, **kwargs)

それに渡すc = τθ(y)を得ているのはこの行です

c = self.cond_stage_model.encode(c)

このクラスはddpm(Denoising Diffusion PyTorch Model)であり、U-Netのckptをロードしています。

④projection

学習可能な射影はConv2Dで実装されています。LinearAttentionクラスですね。

https://github.com/CompVis/latent-diffusion/blob/a506df5756472e2ebaf9078affdde2c4f1502cd4/ldm/modules/attention.py#L80

self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
k = k.softmax(dim=-1)  
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)

SpatialSelfAttention・CrossAttentionもあります。

⑦decoding

encodingの時に使ったのと同じモデルのdecodeを呼んでいます。

https://github.com/CompVis/latent-diffusion/blob/a506df5756472e2ebaf9078affdde2c4f1502cd4/ldm/models/diffusion/ddpm.py#L737

output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
force_not_quantize=predict_cids or force_not_quantize)
for i in range(z.shape[-1])]

おわりに

いかがでしたでしょうか?
本記事では「SDとは何か」の解説をしました。

この記事ではpredict/trainの原理、公式にどう実装されているかを解説しました。

この記事が、SD2, SD3, SDXLなどを理解するのに役立てば幸いです。

Discussion