📘

TextDiffusersをgoogle colabで試してみた。

2023/06/06に公開

text diffusersとは

text diffusersはmicrosoftから発表されたtextを組み入れて画像を生成することができるdiffusion modelです。今まではなかなか綺麗に文字を入れられなかったですが、今回のモデルの登場で画像内にテキストを入れることが可能になりました。
https://github.com/microsoft/unilm/tree/master/textdiffuser

リンク

Colab
github

準備

Google Colabを開き、メニューから「ランタイム→ランタイムのタイプを変更」でランタイムを「GPU」に変更します。

環境構築

インストール手順です。

!git clone https://github.com/microsoft/unilm.git
%cd unilm/textdiffuser
!pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117
!pip install datasets==2.11.0 tokenizers==0.13.3 transformers==4.27.4 xformers==0.0.16 accelerate==0.18.0 triton==2.0.0.post1 termcolor==2.3.0 tinydb flask
!pip install TorchSnooper

# install diffusers
%cd /content/unilm/textdiffuser
!git clone https://github.com/huggingface/diffusers
!cp ./assets/files/scheduling_ddpm.py ./diffusers/src/diffusers/schedulers/scheduling_ddpm.py
!cp ./assets/files/unet_2d_condition.py ./diffusers/src/diffusers/models/unet_2d_condition.py
!cp ./assets/files/modeling_utils.py ./diffusers/src/diffusers/models/modeling_utils.py
%cd diffusers
!pip install -e .

推論

(1)モデルのダウンロード

%cd /content/unilm/textdiffuser
!wget https://layoutlm.blob.core.windows.net/textdiffuser/textdiffuser-ckpt.zip

!unzip textdiffuser-ckpt.zip
!rm -rf textdiffuser-ckpt.zip

(2)推論
推論させる上で注意点があります。pillowのバージョンの関係でfontに関してコードを修正する箇所があります。

/content/unilm/textdiffuser/util.py
# ------------------------------------------
# TextDiffuser: Diffusion Models as Text Painters
# Paper Link: https://arxiv.org/abs/2305.10855
# Code Link: https://github.com/microsoft/unilm/tree/master/textdiffuser
# Copyright (c) Microsoft Corporation.
# This file defines a set of commonly used utility functions.
# ------------------------------------------

import os
import re
import cv2
import math
import shutil
import string
import textwrap
import numpy as np
from PIL import Image, ImageFont, ImageDraw, ImageOps
import cv2

from typing import *

# define alphabet and alphabet_dic
alphabet = string.digits + string.ascii_lowercase + string.ascii_uppercase + string.punctuation + ' ' # len(aphabet) = 95
alphabet_dic = {}
for index, c in enumerate(alphabet):
    alphabet_dic[c] = index + 1 # the index 0 stands for non-character
    

def transform_mask(mask_root: str):
    
    img = cv2.imread(mask_root)
    img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_NEAREST)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    ret, binary = cv2.threshold(gray, 250, 255, cv2.THRESH_BINARY) # pixel value is set to 0 or 255 according to the threshold
    return 1 - (binary.astype(np.float32) / 255) 


def segmentation_mask_visualization(font_path: str, segmentation_mask: np.array):
    
    segmentation_mask = cv2.resize(segmentation_mask, (64, 64), interpolation=cv2.INTER_NEAREST)
    # font = ImageFont.truetype(font_path, 8)
    font_path = os.path.join(cv2.__path__[0],'qt','fonts','DejaVuSans.ttf')
    font = ImageFont.truetype(font_path, size=8)
    blank = Image.new('RGB', (512,512), (0,0,0))
    d = ImageDraw.Draw(blank)
    for i in range(64):
        for j in range(64):
            if int(segmentation_mask[i][j]) == 0 or int(segmentation_mask[i][j])-1 >= len(alphabet): 
                continue
            else:
                d.text((j*8, i*8), alphabet[int(segmentation_mask[i][j])-1], font=font, fill=(0, 255, 0))
    return blank


def make_caption_pil(font_path: str, captions: List[str]):
    
    caption_pil_list = []
    # font = ImageFont.truetype(font_path, 18)
    font_path = os.path.join(cv2.__path__[0],'qt','fonts','DejaVuSans.ttf')
    font = ImageFont.truetype(font_path, size=18)

    for caption in captions:
        border_size = 2
        img = Image.new('RGB', (512-4,48-4), (255,255,255)) 
        img = ImageOps.expand(img, border=(border_size, border_size, border_size, border_size), fill=(127, 127, 127))
        draw = ImageDraw.Draw(img)
        border_size = 2
        text = caption
        lines = textwrap.wrap(text, width=40)
        x, y = 4, 4
        line_height = font.getsize('A')[1] + 4 

        start = 0
        for line in lines:
            draw.text((x, y+start), line, font=font, fill=(200, 127, 0))
            y += line_height

        caption_pil_list.append(img)
    return caption_pil_list


def filter_segmentation_mask(segmentation_mask: np.array):
    
    segmentation_mask[segmentation_mask==alphabet_dic['-']] = 0
    segmentation_mask[segmentation_mask==alphabet_dic[' ']] = 0
    return segmentation_mask
    
    

def combine_image(args, sub_output_dir: str, pred_image_list: List, image_pil: Image, character_mask_pil: Image, character_mask_highlight_pil: Image, caption_pil_list: List):
    
    
    # # create a "latest" folder to store the results 
    # if os.path.exists(f'{args.output_dir}/latest'):
    #     shutil.rmtree(f'{args.output_dir}/latest')
    # os.mkdir(f'{args.output_dir}/latest')
    
    # save each predicted image
    # os.makedirs(f'{args.output_dir}/{sub_output_dir}', exist_ok=True)
    for index, img in enumerate(pred_image_list):
        img.save(f'{args.output_dir}/{sub_output_dir}/{index}.jpg')
        # img.save(f'{args.output_dir}/latest/{index}.jpg')
        
    length = len(pred_image_list)
    lines = math.ceil(length / 3)
    
    blank = Image.new('RGB', (512*3, 512*(lines+1)+48*lines), (0,0,0)) 
    blank.paste(image_pil,(0,0))
    blank.paste(character_mask_pil,(512,0))
    blank.paste(character_mask_highlight_pil,(512*2,0))
    
    for i in range(length):
        row, col = i // 3, i % 3
        blank.paste(pred_image_list[i],(512*col,512*(row+1)+48*row))
        blank.paste(caption_pil_list[i],(512*col,512*(row+1)+48*row+512))
    
    blank.save(f'{args.output_dir}/{sub_output_dir}/combine.jpg')
    # blank.save(f'{args.output_dir}/latest/combine.jpg')
    
    return blank.convert('RGB')
    
    
def get_width(font_path, text):
    
    # font = ImageFont.truetype(font_path, 24)
    font_path = os.path.join(cv2.__path__[0],'qt','fonts','DejaVuSans.ttf')
    font = ImageFont.truetype(font_path, size=24)
    width, _ = font.getsize(text)
    return width



def get_key_words(text: str):
   

    words = []
    text = text
    matches = re.findall(r"'(.*?)'", text) # find the keywords enclosed by ''
    if matches:
        for match in matches:
            words.extend(match.split())
            
    if len(words) >= 8:
        return []
   
    return words


def adjust_overlap_box(box_output, current_index):
    
    
    if current_index == 0:
        return box_output
    else:
        # judge whether it contains overlap with the last output
        last_box = box_output[0, current_index-1, :]
        xmin_last, ymin_last, xmax_last, ymax_last = last_box
        
        current_box = box_output[0, current_index, :]
        xmin, ymin, xmax, ymax = current_box
        
        if xmin_last <= xmin <= xmax_last and ymin_last <= ymin <= ymax_last:
            print('adjust overlapping')
            distance_x = xmax_last - xmin
            distance_y = ymax_last - ymin
            if distance_x <= distance_y:
                # avoid overlap
                new_x_min = xmax_last + 0.025
                new_x_max = xmax - xmin + xmax_last + 0.025
                box_output[0,current_index,0] = new_x_min
                box_output[0,current_index,2] = new_x_max
            else:
                new_y_min = ymax_last + 0.025
                new_y_max = ymax - ymin + ymax_last + 0.025
                box_output[0,current_index,1] = new_y_min
                box_output[0,current_index,3] = new_y_max  
                
        elif xmin_last <= xmin <= xmax_last and ymin_last <= ymax <= ymax_last:
            print('adjust overlapping')
            new_x_min = xmax_last + 0.05
            new_x_max = xmax - xmin + xmax_last + 0.05
            box_output[0,current_index,0] = new_x_min
            box_output[0,current_index,2] = new_x_max
                    
        return box_output
    
    
def shrink_box(box, scale_factor = 0.9):
    
    
    x1, y1, x2, y2 = box
    x1_new = x1 + (x2 - x1) * (1 - scale_factor) / 2
    y1_new = y1 + (y2 - y1) * (1 - scale_factor) / 2
    x2_new = x2 - (x2 - x1) * (1 - scale_factor) / 2
    y2_new = y2 - (y2 - y1) * (1 - scale_factor) / 2
    return (x1_new, y1_new, x2_new, y2_new)


def adjust_font_size(args, width, height, draw, text):
    
    
    size_start = height
    while True:
        # font = ImageFont.truetype(args.font_path, size_start)
        font_path = os.path.join(cv2.__path__[0],'qt','fonts','DejaVuSans.ttf')
        font = ImageFont.truetype(font_path, size=size_start)
        text_width, _ = draw.textsize(text, font=font)
        if text_width >= width:
            size_start = size_start - 1
        else:
            return size_start
    
    
def inpainting_merge_image(original_image, mask_image, inpainting_image):
    
    
    original_image = original_image.resize((512, 512))
    mask_image = mask_image.resize((512, 512))
    inpainting_image = inpainting_image.resize((512, 512))
    mask_image.convert('L')
    threshold = 250 
    table = []
    for i in range(256):
        if i < threshold:
            table.append(1)
        else:
            table.append(0)
    mask_image = mask_image.point(table, "1")
    merged_image = Image.composite(inpainting_image, original_image, mask_image)
    return merged_image
/content/unilm/textdiffuser/model/layout_generator.py
# ------------------------------------------
# TextDiffuser: Diffusion Models as Text Painters
# Paper Link: https://arxiv.org/abs/2305.10855
# Code Link: https://github.com/microsoft/unilm/tree/master/textdiffuser
# Copyright (c) Microsoft Corporation.
# This file aims to predict the layout of keywords in user prompts.
# ------------------------------------------

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

import re
import numpy as np
import torch
import torch.nn as nn
from transformers import CLIPTokenizer
from PIL import Image, ImageDraw, ImageFont
from util import get_width, get_key_words, adjust_overlap_box, shrink_box, adjust_font_size, alphabet_dic
from model.layout_transformer import LayoutTransformer, TextConditioner
from termcolor import colored

import os
import cv2

# import layout transformer
model = LayoutTransformer().cuda().eval()
model.load_state_dict(torch.load('textdiffuser-ckpt/layout_transformer.pth'))

# import text encoder and tokenizer
text_encoder = TextConditioner().cuda().eval()
tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14')


def process_caption(font_path, caption, keywords):
    # remove punctuations. please remove this statement if you want to paint punctuations
    caption = re.sub(u"([^\u0041-\u005a\u0061-\u007a\u0030-\u0039])", " ", caption) 
    
    # tokenize it into ids and get length
    caption_words = tokenizer([caption], truncation=True, max_length=77, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
    caption_words_ids = caption_words['input_ids'] # (1, 77)
    length = caption_words['length'] # (1, )
    
    # convert id to words
    words = tokenizer.convert_ids_to_tokens(caption_words_ids.view(-1).tolist())
    words = [i.replace('</w>', '') for i in words]
    words_valid = words[:int(length)]

    # store the box coordinates and state of each token
    info_array = np.zeros((77,5)) # (77, 5)

    # split the caption into words and convert them into lower case
    caption_split = caption.split() 
    caption_split = [i.lower() for i in caption_split]

    start_dic = {} # get the start index of each word
    state_list = [] # 0: start, 1: middle, 2: special token
    word_match_list = [] # the index of the word in the caption
    current_caption_index = 0 
    current_match = ''
    for i in range(length): 

        # the first and last token are special tokens
        if i == 0 or i == length-1:
            state_list.append(2) 
            word_match_list.append(127)
            continue

        if current_match == '':
            state_list.append(0)
            start_dic[current_caption_index] = i
        else:
            state_list.append(1)

        current_match += words_valid[i]
        word_match_list.append(current_caption_index)
        if current_match == caption_split[current_caption_index]:
            current_match = ''
            current_caption_index += 1

    while len(state_list) < 77:
        state_list.append(127)
    while len(word_match_list) < 77:
        word_match_list.append(127)

    length_list = []
    width_list =[]
    for i in range(len(word_match_list)):
        if word_match_list[i] == 127:
            length_list.append(0)
            width_list.append(0)
        else:
            length_list.append(len(caption.split()[word_match_list[i]]))
            width_list.append(get_width(font_path, caption.split()[word_match_list[i]]))

    while len(length_list) < 77:
        length_list.append(127)
        width_list.append(0)

    length_list = torch.Tensor(length_list).long() # (77, )
    width_list = torch.Tensor(width_list).long() # (77, )

    boxes = []
    duplicate_dict = {} # some words may appear more than once
    for keyword in keywords: 
        keyword = keyword.lower()
        if keyword in caption_split:
            if keyword not in duplicate_dict:
                duplicate_dict[keyword] = caption_split.index(keyword) 
                index = caption_split.index(keyword)
            else:
                if duplicate_dict[keyword]+1 < len(caption_split) and keyword in caption_split[duplicate_dict[keyword]+1:]:
                    index = duplicate_dict[keyword] + caption_split[duplicate_dict[keyword]+1:].index(keyword)
                    duplicate_dict[keyword] = index
                else:
                    continue
                
            index = caption_split.index(keyword) 
            index = start_dic[index] 
            info_array[index][0] = 1 

            box = [0,0,0,0] 
            boxes.append(list(box))
            info_array[index][1:] = box
    
    boxes_length = len(boxes)
    if boxes_length > 8:
        boxes = boxes[:8]
    while len(boxes) < 8:
        boxes.append([0,0,0,0])

    return caption, length_list, width_list, torch.from_numpy(info_array), words, torch.Tensor(state_list).long(), torch.Tensor(word_match_list).long(), torch.Tensor(boxes), boxes_length


def get_layout_from_prompt(args):

    # prompt = args.prompt
    font_path = args.font_path
    keywords = get_key_words(args.prompt)
    
    print(f'{colored("[!]", "red")} Detected keywords: {keywords} from prompt {args.prompt}')
    
    text_embedding, mask = text_encoder(args.prompt) # (1, 77 768) / (1, 77)

    # process all relevant info
    caption, length_list, width_list, target, words, state_list, word_match_list, boxes, boxes_length = process_caption(font_path, args.prompt, keywords)
    target = target.cuda().unsqueeze(0) # (77, 5)
    width_list = width_list.cuda().unsqueeze(0) # (77, )
    length_list = length_list.cuda().unsqueeze(0) # (77, )
    state_list = state_list.cuda().unsqueeze(0) # (77, )
    word_match_list = word_match_list.cuda().unsqueeze(0) # (77, )

    padding = torch.zeros(1, 1, 4).cuda()
    boxes = boxes.unsqueeze(0).cuda()
    right_shifted_boxes = torch.cat([padding, boxes[:,0:-1,:]],1) # (1, 8, 4)
   
    # inference
    return_boxes= []
    with torch.no_grad():
        for box_index in range(boxes_length):
            
            if box_index == 0:
                encoder_embedding = None
                
            output, encoder_embedding = model(text_embedding, length_list, width_list, mask, state_list, word_match_list, target, right_shifted_boxes, train=False, encoder_embedding=encoder_embedding) 
            output = torch.clamp(output, min=0, max=1) # (1, 8, 4)
            
            # add overlap detection
            output = adjust_overlap_box(output, box_index) # (1, 8, 4)
            
            right_shifted_boxes[:,box_index+1,:] = output[:,box_index,:]
            xmin, ymin, xmax, ymax = output[0, box_index, :].tolist()
            return_boxes.append([xmin, ymin, xmax, ymax])
            
            
    # print the location of keywords
    print(f'index\tkeyword\tx_min\ty_min\tx_max\ty_max')
    for index, keyword in enumerate(keywords):
        x_min = int(return_boxes[index][0] * 512)
        y_min = int(return_boxes[index][1] * 512)
        x_max = int(return_boxes[index][2] * 512)
        y_max = int(return_boxes[index][3] * 512)
        print(f'{index}\t{keyword}\t{x_min}\t{y_min}\t{x_max}\t{y_max}')
    
    
    # paint the layout
    render_image = Image.new('RGB', (512, 512), (255, 255, 255))
    draw = ImageDraw.Draw(render_image)
    segmentation_mask = Image.new("L", (512,512), 0)
    segmentation_mask_draw = ImageDraw.Draw(segmentation_mask)

    for index, box in enumerate(return_boxes):
        box = [int(i*512) for i in box]
        xmin, ymin, xmax, ymax = box
        
        width = xmax - xmin
        height = ymax - ymin
        text = keywords[index]

        font_size = adjust_font_size(args, width, height, draw, text)
        # font = ImageFont.truetype(args.font_path, font_size)
        font_path = os.path.join(cv2.__path__[0],'qt','fonts','DejaVuSans.ttf')
        font = ImageFont.truetype(font_path, size=font_size)

        # draw.rectangle([xmin, ymin, xmax,ymax], outline=(255,0,0))
        draw.text((xmin, ymin), text, font=font, fill=(0, 0, 0))
            
        boxes = []
        for i, char in enumerate(text):
            
            # paint character-level segmentation masks
            # https://github.com/python-pillow/Pillow/issues/3921
            bottom_1 = font.getsize(text[i])[1]
            right, bottom_2 = font.getsize(text[:i+1])
            bottom = bottom_1 if bottom_1 < bottom_2 else bottom_2
            width, height = font.getmask(char).size
            right += xmin
            bottom += ymin
            top = bottom - height
            left = right - width
            
            char_box = (left, top, right, bottom)
            boxes.append(char_box)
            
            char_index = alphabet_dic[char]
            segmentation_mask_draw.rectangle(shrink_box(char_box, scale_factor = 0.9), fill=char_index)
    
    print(f'{colored("[√]", "green")} Layout is successfully generated')
    return render_image, segmentation_mask

以上2つのファイルを変更してください。

推論です。

%cd /content/unilm/textdiffuser
# Text-to-Image
!CUDA_VISIBLE_DEVICES=0 python inference.py \
  --mode="text-to-image" \
  --resume_from_checkpoint="textdiffuser-ckpt/diffusion_backbone" \
  --prompt="A sign that says 'Hello'" \
  --output_dir="./output" \
  --vis_num=4

Helloという文字を出力させます。

結果です。

マジか!本当に文字がインプットされています。これは革命!

最後に

今回はmicrosoftから出されたtextdiffuserをgoogle colabで利用してみました。ついに文字を綺麗に画像内に埋め込めるdiffusion modelがOSSに登場してきました。画像編集の幅がかなり広がりそう!色々試してみます!

今後ともLLM, Diffusion model, Image Analysis, 3Dに関連する試した記事を投稿していく予定なのでよろしくお願いします。

Discussion