🎡
ColabのTPU v2でStable Diffusionを動かしてみる
AIの処理にはNVIDIAのGPUがよく使われていますが、GoogleのTPUやAWSのTrainiumのシェアがもう少し高くなってもよいのになあと思っています。
私自身はTPUやTrainiumは使ってみたいと思いつつ、実際に使うのはGPUばかりでした。そこでColabから利用できるTPUでStable Diffusionの画像生成をしてみようと思いました。
ColabのランタイムにTPU v2を選択して、以下のコードを実行していきます。
パッケージのインストール
!pip install --upgrade jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install --upgrade jaxlib flax transformers ftfy diffusers
TPUの確認
TPUのデバイスが認識されているかを確認します。
import jax
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind
print(f"Found {num_devices} JAX devices of type {device_type}.")
以下のように表示されました。
Found 8 JAX devices of type TPU v2.
パイプラインのセットアップ
import jax.numpy as jnp
from diffusers import FlaxStableDiffusionPipeline
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="bf16",
dtype=jnp.bfloat16,
)
画像生成
プロンプトを指定して画像を生成します。
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
import random
from IPython.display import display
prompt = "masterpiece, best quality, 1girl" # @param {type:"string"}
prompts = [prompt] * jax.device_count()
p_params = replicate(params)
prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)
rng = jax.random.PRNGKey(random.randrange(1000000))
rng = jax.random.split(rng, jax.device_count())
images = pipeline(prompt_ids, p_params, rng, jit=True).images
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
for image in images:
display(image)
参考
Discussion