Closed14

ControlNetをトレーニングしたい

PlatPlat

学習のチュートリアル解説がついているのでまずはそれに従ってやる

黒背景に白い円が描かれた画像とそれに色がついた画像のペアを使って ControlNet をトレーニングする。

完全に解説そのままやるだけ。

git でレポを clone する

git clone https://github.com/lllyasviel/ControlNet

conda で環境作る

conda env create -f environment.yaml

(Python 3.8.5 に Torch 1.12 と少し古いが、おそらくこれは大事なのでそのまま従う)

PlatPlat

データセット落とす

training フォルダを作成する

mkdir training

Hugging Face から円のデータセット fill50k.zip を落としてくる

wget https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip

解凍

unzip fill50k

データセットの確認

python tutorial_dataset_test.py

おそらくこうなる。

PlatPlat

モデルの用意

適当にモデルを用意する。

私の中では v1 モデルは存在しないことになっているので v2系をつかう。(今更v1トレーニングする意味も薄いしね)

今回は WD 1.5 を使ったけど、ここはマジでぶっ壊れてるモデルじゃなければなんでもいいと思う。気分で WD1.5 にした。あと、ここはなぜか ckpt しか対応してないので safetensors だったらどうにかして ckpt に変換しよう。

もう既に Hugging Face から落としてきたとする

モデルをセットアップするので以下のように実行する

python tool_add_control_sd21.py models/wd-15-2.6-186k.ckpt models/control-wd15-2.6-186k-ini.ckpt 

ちょっと待てば終わる。

PlatPlat

学習する

あとはほんとに学習のコードを実行するだけなのだが、学習コードがめちゃくちゃシンプルで私でも理解できる程度だったので、軽く見ていこう

データセットを管理するコード

tutorial_dataset.py
import json
import cv2
import numpy as np

from torch.utils.data import Dataset


class MyDataset(Dataset):
    def __init__(self):
        self.data = []
        with open('./training/fill50k/prompt.json', 'rt') as f:
            for line in f:
                self.data.append(json.loads(line))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        source_filename = item['source']
        target_filename = item['target']
        prompt = item['prompt']

        source = cv2.imread('./training/fill50k/' + source_filename)
        target = cv2.imread('./training/fill50k/' + target_filename)

        # Do not forget that OpenCV read images in BGR order.
        source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
        target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)

        # Normalize source images to [0, 1].
        source = source.astype(np.float32) / 255.0

        # Normalize target images to [-1, 1].
        target = (target.astype(np.float32) / 127.5) - 1.0

        return dict(jpg=target, txt=prompt, hint=source)

初期化時に prompt.json (ヒント画像、正解画像の相対パスと正解画像のキャプションが入っている) を読み込んで、 [123] のようにインデックス指定されたときに、その番号に対応する画像とキャプションを返している。

ノーマライズはおそらく学習する際に入力に合うように変換しているだけだと思う。(多分)

非常にシンプルだし、これくらいなら簡単に改造できそうに思える。素晴らしい。

おそらくここでランダムにキャプションを削ったりするのだと思われる。(3.3 Training)

次は実行する学習コード

tutorial_train_sd21.py
from share import *

import pytorch_lightning as pl
from torch.utils.data import DataLoader
from tutorial_dataset import MyDataset
from cldm.logger import ImageLogger
from cldm.model import create_model, load_state_dict


# Configs
resume_path = './models/control_sd21_ini.ckpt'
batch_size = 4
logger_freq = 300
learning_rate = 1e-5
sd_locked = True
only_mid_control = False


# First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
model = create_model('./models/cldm_v21.yaml').cpu()
model.load_state_dict(load_state_dict(resume_path, location='cpu'))
model.learning_rate = learning_rate
model.sd_locked = sd_locked
model.only_mid_control = only_mid_control


# Misc
dataset = MyDataset()
dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True)
logger = ImageLogger(batch_frequency=logger_freq)
trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger])


# Train!
trainer.fit(model, dataloader)

モデルやハイパーパラメーターを指定し、読み込んで、 fit で学習を開始している。

私は機械学習をまったくやってないので学習コードがこんなにシンプルに収まるなんて想像できなかった。めちゃくちゃ複雑そうなコードを書いているイメージがあったのだが、さまざまなライブラリや先人のコードのおかげでここまでシンプルになるのだなと少し感動した。

モデルを指定したり、バッチサイズを変更したらあとは実行するだけ

python tutorial_train_sd21 .py

あとは待つ!

PlatPlat

途中経過 (30時間くらい?)

A40、バッチサイズ 12 で回している

キャプション

ヒント画像

生成されたもの

正解

PlatPlat

途中経過 (40時間)

キャプション

ヒント

生成結果

正解

PlatPlat

途中経過 (48時間くらい)

キャプション

ヒント

生成結果

正解

PlatPlat

学習を停止した (51時間ちょい)

キャプション

ヒント

生成画像

正解

これで十分だと思われる。

50k 枚の円の画像ペア、バッチサイズ12で A40 で 約50GPU時間かかって良い感じにトレーニングされた。

PlatPlat

公式のモデル

時間があるので論文に載ってたモデルについて見てみる

ベースモデルはSD1.5じゃないのを太字にした

名前 意味 データセット ベースモデル GPU 学習時間 備考
Canny Edge 輪郭 インターネットから3M SD1.5 A100 600時間 3Mはパない
Canny Edge (Alter) 輪郭 上と同じもので、1k,10k,50k,500kで実験 SD1.5 A100 150時間
Hough Line 直線 Places2 (10M) Canny Edge A100 150時間 Cannyからの継続。Houghに
HED 輪郭 インターネットから3M SD1.5 A100 300時間 Cannyよりも曖昧な感じの輪郭?
User Sketching 人間っぽいスケッチ HEDのやつを調整したものとインターネットから500kペア Canny Edge A100 150時間
Openpifpaf ポーズ インターネットから80k SD2.1 RTX 3090TI 400時間 これだけ SD2.1ベース。だけどなぜか公開されてない。画像の制約として、人体が一定割合含まれている条件があった。
OpenPose ポーズ インターネットから200k SD1.5 A100 300時間 こちらも人体が一定割合含まれている条件がある。この条件で200kあつめるのたいへんそう...
Semantic Segmentaion (COCO) 塗り分け COCOから164k (BLIPでキャプション) SD1.5 RTX 3090TI 400時間 BLIPでキャプションが付け直されている
Semantic Segmentation (ADKE20K) 塗り分け ADE20Kから164k (BLIPでキャプション) SD1.5 A100 200時間 こちらも
Depth (Midas) 深度 インターネットから3M SD1.5 A100 500時間
Depth (小規模) 深度 上のやつから絞って200k 他の条件書いてないけど多分おなじ?
Normal Maps ノーマルマップ DIODEから25,452ペア (BLIPでキャプション) SD1.5 A100 100時間 なぜかこれだけ数が具体的に書いてある
Normal Maps (拡張) ノーマルマップ Midasで深度とってからノーマルマップを作成した Normal Maps A100 200時間 上のモデルからの継続
Cartoon Line Drawing 線画 インターネットから取得して線画を抽出し、人気度が高いものから順に1Mペア Waifu Diffusion (多分1.4?) A100 300時間 WDベース!!これはおそらくDanbooruから線画を抽出した?1Mペアは大変だ... 線画の抽出に使われたのはこれ
PlatPlat

学習時の注意

ちょっと先の方でも述べたが、論文によるとプロンプトの50%を空文字に置き換えたことで、いい感じになったぜと書いてある。

ただし、イラストにおいては50%は高すぎるので 0%~10% にするとよいらしいというのがここに書いてある。

あと、バッチサイズは下げてはならないらしい。「突然収束する現象」はある一定の段階で必ず起こるため、バッチサイズを下げてステップ数を増やすよりも、バッチサイズを上げてステップ数を下げるべき、とControlNetの作者が言っている

PlatPlat

ワッ...ワァ....

ControlNet v1.1

https://github.com/lllyasviel/ControlNet-v1-1-nightly

v1.0から継続でトレーニングされたものや、新しい種類のモデルが追加された。

また、v1.5まではモデルのアーキテクチャが変わらないことが保証された。(ほんとにv1.5シリーズまで出すの!?)

モデル名の命名のガイドラインが作られた。(今までどれがどれなのかわかりずらかったw)

PlatPlat

v1.1のモデル

Depth

Midas、Leres、Zoeの3つのアノテーターと、256/384/512解像度の画像を使って学習された

Normal Maps

BAE のノーマルマップの手法を使ったもの (前回のMidasベースのやつはゴミなので破棄されたそう)

Canny・MLSD・Scribble (継続)

データセットは以前と同じで、前回のモデルから追加でA100で200時間学習された (贅沢な計算資源よ...)

Soft Edge

HEDのデータであんまりいい感じじゃないのがあったからやりなおした感じ?(よくわからん)

Segmentation

前回は別々だった COCO と ADE20K を一緒に学習したもの。ControlNetにとっては、色分けのルールなんか細かいことで、異なるルールの区分けを一緒に学習することで使いやすくなったらしい。

Openpose

前回は全身 (body) のみだったが、顔 (face)、全身 (body)、手 (hand) も一緒に学習して一緒に使えるようになったモデル。

Lineart

実写的なスケッチ。

Anime Lineart

ベースモデルに何使ってんだよ!?!?!?!????
イラスト線画。Anime2Sketch を使っている。

Shuffle

プロンプトで指示しながら reimagine 的な感じができたり?色の雰囲気を指定したり?する?スタイル転移的な?

InstructPix2Pix

本家のInstructPix2Pixよ... あいつはいいやつだったよ...
ControlNet でどのモデルでも、いい感じにPix2Pixできるようになった。
データセットは本家が公開しているものと、キャプションをちょっと加工して使ったらしい。

Inpaint

50%ランダムなマスクと50%の意図的なマスク?(ここ何言ってるのかわからなかった)で学習して、インペイントができるようにしたモデル。
ヒント画像作るのが楽そうなのでわりと気軽にトレーニングできそう。

Tile (未完成)

タイル(小さく区切って)で拡散して高解像度で出力するのに、プロンプトの影響をいい感じに伝わるように?作られたモデル
4k画像が200k枚つかわれている (どこから集められるんだろう...)

PlatPlat

実はあの円のやつは思いっきり間違えたことをやっていた。

あの ./models/cldm_v21.yaml は、デフォルトで用意されている SDv2.1 向けの設定ファイルなのだが、これは base モデル向けであり、768解像度のモデル向けではない。少し修正する必要があるのだ。

具体的にはこのようにする

model:
  target: cldm.cldm.ControlLDM
  params:
+  parameterization: "v"
    linear_start: 0.00085
    linear_end: 0.0120
    num_timesteps_cond: 1
    log_every_t: 200
    timesteps: 1000
    first_stage_key: "jpg"
    cond_stage_key: "txt"
    control_key: "hint"
    image_size: 64
    channels: 4
    cond_stage_trainable: false
    conditioning_key: crossattn
    monitor: val/loss_simple_ema
    scale_factor: 0.18215
    use_ema: False
    only_mid_control: False

    control_stage_config:
      target: cldm.cldm.ControlNet
      params:
        use_checkpoint: True
        image_size: 32 # unused
        in_channels: 4
        hint_channels: 3
        model_channels: 320
        attention_resolutions: [ 4, 2, 1 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4, 4 ]
        num_head_channels: 64 # need to fix for flash-attn
        use_spatial_transformer: True
        use_linear_in_transformer: True
        transformer_depth: 1
        context_dim: 1024
        legacy: False

    unet_config:
      target: cldm.cldm.ControlledUnetModel
      params:
        use_checkpoint: True
        image_size: 32 # unused
        in_channels: 4
        out_channels: 4
        model_channels: 320
        attention_resolutions: [ 4, 2, 1 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4, 4 ]
        num_head_channels: 64 # need to fix for flash-attn
        use_spatial_transformer: True
        use_linear_in_transformer: True
        transformer_depth: 1
        context_dim: 1024
        legacy: False

    first_stage_config:
      target: ldm.models.autoencoder.AutoencoderKL
      params:
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          #attn_type: "vanilla-xformers"
          double_z: true
          z_channels: 4
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult:
          - 1
          - 2
          - 4
          - 4
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

    cond_stage_config:
      target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
      params:
        freeze: True
        layer: "penultimate"

そう、v-prediction を有効にしなければならない。これをせずに学習すると、あらゆる画像を茶色にする ControlNet が完成する (一敗)。

ステップ 0 時点で出力される画像が、無駄に汚い茶色だったり、意味の分からない画像であった場合は yaml ファイルが間違っている。

なぜなら、ControlNet モデルで追加される zero convolution のレイヤーとは別に、もとからあるネットワーク構造は破壊されないため、かならずキャプションに沿った何かしら意味のある画像が生成されるはずだからだ。

意味不明な画像が出てきたら、その時点でそのモデルは壊れている。

今後 768 解像度で ControlNet を学習しようとしている人がいたら気を付けてほしい...

このスクラップは2023/12/13にクローズされました