🎨

ControlNetで文字画像を生成する【雑コード付き】

2023/02/26に公開約6,500字

最近ControlNetで遊ぶのにハマっているwatankoです。

https://twitter.com/yamkaz/status/1628891312299839489?s=61&t=a2y7iu1YzOVPOzsx2-caHw

先日このようなツイートを見かけ、面白そうなので試してみることにしました。

生成結果




やっていること

以下のような白黒の文字画像を、"直接"depthモデルに突っ込むだけです。(解説:例えばdepthの場合ControlNetは第一段階として入力画像から深度画像を推定したのち、第二段階で推定した深度画像を入力として画像を生成します。第一段階の推定を飛ばして文字画像をそのまま深度画像として使うということですね)

参考
https://www.reddit.com/r/StableDiffusion/comments/119j8qr/clear_text_using_controlnet/

prompt

prompt: "A movie title, sfx, 3d, blue lasers, lights, full HD, 4k",
negative prompt: "cartoon"

環境構築

注意

生成はweb UIではなくローカルで行っており、環境周りの知識が十分にあることを前提に解説しています。ご了承ください。エラー出たらごめんなさい。

  1. ControlNetの本家をclone
git clone https://github.com/lllyasviel/ControlNet.git
cd ControlNet
  1. condaで環境作る
    本家のREADME通りです。
conda env create -f environment.yaml
conda activate control
  1. 重みを落としてくる
    今回は本家の重みではなく https://huggingface.co/edwardwangxy/ControlNet_Deliberate/tree/main 
    の重みを使用しました。手元のSD1.5互換のモデルをControlNet用に変換できるのですが、僕の環境ではVRAMが足りなかったため拾ってくることにしました。
cd models/
wget https://huggingface.co/edwardwangxy/ControlNet_Deliberate/resolve/main/deliberate_depth.pth
  1. フォントを落としてくる
    今回はkeifontを使わせていただきました。お好きなので大丈夫です。
    http://font.sumomo.ne.jp/font_1.html

生成

本家のコードはgradio経由のみで煩わしかったため、雑に書きました。

from share import *
import config

import cv2
from PIL import Image, ImageDraw, ImageFont
import einops
import numpy as np
import torch
import random

from annotator.util import resize_image, HWC3
from annotator.midas import MidasDetector

from pytorch_lightning import seed_everything
from cldm.model import create_model, load_state_dict
from cldm.ddim_hacked import DDIMSampler


def process(
    text="5000兆円\n  欲しい!",
    prompt="A movie title, sfx, 3d, blue lasers, lights, full HD, 4k",
    a_prompt="",
    n_prompt="cartoon",
    num_samples=1,
    image_resolution=512,
    detect_resolution=384,
    ddim_steps=20,
    guess_mode=False,
    strength=1.5,
    scale=9.0,
    seed=-1,
    eta=0.0,
):
    """_summary_
    Args:
        prompt (str, optional): Defaults to None.
        a_prompt (str, optional): added prompt. Defaults to "best quality, extremely detailed".
        n_prompt (str, optional): negative prompt. Defaults to "longbody, lowres, bad anatomy, bad hands, missing fingers,
        extra digit, fewer digits, cropped, worst quality, low quality".
        num_samples (int): Defaults to 1.
        image_resolution (int): Defaults to 512.
        detect_resolution (int): Defaults to 512.
        ddim_steps (int): Defaults to 20.
        guess_mode (bool): Defaults to False.
        strength (float): control strength. Defaults to 1.0.
        scale (float): Defaults to 9.0.
        seed (int): Defaults to 0.
        eta (float): Defaults to 0.
    """
    apply_midas = MidasDetector()

    model = create_model('./models/cldm_v15.yaml').cpu()
    model.load_state_dict(load_state_dict('./models/deliberate_depth.pth', location='cuda'))
    model = model.cuda()
    ddim_sampler = DDIMSampler(model)

    with torch.no_grad():
        W = H = 512
        img = Image.new("RGB", (W, H), (((255, 255, 255))))
        draw = ImageDraw.Draw(img)
        font = ImageFont.truetype('keifont.ttf', 100)
        draw.multiline_text((30, 150), text, fill=((0, 0, 0)), font=font)

        text_anno = np.array(img, dtype=np.uint8)
        text_anno = cv2.cvtColor(text_anno, cv2.COLOR_RGB2BGR)
        cv2.imwrite("anno.jpg", text_anno)
        input_image = text_anno.copy()

        # detected_map, _ = apply_midas(resize_image(input_image, detect_resolution))
        detected_map = HWC3(input_image)
        img = resize_image(input_image, image_resolution)
        H, W, C = img.shape

        detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)

        control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
        control = torch.stack([control for _ in range(num_samples)], dim=0)
        control = einops.rearrange(control, 'b h w c -> b c h w').clone()

        if seed == -1:
            seed = random.randint(0, 65535)
        seed_everything(seed)

        if config.save_memory:
            model.low_vram_shift(is_diffusing=False)

        cond = {
            "c_concat": [control],
            "c_crossattn": [model.get_learned_conditioning([prompt + ", " + a_prompt] * num_samples)],
        }
        un_cond = {
            "c_concat": None if guess_mode else [control],
            "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)],
        }
        shape = (4, H // 8, W // 8)

        if config.save_memory:
            model.low_vram_shift(is_diffusing=True)

        model.control_scales = (
            [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13)
        )  # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
        samples, intermediates = ddim_sampler.sample(
            ddim_steps,
            num_samples,
            shape,
            cond,
            verbose=False,
            eta=eta,
            unconditional_guidance_scale=scale,
            unconditional_conditioning=un_cond,
        )

        if config.save_memory:
            model.low_vram_shift(is_diffusing=False)

        x_samples = model.decode_first_stage(samples)
        x_samples = (
            (einops.rearrange(x_samples, "b c h w -> b h w c") * 127.5 + 127.5)
            .cpu()
            .numpy()
            .clip(0, 255)
            .astype(np.uint8)
        )

        results = [x_samples[i] for i in range(num_samples)]
        for i, result in enumerate(results):
            cv2.imwrite(f"r{i}.jpg", cv2.cvtColor(result, cv2.COLOR_RGB2BGR))


if __name__ == "__main__":
    processz()

元の文字画像がanno.jpgとして、結果がr{i}.jpg (i=0, 1, ..., num_samples)として保存されます。

雑解説

img = Image.new("RGB", (W, H), (((255, 255, 255))))
draw = ImageDraw.Draw(img)
font = ImageFont.truetype('keifont.ttf', 100)
draw.multiline_text((30, 150), text, fill=((0, 0, 0)), font=font)

pillowで文字画像を作った後に突っ込んでいます。すでに用意したものを使う場合は

img = Image.open(path)

で置き換えてください。なお、pillowはRGBなのに対してcv2はBGRなので後者で読み込む場合はcvtColorが必要です。

Discussion

ログインするとコメントできます