🎡

ColabのTPU v2でStable Diffusionを動かしてみる

2024/06/28に公開

AIの処理にはNVIDIAのGPUがよく使われていますが、GoogleのTPUやAWSのTrainiumのシェアがもう少し高くなってもよいのになあと思っています。

私自身はTPUやTrainiumは使ってみたいと思いつつ、実際に使うのはGPUばかりでした。そこでColabから利用できるTPUでStable Diffusionの画像生成をしてみようと思いました。

ColabのランタイムにTPU v2を選択して、以下のコードを実行していきます。

パッケージのインストール

!pip uninstall -y tensorflow && pip install tensorflow-cpu
!pip install --upgrade jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install --upgrade jaxlib flax transformers ftfy diffusers

TPUの確認

TPUのデバイスが認識されているかを確認します。

import jax

num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind

print(f"Found {num_devices} JAX devices of type {device_type}.")

以下のように表示されました。

Found 8 JAX devices of type TPU v2.

パイプラインのセットアップ

import jax.numpy as jnp
from diffusers import FlaxStableDiffusionPipeline
from flax.jax_utils import replicate

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    revision="bf16",
    dtype=jnp.bfloat16,
)

p_params = replicate(params)

画像生成

プロンプトを指定して画像を生成します。最初の実行は3分ほどかかりますが、2回目以降は10秒程度で画像が生成できるようになります。

from jax import pmap
from flax.training.common_utils import shard
import random
from IPython.display import display

prompt = "masterpiece, best quality, 1girl" # @param {type:"string"}
prompts = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)

rng = jax.random.PRNGKey(random.randrange(1000000))
rng = jax.random.split(rng, jax.device_count())

images = pipeline(prompt_ids, p_params, rng, jit=True).images
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)

for image in images:
    display(image)

参考

https://huggingface.co/blog/sdxl_jax

https://huggingface.co/docs/diffusers/using-diffusers/stable_diffusion_jax_how_to

https://tech.dentsusoken.com/entry/2022/10/14/Stable_Diffusion_TPU版の使い方

Discussion