📘
TextDiffusersをgoogle colabで試してみた。
text diffusersとは
text diffusersはmicrosoftから発表されたtextを組み入れて画像を生成することができるdiffusion modelです。今まではなかなか綺麗に文字を入れられなかったですが、今回のモデルの登場で画像内にテキストを入れることが可能になりました。
リンク
準備
Google Colabを開き、メニューから「ランタイム→ランタイムのタイプを変更」でランタイムを「GPU」に変更します。
環境構築
インストール手順です。
!git clone https://github.com/microsoft/unilm.git
%cd unilm/textdiffuser
!pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117
!pip install datasets==2.11.0 tokenizers==0.13.3 transformers==4.27.4 xformers==0.0.16 accelerate==0.18.0 triton==2.0.0.post1 termcolor==2.3.0 tinydb flask
!pip install TorchSnooper
# install diffusers
%cd /content/unilm/textdiffuser
!git clone https://github.com/huggingface/diffusers
!cp ./assets/files/scheduling_ddpm.py ./diffusers/src/diffusers/schedulers/scheduling_ddpm.py
!cp ./assets/files/unet_2d_condition.py ./diffusers/src/diffusers/models/unet_2d_condition.py
!cp ./assets/files/modeling_utils.py ./diffusers/src/diffusers/models/modeling_utils.py
%cd diffusers
!pip install -e .
推論
(1)モデルのダウンロード
%cd /content/unilm/textdiffuser
!wget https://layoutlm.blob.core.windows.net/textdiffuser/textdiffuser-ckpt.zip
!unzip textdiffuser-ckpt.zip
!rm -rf textdiffuser-ckpt.zip
(2)推論
推論させる上で注意点があります。pillowのバージョンの関係でfontに関してコードを修正する箇所があります。
/content/unilm/textdiffuser/util.py
# ------------------------------------------
# TextDiffuser: Diffusion Models as Text Painters
# Paper Link: https://arxiv.org/abs/2305.10855
# Code Link: https://github.com/microsoft/unilm/tree/master/textdiffuser
# Copyright (c) Microsoft Corporation.
# This file defines a set of commonly used utility functions.
# ------------------------------------------
import os
import re
import cv2
import math
import shutil
import string
import textwrap
import numpy as np
from PIL import Image, ImageFont, ImageDraw, ImageOps
import cv2
from typing import *
# define alphabet and alphabet_dic
alphabet = string.digits + string.ascii_lowercase + string.ascii_uppercase + string.punctuation + ' ' # len(aphabet) = 95
alphabet_dic = {}
for index, c in enumerate(alphabet):
alphabet_dic[c] = index + 1 # the index 0 stands for non-character
def transform_mask(mask_root: str):
img = cv2.imread(mask_root)
img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_NEAREST)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
ret, binary = cv2.threshold(gray, 250, 255, cv2.THRESH_BINARY) # pixel value is set to 0 or 255 according to the threshold
return 1 - (binary.astype(np.float32) / 255)
def segmentation_mask_visualization(font_path: str, segmentation_mask: np.array):
segmentation_mask = cv2.resize(segmentation_mask, (64, 64), interpolation=cv2.INTER_NEAREST)
# font = ImageFont.truetype(font_path, 8)
font_path = os.path.join(cv2.__path__[0],'qt','fonts','DejaVuSans.ttf')
font = ImageFont.truetype(font_path, size=8)
blank = Image.new('RGB', (512,512), (0,0,0))
d = ImageDraw.Draw(blank)
for i in range(64):
for j in range(64):
if int(segmentation_mask[i][j]) == 0 or int(segmentation_mask[i][j])-1 >= len(alphabet):
continue
else:
d.text((j*8, i*8), alphabet[int(segmentation_mask[i][j])-1], font=font, fill=(0, 255, 0))
return blank
def make_caption_pil(font_path: str, captions: List[str]):
caption_pil_list = []
# font = ImageFont.truetype(font_path, 18)
font_path = os.path.join(cv2.__path__[0],'qt','fonts','DejaVuSans.ttf')
font = ImageFont.truetype(font_path, size=18)
for caption in captions:
border_size = 2
img = Image.new('RGB', (512-4,48-4), (255,255,255))
img = ImageOps.expand(img, border=(border_size, border_size, border_size, border_size), fill=(127, 127, 127))
draw = ImageDraw.Draw(img)
border_size = 2
text = caption
lines = textwrap.wrap(text, width=40)
x, y = 4, 4
line_height = font.getsize('A')[1] + 4
start = 0
for line in lines:
draw.text((x, y+start), line, font=font, fill=(200, 127, 0))
y += line_height
caption_pil_list.append(img)
return caption_pil_list
def filter_segmentation_mask(segmentation_mask: np.array):
segmentation_mask[segmentation_mask==alphabet_dic['-']] = 0
segmentation_mask[segmentation_mask==alphabet_dic[' ']] = 0
return segmentation_mask
def combine_image(args, sub_output_dir: str, pred_image_list: List, image_pil: Image, character_mask_pil: Image, character_mask_highlight_pil: Image, caption_pil_list: List):
# # create a "latest" folder to store the results
# if os.path.exists(f'{args.output_dir}/latest'):
# shutil.rmtree(f'{args.output_dir}/latest')
# os.mkdir(f'{args.output_dir}/latest')
# save each predicted image
# os.makedirs(f'{args.output_dir}/{sub_output_dir}', exist_ok=True)
for index, img in enumerate(pred_image_list):
img.save(f'{args.output_dir}/{sub_output_dir}/{index}.jpg')
# img.save(f'{args.output_dir}/latest/{index}.jpg')
length = len(pred_image_list)
lines = math.ceil(length / 3)
blank = Image.new('RGB', (512*3, 512*(lines+1)+48*lines), (0,0,0))
blank.paste(image_pil,(0,0))
blank.paste(character_mask_pil,(512,0))
blank.paste(character_mask_highlight_pil,(512*2,0))
for i in range(length):
row, col = i // 3, i % 3
blank.paste(pred_image_list[i],(512*col,512*(row+1)+48*row))
blank.paste(caption_pil_list[i],(512*col,512*(row+1)+48*row+512))
blank.save(f'{args.output_dir}/{sub_output_dir}/combine.jpg')
# blank.save(f'{args.output_dir}/latest/combine.jpg')
return blank.convert('RGB')
def get_width(font_path, text):
# font = ImageFont.truetype(font_path, 24)
font_path = os.path.join(cv2.__path__[0],'qt','fonts','DejaVuSans.ttf')
font = ImageFont.truetype(font_path, size=24)
width, _ = font.getsize(text)
return width
def get_key_words(text: str):
words = []
text = text
matches = re.findall(r"'(.*?)'", text) # find the keywords enclosed by ''
if matches:
for match in matches:
words.extend(match.split())
if len(words) >= 8:
return []
return words
def adjust_overlap_box(box_output, current_index):
if current_index == 0:
return box_output
else:
# judge whether it contains overlap with the last output
last_box = box_output[0, current_index-1, :]
xmin_last, ymin_last, xmax_last, ymax_last = last_box
current_box = box_output[0, current_index, :]
xmin, ymin, xmax, ymax = current_box
if xmin_last <= xmin <= xmax_last and ymin_last <= ymin <= ymax_last:
print('adjust overlapping')
distance_x = xmax_last - xmin
distance_y = ymax_last - ymin
if distance_x <= distance_y:
# avoid overlap
new_x_min = xmax_last + 0.025
new_x_max = xmax - xmin + xmax_last + 0.025
box_output[0,current_index,0] = new_x_min
box_output[0,current_index,2] = new_x_max
else:
new_y_min = ymax_last + 0.025
new_y_max = ymax - ymin + ymax_last + 0.025
box_output[0,current_index,1] = new_y_min
box_output[0,current_index,3] = new_y_max
elif xmin_last <= xmin <= xmax_last and ymin_last <= ymax <= ymax_last:
print('adjust overlapping')
new_x_min = xmax_last + 0.05
new_x_max = xmax - xmin + xmax_last + 0.05
box_output[0,current_index,0] = new_x_min
box_output[0,current_index,2] = new_x_max
return box_output
def shrink_box(box, scale_factor = 0.9):
x1, y1, x2, y2 = box
x1_new = x1 + (x2 - x1) * (1 - scale_factor) / 2
y1_new = y1 + (y2 - y1) * (1 - scale_factor) / 2
x2_new = x2 - (x2 - x1) * (1 - scale_factor) / 2
y2_new = y2 - (y2 - y1) * (1 - scale_factor) / 2
return (x1_new, y1_new, x2_new, y2_new)
def adjust_font_size(args, width, height, draw, text):
size_start = height
while True:
# font = ImageFont.truetype(args.font_path, size_start)
font_path = os.path.join(cv2.__path__[0],'qt','fonts','DejaVuSans.ttf')
font = ImageFont.truetype(font_path, size=size_start)
text_width, _ = draw.textsize(text, font=font)
if text_width >= width:
size_start = size_start - 1
else:
return size_start
def inpainting_merge_image(original_image, mask_image, inpainting_image):
original_image = original_image.resize((512, 512))
mask_image = mask_image.resize((512, 512))
inpainting_image = inpainting_image.resize((512, 512))
mask_image.convert('L')
threshold = 250
table = []
for i in range(256):
if i < threshold:
table.append(1)
else:
table.append(0)
mask_image = mask_image.point(table, "1")
merged_image = Image.composite(inpainting_image, original_image, mask_image)
return merged_image
/content/unilm/textdiffuser/model/layout_generator.py
# ------------------------------------------
# TextDiffuser: Diffusion Models as Text Painters
# Paper Link: https://arxiv.org/abs/2305.10855
# Code Link: https://github.com/microsoft/unilm/tree/master/textdiffuser
# Copyright (c) Microsoft Corporation.
# This file aims to predict the layout of keywords in user prompts.
# ------------------------------------------
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
import re
import numpy as np
import torch
import torch.nn as nn
from transformers import CLIPTokenizer
from PIL import Image, ImageDraw, ImageFont
from util import get_width, get_key_words, adjust_overlap_box, shrink_box, adjust_font_size, alphabet_dic
from model.layout_transformer import LayoutTransformer, TextConditioner
from termcolor import colored
import os
import cv2
# import layout transformer
model = LayoutTransformer().cuda().eval()
model.load_state_dict(torch.load('textdiffuser-ckpt/layout_transformer.pth'))
# import text encoder and tokenizer
text_encoder = TextConditioner().cuda().eval()
tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14')
def process_caption(font_path, caption, keywords):
# remove punctuations. please remove this statement if you want to paint punctuations
caption = re.sub(u"([^\u0041-\u005a\u0061-\u007a\u0030-\u0039])", " ", caption)
# tokenize it into ids and get length
caption_words = tokenizer([caption], truncation=True, max_length=77, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
caption_words_ids = caption_words['input_ids'] # (1, 77)
length = caption_words['length'] # (1, )
# convert id to words
words = tokenizer.convert_ids_to_tokens(caption_words_ids.view(-1).tolist())
words = [i.replace('</w>', '') for i in words]
words_valid = words[:int(length)]
# store the box coordinates and state of each token
info_array = np.zeros((77,5)) # (77, 5)
# split the caption into words and convert them into lower case
caption_split = caption.split()
caption_split = [i.lower() for i in caption_split]
start_dic = {} # get the start index of each word
state_list = [] # 0: start, 1: middle, 2: special token
word_match_list = [] # the index of the word in the caption
current_caption_index = 0
current_match = ''
for i in range(length):
# the first and last token are special tokens
if i == 0 or i == length-1:
state_list.append(2)
word_match_list.append(127)
continue
if current_match == '':
state_list.append(0)
start_dic[current_caption_index] = i
else:
state_list.append(1)
current_match += words_valid[i]
word_match_list.append(current_caption_index)
if current_match == caption_split[current_caption_index]:
current_match = ''
current_caption_index += 1
while len(state_list) < 77:
state_list.append(127)
while len(word_match_list) < 77:
word_match_list.append(127)
length_list = []
width_list =[]
for i in range(len(word_match_list)):
if word_match_list[i] == 127:
length_list.append(0)
width_list.append(0)
else:
length_list.append(len(caption.split()[word_match_list[i]]))
width_list.append(get_width(font_path, caption.split()[word_match_list[i]]))
while len(length_list) < 77:
length_list.append(127)
width_list.append(0)
length_list = torch.Tensor(length_list).long() # (77, )
width_list = torch.Tensor(width_list).long() # (77, )
boxes = []
duplicate_dict = {} # some words may appear more than once
for keyword in keywords:
keyword = keyword.lower()
if keyword in caption_split:
if keyword not in duplicate_dict:
duplicate_dict[keyword] = caption_split.index(keyword)
index = caption_split.index(keyword)
else:
if duplicate_dict[keyword]+1 < len(caption_split) and keyword in caption_split[duplicate_dict[keyword]+1:]:
index = duplicate_dict[keyword] + caption_split[duplicate_dict[keyword]+1:].index(keyword)
duplicate_dict[keyword] = index
else:
continue
index = caption_split.index(keyword)
index = start_dic[index]
info_array[index][0] = 1
box = [0,0,0,0]
boxes.append(list(box))
info_array[index][1:] = box
boxes_length = len(boxes)
if boxes_length > 8:
boxes = boxes[:8]
while len(boxes) < 8:
boxes.append([0,0,0,0])
return caption, length_list, width_list, torch.from_numpy(info_array), words, torch.Tensor(state_list).long(), torch.Tensor(word_match_list).long(), torch.Tensor(boxes), boxes_length
def get_layout_from_prompt(args):
# prompt = args.prompt
font_path = args.font_path
keywords = get_key_words(args.prompt)
print(f'{colored("[!]", "red")} Detected keywords: {keywords} from prompt {args.prompt}')
text_embedding, mask = text_encoder(args.prompt) # (1, 77 768) / (1, 77)
# process all relevant info
caption, length_list, width_list, target, words, state_list, word_match_list, boxes, boxes_length = process_caption(font_path, args.prompt, keywords)
target = target.cuda().unsqueeze(0) # (77, 5)
width_list = width_list.cuda().unsqueeze(0) # (77, )
length_list = length_list.cuda().unsqueeze(0) # (77, )
state_list = state_list.cuda().unsqueeze(0) # (77, )
word_match_list = word_match_list.cuda().unsqueeze(0) # (77, )
padding = torch.zeros(1, 1, 4).cuda()
boxes = boxes.unsqueeze(0).cuda()
right_shifted_boxes = torch.cat([padding, boxes[:,0:-1,:]],1) # (1, 8, 4)
# inference
return_boxes= []
with torch.no_grad():
for box_index in range(boxes_length):
if box_index == 0:
encoder_embedding = None
output, encoder_embedding = model(text_embedding, length_list, width_list, mask, state_list, word_match_list, target, right_shifted_boxes, train=False, encoder_embedding=encoder_embedding)
output = torch.clamp(output, min=0, max=1) # (1, 8, 4)
# add overlap detection
output = adjust_overlap_box(output, box_index) # (1, 8, 4)
right_shifted_boxes[:,box_index+1,:] = output[:,box_index,:]
xmin, ymin, xmax, ymax = output[0, box_index, :].tolist()
return_boxes.append([xmin, ymin, xmax, ymax])
# print the location of keywords
print(f'index\tkeyword\tx_min\ty_min\tx_max\ty_max')
for index, keyword in enumerate(keywords):
x_min = int(return_boxes[index][0] * 512)
y_min = int(return_boxes[index][1] * 512)
x_max = int(return_boxes[index][2] * 512)
y_max = int(return_boxes[index][3] * 512)
print(f'{index}\t{keyword}\t{x_min}\t{y_min}\t{x_max}\t{y_max}')
# paint the layout
render_image = Image.new('RGB', (512, 512), (255, 255, 255))
draw = ImageDraw.Draw(render_image)
segmentation_mask = Image.new("L", (512,512), 0)
segmentation_mask_draw = ImageDraw.Draw(segmentation_mask)
for index, box in enumerate(return_boxes):
box = [int(i*512) for i in box]
xmin, ymin, xmax, ymax = box
width = xmax - xmin
height = ymax - ymin
text = keywords[index]
font_size = adjust_font_size(args, width, height, draw, text)
# font = ImageFont.truetype(args.font_path, font_size)
font_path = os.path.join(cv2.__path__[0],'qt','fonts','DejaVuSans.ttf')
font = ImageFont.truetype(font_path, size=font_size)
# draw.rectangle([xmin, ymin, xmax,ymax], outline=(255,0,0))
draw.text((xmin, ymin), text, font=font, fill=(0, 0, 0))
boxes = []
for i, char in enumerate(text):
# paint character-level segmentation masks
# https://github.com/python-pillow/Pillow/issues/3921
bottom_1 = font.getsize(text[i])[1]
right, bottom_2 = font.getsize(text[:i+1])
bottom = bottom_1 if bottom_1 < bottom_2 else bottom_2
width, height = font.getmask(char).size
right += xmin
bottom += ymin
top = bottom - height
left = right - width
char_box = (left, top, right, bottom)
boxes.append(char_box)
char_index = alphabet_dic[char]
segmentation_mask_draw.rectangle(shrink_box(char_box, scale_factor = 0.9), fill=char_index)
print(f'{colored("[√]", "green")} Layout is successfully generated')
return render_image, segmentation_mask
以上2つのファイルを変更してください。
推論です。
%cd /content/unilm/textdiffuser
# Text-to-Image
!CUDA_VISIBLE_DEVICES=0 python inference.py \
--mode="text-to-image" \
--resume_from_checkpoint="textdiffuser-ckpt/diffusion_backbone" \
--prompt="A sign that says 'Hello'" \
--output_dir="./output" \
--vis_num=4
Helloという文字を出力させます。
結果です。
マジか!本当に文字がインプットされています。これは革命!
最後に
今回はmicrosoftから出されたtextdiffuserをgoogle colabで利用してみました。ついに文字を綺麗に画像内に埋め込めるdiffusion modelがOSSに登場してきました。画像編集の幅がかなり広がりそう!色々試してみます!
今後ともLLM, Diffusion model, Image Analysis, 3Dに関連する試した記事を投稿していく予定なのでよろしくお願いします。
Discussion