GLIGEN diffusersを使ってみた。
GLIGENとは
GLIGENはtext-to-imageのdiffusion modelの一つであり、bounding boxとテキストを指定することによって画面内にそのObjectを生成することができたり、Depth/Segment/Canny/Pose Imageなど色々な画像に対してcaptionを加えることでそのguide通りにImage Objectを生成することが可能になっているモデルです。ControlNetと似ているところがありますが、複数のObjectを位置を指定してコントロールできるのは少し異なる部分かなと個人的に感じています。以下のProject pageに結果やモデルの説明が載っています。是非参考にしてみてください。huggingface spacesにもDemoが載っています。
リンク
準備
Google Colabを開き、メニューから「ランタイム→ランタイムのタイプを変更」でランタイムを「GPU」に変更します。
環境構築
インストール手順は以下の通りです。
今回はbounding boxを作成するためにzeroshot object detectionのSOTAであるGrounding DINOを利用します。
!pip install transformers accelerate scipy safetensors
!git clone https://github.com/gligen/diffusers.git
!pip install git+https://github.com/gligen/diffusers.git
# Installation for GroundingDINO
%cd /content
!git clone https://github.com/IDEA-Research/GroundingDINO.git
%cd /content/GroundingDINO
!pip install -q -e .
!pip install -q roboflow
推論
(1) ライブラリのインポート
import argparse
from functools import partial
import cv2
import requests
from io import BytesIO
from PIL import Image
import numpy as np
from pathlib import Path
import random
import warnings
warnings.filterwarnings("ignore")
import torch
from torchvision.ops import box_convert
from groundingdino.models import build_model
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
from groundingdino.util.inference import annotate, load_image, predict
import groundingdino.datasets.transforms as T
from huggingface_hub import hf_hub_download
(2) Utility関数の設定
def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)
args = SLConfig.fromfile(cache_config_file)
model = build_model(args)
args.device = device
cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
checkpoint = torch.load(cache_file, map_location='cpu')
log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
print("Model loaded from {} \n => {}".format(cache_file, log))
_ = model.eval()
return model
(3) モデルのロード
・detection model
ckpt_repo_id = "ShilongLiu/GroundingDINO"
ckpt_filenmae = "groundingdino_swint_ogc.pth"
ckpt_config_filename = "GroundingDINO_SwinT_OGC.cfg.py"
dino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename)
・gligen pipeline
# load gligen pipeline
from diffusers import StableDiffusionGLIGENPipeline
pipe = StableDiffusionGLIGENPipeline.from_pretrained("gligen/diffusers-inpainting-text-box", revision="fp16", torch_dtype=torch.float16)
pipe.to("cuda")
(4) 推論データの準備
# download data for inference
%cd /content
!wget https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/art_dog_birthdaycake.png
(5) 推論
1 step: grounding dino detection
# grounding dino detection
import os
import supervision as sv
local_image_path = "art_dog_birthdaycake.png"
TEXT_PROMPT = "dog. cake."
BOX_TRESHOLD = 0.35
TEXT_TRESHOLD = 0.25
image_source, image = load_image(local_image_path)
boxes, logits, phrases = predict(
model=dino_model,
image=image,
caption=TEXT_PROMPT,
box_threshold=BOX_TRESHOLD,
text_threshold=TEXT_TRESHOLD
)
annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
annotated_frame = annotated_frame[...,::-1] # BGR to RGB
2 step: Gligen Image Inpainting
image_source = Image.fromarray(image_source)
annotated_frame = Image.fromarray(annotated_frame)
image_mask = Image.fromarray(image_mask)
# Resize
image_source_for_inpaint = image_source.resize((512, 512))
image_mask_for_inpaint = image_mask.resize((512, 512))
# get bbox
xyxy_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").tolist()
# define prompts for each box
gligen_phrases = ['a cat', 'a rose']
prompt = "'a cat', 'a rose'"
num_box = len(boxes)
image_inpainting = pipe(
prompt,
num_images_per_prompt = 2,
gligen_phrases = gligen_phrases,
gligen_inpaint_image = image_source_for_inpaint,
gligen_boxes = xyxy_boxes,
gligen_scheduled_sampling_beta=1,
output_type="numpy",
num_inference_steps=50
).images
3 step: display Image
# display image
image_inpainting = (image_inpainting * 255).astype(np.uint8)
image_inpainting = np.concatenate(image_inpainting, axis=1)
Image.fromarray(image_inpainting).resize((image_source.size[0]*2, image_source.size[1]))
Adavanced Application
ChatGPT+Grounding DINO+GLIGENでInstructPix2Pixをやってみた。で紹介しています。
最後に
今回はGLIGENをdiffusersのpipelineの中で実行するデモを試してみました。object detectionと組み合わせると写真を編集することがかなり容易になったと感じました。Instruct Pix2Pixでは部分の指定まではなかなか難しかったので、そこの部分が改善されているのが非常にGoodな感じでした。zeroshot object detection, segment anythingを利用して、promptで「replace chair into blue sofa at the left side」みたいに打つと簡単に画像を編集できるようになるとかなり人間の直感に近いツールを開発できそうな予感でした。このあたりのアプリケーションについても実際に実装できるか試してみたいです。
今後ともLLM, Diffusion model, Image Analysis, 3Dに関連する試した記事を投稿していく予定なのでよろしくお願いします。
Discussion