📘

FastComposerをgoogle colabで試してみた。

2023/05/18に公開

FastComposerとは

MITから出された2つの画像を混ぜ合わせて任意の画像を生成するtext2imgのdiffusion modelです。
https://github.com/mit-han-lab/fastcomposer

リンク

Colab
github

準備

Google Colabを開き、メニューから「ランタイム→ランタイムのタイプを変更」でランタイムを「GPU」に変更します。

環境構築

インストール手順です。

!pip install transformers accelerate datasets evaluate diffusers xformers triton scipy clip
!git clone https://github.com/mit-han-lab/fastcomposer.git
%cd fastcomposer
!python setup.py install

推論

(1)モデルのダウンロード

%cd /content/fastcomposer
!mkdir -p model/fastcomposer
%cd model/fastcomposer
!wget https://huggingface.co/mit-han-lab/fastcomposer/resolve/main/pytorch_model.bin

(2)推論
推論を実行するにあたってそのままscript/inference.shを実行してもgoogle colabでは動かないみたい。
fastcomposer/inference.pyを以下のスクリプトに変更

fastcomposer/inference.py
from fastcomposer.transforms import get_object_transforms
from fastcomposer.data import DemoDataset
from fastcomposer.model import FastComposerModel
from diffusers import StableDiffusionPipeline
from transformers import CLIPTokenizer
from accelerate.utils import set_seed
from fastcomposer.utils import parse_args
from accelerate import Accelerator
from pathlib import Path
from PIL import Image
import numpy as np
import torch
import os
from tqdm.auto import tqdm
from fastcomposer.pipeline import (
    stable_diffusion_call_with_references_delayed_conditioning,
)
import types
import itertools
import os


@torch.no_grad()
def main():
    args = parse_args()
    accelerator = Accelerator(
        mixed_precision=args.mixed_precision,
    )

    # Handle the repository creation
    if accelerator.is_main_process:
        if args.output_dir is not None:
            os.makedirs(args.output_dir, exist_ok=True)
    accelerator.wait_for_everyone()

    # If passed along, set the training seed now.
    if args.seed is not None:
        set_seed(args.seed)

    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    pipe = StableDiffusionPipeline.from_pretrained(
        args.pretrained_model_name_or_path, torch_dtype=weight_dtype
    )

    # add own
    pipe = pipe.to("cuda")
    
    model = FastComposerModel.from_pretrained(args)

    ckpt_name = "pytorch_model.bin"

    model.load_state_dict(
        torch.load(Path(args.finetuned_model_path) / ckpt_name, map_location="cuda")
    ) # change cpu -> cuda

    model = model.to(device=accelerator.device, dtype=weight_dtype)

    pipe.unet = model.unet

    if args.enable_xformers_memory_efficient_attention:
        pipe.unet.enable_xformers_memory_efficient_attention()

    pipe.text_encoder = model.text_encoder
    pipe.image_encoder = model.image_encoder

    pipe.postfuse_module = model.postfuse_module

    pipe.inference = types.MethodType(
        stable_diffusion_call_with_references_delayed_conditioning, pipe
    )

    del model

    pipe = pipe.to(accelerator.device)

    # Set up the dataset
    tokenizer = CLIPTokenizer.from_pretrained(
        args.pretrained_model_name_or_path,
        subfolder="tokenizer",
        revision=args.revision,
    )

    object_transforms = get_object_transforms(args)

    demo_dataset = DemoDataset(
        test_caption=args.test_caption,
        test_reference_folder=args.test_reference_folder,
        tokenizer=tokenizer,
        object_transforms=object_transforms,
        device=accelerator.device,
        max_num_objects=args.max_num_objects,
    )

    image_ids = os.listdir(args.test_reference_folder)
    print(f"Image IDs: {image_ids}")
    demo_dataset.set_image_ids(image_ids)

    unique_token = "<|image|>"

    prompt = args.test_caption
    prompt_text_only = prompt.replace(unique_token, "")

    os.makedirs(args.output_dir, exist_ok=True)

    batch = demo_dataset.get_data()

    input_ids = batch["input_ids"].to(accelerator.device)
    text = tokenizer.batch_decode(input_ids)[0]
    print(prompt)
    # print(input_ids)
    image_token_mask = batch["image_token_mask"].to(accelerator.device)

    # print(image_token_mask)
    all_object_pixel_values = (
        batch["object_pixel_values"].unsqueeze(0).to(accelerator.device)
    )
    num_objects = batch["num_objects"].unsqueeze(0).to(accelerator.device)

    all_object_pixel_values = all_object_pixel_values.to(
        dtype=weight_dtype, device=accelerator.device
    )

    object_pixel_values = all_object_pixel_values  # [:, 0, :, :, :]
    if pipe.image_encoder is not None:
        global_object_embeds = pipe.image_encoder(object_pixel_values)
    else:
        global_object_embeds = None

    encoder_hidden_states = pipe.text_encoder(
        input_ids, image_token_mask, global_object_embeds, num_objects
    )[0]

    encoder_hidden_states_text_only = pipe._encode_prompt(
        prompt_text_only,
        accelerator.device,
        args.num_images_per_prompt,
        do_classifier_free_guidance=False,
    )

    encoder_hidden_states = pipe.postfuse_module(
        encoder_hidden_states,
        global_object_embeds,
        image_token_mask,
        num_objects,
    )

    cross_attention_kwargs = {}

    images = pipe.inference(
        prompt_embeds=encoder_hidden_states,
        num_inference_steps=args.inference_steps,
        height=args.generate_height,
        width=args.generate_width,
        guidance_scale=args.guidance_scale,
        num_images_per_prompt=args.num_images_per_prompt,
        cross_attention_kwargs=cross_attention_kwargs,
        prompt_embeds_text_only=encoder_hidden_states_text_only,
        start_merge_step=args.start_merge_step,
    ).images

    for instance_id in range(args.num_images_per_prompt):
        images[instance_id].save(
            os.path.join(
                args.output_dir,
                f"output_{instance_id}.png",
            )
        )


if __name__ == "__main__":
    main()

変更した後に以下のスクリプトを実行してください。

%cd /content/fastcomposer
# !bash script/run_inference.sh

from accelerate.utils import write_basic_config

write_basic_config()

CAPTION="a man <|image|> and a man <|image|> are reading book together"
DEMO_NAME="newton_einstein"

!CUDA_VISIBLE_DEVICES=0 accelerate launch \
    --mixed_precision=fp16 \
    fastcomposer/inference.py \
    --pretrained_model_name_or_path runwayml/stable-diffusion-v1-5 \
    --finetuned_model_path model/fastcomposer \
    --test_reference_folder data/newton_einstein \
    --test_caption "a man <|image|> and a man <|image|> are reading book together" \
    --output_dir outputs/newton_einstein \
    --mixed_precision fp16 \
    --image_encoder_type clip \
    --image_encoder_name_or_path openai/clip-vit-large-patch14 \
    --num_image_tokens 1 \
    --max_num_objects 2 \
    --object_resolution 224 \
    --generate_height 512 \
    --generate_width 512 \
    --num_images_per_prompt 1 \
    --num_rows 1 \
    --seed 42 \
    --guidance_scale 5 \
    --inference_steps 50 \
    --start_merge_step 10 \
    --no_object_augmentation

推論実行結果は以下の通り

from PIL import Image
display(Image.open("/content/fastcomposer/outputs/newton_einstein/output_0.png"))

Advanced Application

偉大なる孫正義さんとビルゲイツを並べてみましょう。
(1)データの準備

!wget https://www.softbank.jp/corp/set/data/aboutus/profile/officer/img/officer-01.jpg -P /content/fastcomposer/data/son_bill/son
!wget https://images.forbesjapan.com/media/article/60581/images/main_image_ef0d8efdc943d72876e65a13a5e957af3b954661.jpg -P /content/fastcomposer/data/son_bill/bill

(2)推論実行

%cd /content/fastcomposer
# !bash script/run_inference.sh

from accelerate.utils import write_basic_config

write_basic_config()

!CUDA_VISIBLE_DEVICES=0 accelerate launch \
    --mixed_precision=fp16 \
    fastcomposer/inference.py \
    --pretrained_model_name_or_path runwayml/stable-diffusion-v1-5 \
    --finetuned_model_path model/fastcomposer \
    --test_reference_folder data/son_bill \
    --test_caption "a man <|image|> and a man <|image|> are reading book together" \
    --output_dir outputs/son_bill \
    --mixed_precision fp16 \
    --image_encoder_type clip \
    --image_encoder_name_or_path openai/clip-vit-large-patch14 \
    --num_image_tokens 1 \
    --max_num_objects 2 \
    --object_resolution 224 \
    --generate_height 512 \
    --generate_width 512 \
    --num_images_per_prompt 1 \
    --num_rows 1 \
    --seed 42 \
    --guidance_scale 5 \
    --inference_steps 50 \
    --start_merge_step 10 \
    --no_object_augmentation

結果です。

ガチで並んでるww

最後に

今回は2つの画像を用いて新しい画像を生成するImage guidedなdiffusion modelであるfastcomposerを試してみました。なんか高精度すぎて笑いました。今回はstable diffusionのv1.5を利用しましたが、もしかしたらアニメ向けのモデルに変えたら普通にアニメアイドル同士を並べたりなんてこともできそうな感じですね。試してみたい。

今後ともLLM, Diffusion model, Image Analysis, 3Dに関連する試した記事を投稿していく予定なのでよろしくお願いします。

Discussion