Closed3

Diffusers 使って 無料 Colab で StableCascade をとりあえず動かしてみる

PlatPlat

StableCascade:

https://huggingface.co/stabilityai/stable-cascade

カスケード方式(Würstchenアーキテクチャ)の新しいモデルです。
生成を試すだけなら↓のデモを試すといいと思います。
https://huggingface.co/spaces/multimodalart/stable-cascade

diffusers ライブラリでStableCascadeを動かそうと思うのですが、これを書いてる時点ではまだパイプラインのコードがマージされてないので Wurstchen v3 のブランチを使います。

PlatPlat

まず、前提としてこのモデルは多段階の生成を行う(DeepFloyd IFみたいに)ので、(VRAMが足りない場合は)複数のモデルを読み込んだりアンロードしたりすることになります。

StableCascade のアーキテクチャ図

この Stage C と書かれている部分がコアになっており、ここでベース画像(潜在空間上)を生成します。図では Generator となってますが、単純に Prior と呼ばれるっぽいです。レポは stabilityai/stable-cascade-prior です。

https://huggingface.co/stabilityai/stable-cascade-prior

その画像的な何かを Stage B と Stage A (VAE) で高解像度にアップスケール&実際の画像に変換します。この段階では画像をデコードする役割になっているので、まとめて Decoder と呼ばれます。レポは stabilityai/stable-cascade です。

https://huggingface.co/stabilityai/stable-cascade

PlatPlat

無料の Colab (T4) で動きます。ほぼモデルカードのコードと同じです。

%pip install accelerate transformers
%pip install git+https://github.com/kashif/diffusers.git@wuerstchen-v3
import torch
from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
import gc
device = "cuda"
prompt = "Anthropomorphic cat dressed as a pilot"
negative_prompt = "blurry, low quality, signature"
guidance_scale = 4.0
num_images_per_prompt = 1

num_inference_steps_prior = 20
num_inference_steps_decode = 10

prior を読み込みます。T4 ですがなぜか bfloat16 を使うようです。float16 だと NaN になりました。

prior = StableCascadePriorPipeline.from_pretrained(
    "stabilityai/stable-cascade-prior",
    torch_dtype=torch.bfloat16
).to(device)

潜在空間の画像を生成します。

prior_output = prior(
  prompt=prompt,
  height=1024,
  width=1024,
  negative_prompt=negative_prompt,
  guidance_scale=guidance_scale,
  num_images_per_prompt=num_images_per_prompt,
  num_inference_steps=num_inference_steps_prior
)
prior_output # NaN じゃない確認

パイプラインをアンロードします

del prior
# よくわからんけど一度実行してもVRAM開放されなかったら何回か実行するといいかも
gc.collect()
torch.cuda.empty_cache()

decoder を読み込みます。こっちはfloat16でも大丈夫です。

decoder = StableCascadeDecoderPipeline.from_pretrained(
    "stabilityai/stable-cascade",
    torch_dtype=torch.float16
).to(device)

デコードします。

decoder_output = decoder(
  image_embeddings=prior_output.image_embeddings.to(decoder.dtype),
  prompt=prompt,
  negative_prompt=negative_prompt,
  guidance_scale=0.0,
  output_type="pil",
  num_inference_steps=num_inference_steps_decode
).images

decoder_output[0] # 一枚目の画像を表示

StableCascade で生成したパイロット服の猫の画像

このスクラップは2024/02/21にクローズされました