Colab で DreamBooth する
DreamBooth
Google が公開したすっごく強いファインチューニングの手法です。
簡単にどんなものか説明すると、最低 4枚 の画像を与えるとそこに写っているものの他のシチュエーションの画像を生成できるようになります。
DreamBooth の公式サイトから引用
上の画像のようにたった4枚の犬の画像から、泳いでいたり犬小屋の中にいる画像を生成することができます。
DreamBooth の強みはなんと言ってもその精度にあり、他の手法と比べて高い精度で生成することができます。
こんな感じで従来の手法と比べると凄まじいレベルの生成になっていることがわかると思います。
私には仕組み等は全くわからないので知りたい場合は 実際の論文 や他の記事を参照してください。
この記事では一番シンプルな使い方を説明しますが、普通に Colab 見れば大体わかるので慣れてる人は読まなくても大丈夫だと思います。
Colab で動かす
今回は DreamBooth で仙台の伊達政宗の銅像を学習させてみます。
準備
元々の DreamBooth は動かすのに VRAM 40 GB ほどが必要だったのですが、最近 Colab の VRAM 16GB でも動かせるようにしたものが有志によって公開されたのでそれを使います。
この DreamBooth では VRAM を 12.5GB ほど使うので、もし Free プランで VRAM 8GB のものを引いてしまった場合は実行することができません。(そのケースがあるかわからないですが...)。私は Colab Pro で動かしました。
レポの「Open in Colab」をクリックします。
ファイル > ドライブにコピーを保存 から Google Drive にコピーを保存してから使用するのがおすすめです。
編集 > ノートブックの設定 から「ハードウェアアクセラレータ」に「GPU」が選択されていることを確認します。(「ランタイムの仕様」は「標準」でもおそらく大丈夫です)。
実行
パラメーターの設定
上から順に実行していきます。途中まではいつもの StableDiffusion と同じです。
ここで Hugging Face へのログインを求められるので、表示されるリンクにアクセスして API トークンを取得し、入力します。
その後の部分からはファインチューニングの設定になります。
CLASS_NAME = "guy" # just a general name for class like dog for dog images.
覚えさせたい物の一般的な名称を入れます。犬なら dog
、人物なら person
や man
、girl
などを入れます。
MODEL_NAME = "CompVis/stable-diffusion-v1-4"
INSTANCE_DIR = "/content/data/sks"
!mkdir -p $INSTANCE_DIR
CLASS_DIR = f"/content/data/{CLASS_NAME}"
OUTPUT_DIR = "/content/models/sks" # sks is a rare identifier, feel free to replace it.
MODEL_NAME
はファインチューニングするモデルです。StableDiffusion がデフォルトですが、他のモデルでも理論上可能です。Waifu Diffusion は動作を確認しました。
INSTANCE_DIR
は学習用の画像を入れるパスです。sks
は適当になんでも好きな名前に置き換えて良いです。
CLASS_DIR
は学習するときに発生する画像が入るパスです。特にいじる必要はないです。
OUTPUT_DIR
はファインチューニングしたモデルが出力されるパスです。こちらも同様に sks
は自由に置き換えて良いです。
今回は伊達政宗像を学習させるので以下のようになりました。
CLASS_NAME = "statue" # just a general name for class like dog for dog images.
MODEL_NAME = "CompVis/stable-diffusion-v1-4"
INSTANCE_DIR = "/content/data/masamune"
!mkdir -p $INSTANCE_DIR
CLASS_DIR = f"/content/data/{CLASS_NAME}"
OUTPUT_DIR = "/content/models/masamune" # sks is a rare identifier, feel free to replace it.
# Upload your images by running this cell OR you can use the file manager on the left panel to upload to INSTANCE_DIR
import os
from google.colab import files
import shutil
uploaded = files.upload()
for filename in uploaded.keys():
dst_path = os.path.join(INSTANCE_DIR, filename)
print(f'move {filename} to {dst_path}')
shutil.move(filename, dst_path)
これは説明にもあるようにファイル選択画面を表示して画像をアップロードすることができます。が、左のファイルビューワーを開いてドラッグ&ドロップした方が早いのでわざわざ使う必要はないです。
画像のアップロード
今回は伊達政宗像を学習させます。
「せんぴく」 という無料で使える写真素材配布サイトがあったのでそちらからお借りしました。
使うのは以下の 4 つの画像です。
学習開始
学習をするブロックです。変更する必要があるのは instance_prompt
と class_prompt
です。
例と同じような雰囲気で書いていきます。
今回は以下のようになりました。
# remove --use_8bit_adam flag if you got more than 18 GB VRAM.
!accelerate launch train_dreambooth.py \
--pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \
--instance_data_dir=$INSTANCE_DIR \
--class_data_dir=$CLASS_DIR \
--output_dir=$OUTPUT_DIR \
--with_prior_preservation \
--instance_prompt="photo of masamune {CLASS_NAME}" \
--class_prompt="photo of a {CLASS_NAME}" \
--resolution=512 \
--use_8bit_adam \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--learning_rate=5e-6 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--num_class_images=200 \
--max_train_steps=800
他のパラメーターはいじっていませんがお好みで変更してください。
プロンプトを変更したら実行します。
学習には 1 時間ほどかかるので気長に待ちます。
生成
学習が完了したらついに生成の時間です。
3つ目のブロックのプロンプト等を編集していくつか生成してみます。
prompt = ["photo of illuminated masamune statue, at night"] * batch_size
こんな感じ。
生成された画像
(シードは適当に変えたりしています)
photo of illuminated masamune statue, at night
夜にライトアップされた政宗像の写真
ちゃんとライトアップされた。実際のライトアップはもっとカラフル。あとなぜか馬が若干ロバっぽくなってる。
photo of large explosion behind masamune statue, explosion explosion explosion, huge fire fire fire
政宗像の後ろで爆発してる写真、爆発爆発爆発、巨大な炎炎炎
なぜかこっちに背中向けているけど割とそれっぽくてビビってる。よく見ると色々おかしな点はあるが、ぱっと見だとかなりそれっぽい。(語彙力)
illustration of masamune ukiyoe
政宗の浮世絵
それっぽい。片足を上げた馬や兜の三日月みたいなやつなどが再現されている。
Vincent Van Gogh's masamune oil painting
ゴッホの政宗の油絵
雰囲気掴んだままゴッホ風の油絵になった。こちらもちゃんと特徴を掴んでいる。
終わり
学習元画像が青空に台座、木が映っていたせいでその属性も強くなってしまったため、プロンプトで打ち消すのが大変でした。(海に沈めたり宇宙に飛ばしたりしたかったが、木や青空が強すぎてできなかった)。
学習時のパラメーターを変更したり、Stable Diffusion Web UI でネガティブプロンプトを使ったりして回避することができるかもしれないです。
逆に、生成された画像がかなり学習元画像の特徴を掴んでいたのがわかったと思います。DreamBooth のサイトにもあるように、これを使って生成された画像は従来よりも判別が難しい画像を生成できるので、悪用された時に大変だなと感じました。(これは他のものでも同じだが、特に詳しくないと知らないものについての生成は、全然知識のない人から見たら判別がつかないレベルだと思います)。
Discussion