😶‍🌫️

Stable Diffusion から特定の概念を忘れさせる学習を行ってみる

2023/06/29に公開

TL;DR

  • ESD の手法で LoRA を学習してみたらそれっぽい感じのことができたよ
  • VRAM 8GB で余裕で学習できるようになったよ (元は20GB要求)
  • LoRA として保存できるようになったので重みの取り回しが良くなったよ
  • マイナス適用によって、概念を削除するだけでなく強調することもできたよ

一度でも画像生成 AI に触ったことがあると、より楽しんで読めると思います。

論文とかどうでもいいから学習方法知りたい! という方は 実際に学習してみる へどうぞ!

今回作成したもの

コード:

https://github.com/p1atdev/LECO

モデルなど:

https://huggingface.co/p1atdev/leco/tree/main

前提

Stable Diffusion とは、Stability AI らが公開したオープンソースの画像生成 AI であり、テキストによる指示で様々な画像を生成することができる。

本来の Stable Diffusion は、実写画像や海外風のイラストを出力することが得意だが、アジア系の写真やアニメイラスト[1]を出力することは苦手である。

Stable Diffusion による日本人の写真 Stable Diffusion によるイラスト
photo of cute japanese girl stable diffusion anime girl by stable diffusion

そこで、Stable Diffusion はオープンソースなので、個人や団体、企業が自分で用意したデータセットを使って出力を調整することができる。たとえば、Waifu Diffusion[2] (しばしば WD と略される) ではアニメイラストや日本人の画像を学習することで、アニメイラストや日本人っぽい写真の生成を可能にしている。

Waifu Diffusion による日本人の写真 Waifu Diffusion によるアニメイラスト
photo of a girl by waifu diffusion illustration of a girl by waifu diffusion

このように、一般に行われる追加学習は新たな概念を覚えるために行われることが多いが、逆に既に覚えている概念を忘れさせる(出ないようにする)微調整をする手法も存在する。

Erasing Concepts from Diffusion Models (以下ESD) では、モデル自身の知識を利用し、追加のデータセット無しでアートスタイルや特定のオブジェクトなどを出ないように学習させることができる。

ESD の仕組みを応用すると、単純に概念を忘れるだけでなく、その概念を出しやすくしたり、他の概念と入れ替えたり、もしくは混ぜ合わせるということも可能となる[3]。ESD から派生した ConceptMod では、ロスの取り方を工夫することで追加のデータセットを使わずに出力の傾向を調整することを実現している[4]

ただし、ESD のそのままの手法では学習に最低 20GB の VRAM が要求され[5][6]一般的な家庭用 GPU では少し学習が難しいという課題が存在する。

(ConceptMod では Stable Diffusion を直接調整した後にベースモデルとの差分を LoRA として抽出することで LoRA を配布しているようである。[7])

そのため、LoRA の軽量な学習手法に ESD の学習手法を取り入れることで、VRAM 8GB で学習可能にすることを目指す。

Erasing Concepts from Diffusion Models

https://erasing.baulab.info/

https://github.com/rohitgandikota/erasing

https://arxiv.org/abs/2303.07345

ESD についての軽い解説。

特徴

先程述べたように、ESD では特定の概念を消去することが可能である。

忘れたい概念以外への影響が少なく済む [8] ほか、Stable Diffusion モデル全体を学習することは行わず、かつモデル自身の知識を利用するので追加のデータセットが必要ないため、既存の追加学習手法に比べてコストが軽く済む。

また、モデルの重みを変更することで出力を制限するため、重みを共有する場合においても制限を回避して忘れさせた概念を出せるようにすることは簡単ではない。

生成サービスを運営する場合においても、NSFW コンテンツ[9]や、特定の著作権・商標で保護された何かが出てほしくない場合に概念を忘却させることで、問題が発生する懸念を軽減させることができるかもしれない。


概念を忘却させている例。Edited Model が ESDのモデル。画像はプロジェクトページより

画像のように、他の概念への影響を抑えて NSFW 画像やアートスタイル、オブジェクトを出ないように調整できていることがわかる。

仕組み

ここは自信がないのであんまり解説できません....

全体的な方針としては、削除したい概念のトリガーワード (van gogh など) と反対方向に (画像に映らないように) デノイズするよう学習を進める感じです。(すごいざっくり)

ESDの実際の例


概念を消去したモデルそれぞれの出力比較。画像はプロジェクトページより

この画像の見方としては、ヨコが生成時のプロンプトで、タテがそれぞれの概念を削除したモデルとなっている。青の点線で囲まれたものが、概念を削除したモデルでその概念について生成したときの出力となっており、それ以外の出力がオリジナルのSDと異なる場合は、他の概念に影響を与えてしまっていることを示している。

Stable Diffusion のプロンプトでよく見かける(?) Thomas Kinkade に注目すると、Thomas Kinkade の概念を削除したモデルでは、Thomas Kinkade を指定してもきちんとそのスタイルにならないようになっている(写真が出てきてしまっている)ことがわかる。また、それに加えてほかのアーティスト名の出力はそこまで大きく変化していないため、Thomask Kinkade の概念に絞って削除できていることがわかる。

このように、Stable Diffusion で名前を入れたらそのアートスタイルが出てくるような一部の有名なアーティストであれば、SDモデルからその概念を削除することも可能である。(大規模なデータセットから取り除くことができなくても、学習後にその概念を削除することができる可能性がある。)

Low-rank Adaptation for Fast Text-to-Image Diffusion Fine-tuning

https://github.com/cloneofsimo/lora

こっちは他の記事などを参考にするとわかりやすいと思います。

https://speakerdeck.com/koharite/lun-wen-jie-shuo-lora-low-rank-adaptation-of-large-language-models
https://hoshikat.hatenablog.com/entry/2023/05/05/013600
https://note.com/gcem156/n/n6e0178ac6978

本当に超ざっくりまとめると、軽量に追加学習する手法の一つです。(元は大規模言語モデル向けの追加学習手法でしたが、Stable Diffusion にも応用されました)

余談

ESD の論文中で紹介されている学習手法は ESD-x と ESD-u の2つあるのですが、これはそれぞれの Cross Attention と Self Attention を学習します。今回運が良かったのか必然的なのかはわからないのですが、LoRA は Attention の中の構造を軽量に学習するため、既存の LoRA の構造を変えることなく実装が可能でした。

ちなみに、ESD-x で学習される Cross Attention は、プロンプトによるトリガーに強く影響されやすく、ESD-u で学習される Self Attention はプロンプトの影響を受けなくても画像の出力に影響を与えるらしいです。[10]
今回はどちらも一緒に学習しているはずですが... うまく動作しているかはわからない...!

合体

これらを合体させてみます。

そして出来上がったものがこちらになります。(最初に載せたものと同じリンク)

https://github.com/p1atdev/LECO

実際に学習してみる

真面目くさい話は終わりです。ここからは楽しい実践編です。

VRAM 8GB 以上が必要になります。

(2023/6/29) 追記: Colab で使える簡易的な学習ノートブックを作成したので、そっちを使ってもよいです。
Colab: https://colab.research.google.com/github/p1atdev/LECO/blob/main/train.ipynb

まずは必要なものを落としてきて用意します。

git clone https://github.com/p1atdev/LECO
cd LECO

ここでは conda コマンドで環境を作っています(おすすめです)。

conda create -n leco python=3.10
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
pip install xformers 
pip install -r requirements.txt

./examples 下にある設定ファイルをそのまま使ってゴッホ画風の削除を実際に試すことができます。

python ./train_lora.py --config_file "./examples/config.yaml"
`config.yaml` の中身の解説

中身はこんな感じになっていると思います。

prompts_file: "./prompts.yaml"

pretrained_model:
  name_or_path: "stabilityai/stable-diffusion-2-1" # you can also use .ckpt or .safetensors models
  v2: true # true if model is v2.x
  v_pred: true # true if model uses v-prediction

network:
  type: "lierla" # or "c3lier"
  rank: 4
  alpha: 1.0

train:
  precision: "bfloat16"
  noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a"
  iterations: 500
  lr: 1e-4
  optimizer: "AdamW"
  lr_scheduler: "constant"

save:
  name: "van_gogh"
  path: "./output"
  per_steps: 200
  precision: "bfloat16"

logging:
  use_wandb: false
  verbose: false

other:
  use_xformers: true
  • prompts_file では、学習ターゲットとなる概念のプロンプトなどを入れておくファイルです (後述)。
  • pretrained_model は学習のベースモデルです。v1、v2、v2-768解像度に対応しており、それぞれオプションをつけたり消したりすることで設定できます。(diffusers 形式でも .ckpt、.safetensors でも OK)
  • network: kohya氏のLoRAと同じフォーマットで作成するため、その際のネットワーク構造になります。c3lier は LoCon のことです[11]rankalpha はそのフォーマットでの使い方と同じなので、詳細は kohya-ss/sd-scripts を参照です。
  • train: ハイパーパラメーターなどです。
    • precision は学習時に使う小数の精度です。float16bfloat16float32 が使えますが、bfloat16 推奨です。
    • noise_scheduler は、学習の際に使うノイズスケジューラーです。ddim しかテストしてないので違いはわからないです。
    • iterations は学習ステップ数です。500~1000ぐらいで十分だと思います[12]
    • lr は学習率で、よくわからなかったら適当に 1e-4 でいいです。 1e-5 では何も学習されませんでした。
    • optimizer はオプティマイザーで、AdamAdamWLion などが使えます。
    • lr_scheduler は学習率のスケジューラーで、constantlinearcosine などが使えます。
  • save は保存に関する設定です。
    • name は保存される LoRA のファイル名になります。van_gogh_200steps.safetensors みたいになります。
    • path は保存するフォルダの名前になります。
    • per_steps は保存するステップ数の間隔です。
  • logging はログです。
    • use_wandb を true で wandb にログを送信します。(事前に pip install wandbwandb login が必要)
    • verbose はデバッグ用のログを有効にします。不要です。
  • other は他のやつです。
    • use_xformers はオンにしたほうが VRAM 消費が減ると思うのでオンの方がいいと思います。
`prompts.yaml` の中身

こんな感じになっています。

- target: "van gogh" # what word for erasing the positive concept from
  positive: "van gogh" # concept to erase
  unconditional: "" # word to take the difference from the positive concept
  neutral: "" # starting point for conditioning the target
  action: "erase" # erase or enhance
  guidance_scale: 1.0
  resolution: 512
  batch_size: 2

よくわからない場合は、 targetpositive は基本的に揃えて、消したい概念のプロンプトを入れます。また、neutralunconditional"" にしておきます。

  • action は今は eraseenhance を用意してあります。普通に概念を消去したいときは erase を使います。enhance ではロスの取り方が異なり、ターゲットの概念を他の概念に近づけるような調整が可能です。(例: 1girl1girl, cat ears にする、など)
  • guidance_scale は基本的には 1.0 のままでよいと思います。
  • resolution は学習時の解像度となります。これは、生成時にどの解像度で生成するかで変える必要があります。(768x768以上で頻繁に生成する場合、ここを768にしないと効果が薄い場合があります。)
  • batch_size はバッチサイズで、VRAM 8GB なら 2~3 がよいかと思います。

また、 prompts.yaml では、同時に複数の概念について扱うことができます。

例えば、

- target: "realistic"

- target: "instagram"

- target: "real life"

- target: "octane render"

というように、複数の概念の消去を一度の学習で行うことができます。

これだけで学習が始まります。わざわざ画像を用意する必要はありません。(Prompt is All You Need)

デフォルトでは SDv2.1-768 解像度のモデルで学習されるので、変更したい場合は config.yaml を編集してください。

以下は実際にこのコードを使って学習された LoRA の例になります↓

ゴッホ画風削除 (SDv1.5)

oil painting of van gogh by himself
oil painting of van gogh by himself (ゴッホの自画像の油絵)

詳細なプロンプト
oil painting of van gogh by himself
Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 3870472781, Size: 512x512, Model hash: cc6cb27103, Model: v1-5-pruned-emaonly, Clip skip: 2, AddNet Enabled: True, AddNet Module 1: LoRA, AddNet Model 1: van_gogh_4_last(db68853d039b), AddNet Weight A 1: -1.0, AddNet Weight B 1: -1.0, Script: X/Y/Z plot, X Type: AddNet Weight 1, X Values: "-1, 0, 1", Version: v1.3.0

この画像の見方を説明しておくと、左の AddNet Weight 1: -1.0 というのは、今回のゴッホ画風を消去する LoRA を -1.0 強度で適用している(つまり本来の効果の真逆の効果になる)ということを表しており、中央では 0.0 強度(つまり適用していないのと同じ)、右では 1.0 強度で適用(本来の効果)しているという状態です。それぞれの4枚は、左上、右上、右下、左下でそれぞれ共通のシード値(似た構図の絵が出てくる)となっています。

画像を見たら分かるように、0.0 ではいかにもそれっぽいゴッホの油絵が出てきましたが、1.0 強度では「誰...?この人...?」という感じになっており、ゴッホ概念が消去されているのがわかります。

さらに、逆の効果が発生する -1.0 では星月夜[13]のようテクスチャのまでついたゴッホの自画像(と風景)が出てきており、ゴッホの詰め合わせとでも言えるような画像になりました。

他の概念への影響 も見てみましょう。

painting of scenery by monet
painting of scenery by monet (モネが描いた風景の絵)

詳細なプロンプト
painting of scenery by monet
Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 1284787312, Size: 512x512, Model hash: cc6cb27103, Model: v1-5-pruned-emaonly, Clip skip: 2, AddNet Enabled: True, AddNet Module 1: LoRA, AddNet Model 1: van_gogh_4_last(db68853d039b), AddNet Weight A 1: -1.0, AddNet Weight B 1: -1.0, Script: X/Y/Z plot, X Type: AddNet Weight 1, X Values: "-1, 0, 1", Version: v1.3.0

微妙に変化がある(本当はよろしくない)のですが、ゴッホと比べたらかなり軽微な影響と見ることができそうです。LoRA として学習すると適用強度を後から柔軟に変更できるので、それで調整するのもよさそうです。

この LoRA のダウンロードリンク (Civitai & Hugging Face)

モナリザ削除(討伐?) (SDv2.1-768)

mona lisa
mona lisa with jewelry (モナリザとジュエリー)

詳細なプロンプト
mona lisa with jewelry
Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 3630495347, Size: 512x512, Model hash: 832eb50c0c, Model: v2-1_768-ema-pruned, Clip skip: 2, AddNet Enabled: True, AddNet Module 1: LoRA, AddNet Model 1: mona_lisa2_last(393beb35c4b1), AddNet Weight A 1: -1.0, AddNet Weight B 1: -1.0, Script: X/Y/Z plot, X Type: AddNet Weight 1, X Values: "-1, 0, 1", Version: v1.3.0

強度 1.0 ではモナリザを指定してもモナリザの代わりに混沌が生成されるようになりました。また、0.0 では中途半端なモナリザが出てきましたが、-1.0 ではいい感じのモナリザが出てくるようになりました(ジュエリーは無視された)。

1.0 の崩壊具合がすごいですが、他の概念への影響はどのようになっているでしょうか。

cute cat
photo of a cute cat (かわいい猫の写真)

詳細なプロンプト
photo of a cute cat
Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 900866192, Size: 512x512, Model hash: 832eb50c0c, Model: v2-1_768-ema-pruned, Clip skip: 2, AddNet Enabled: True, AddNet Module 1: LoRA, AddNet Model 1: mona_lisa2_last(393beb35c4b1), AddNet Weight A 1: -1.0, AddNet Weight B 1: -1.0, Script: X/Y/Z plot, X Type: AddNet Weight 1, X Values: "-1, 0, 1", Version: v1.3.0

どれもちゃんとかわいい猫ですね!軽微な差は発生していますが、ちゃんと元のモデルの知識を保っているのがわかります。

この LoRA のダウンロードリンク (Civitai & Hugging Face)

実写概念消去 (WD1.5 beta3)

最初の方に述べたように WD1.5 では実写画像とアニメイラストの両方が生成できるのですが、アニメイラストだけを生成したい場合にはうっかり実写画像が出てきてしまうということは避けたいものです。

そこで、 realisticreal lifeinstagram [14]の3つのトリガーワードの効果を薄めてみました。

まずはスタイルのワード realisticanime などを指定せずに比較してみます。

the yellow girl
スタイルを指定していない場合の出力

強度 0.0 では実写画像が出てきてしまっていますが、強度 1.0 ではきちんとイラストになっているのがわかると思います。

今度はプロンプトに realistic, instagram を入れて意地悪したらどうなるのか見てみましょう。

bob cut girl

詳細なプロンプト
masterpiece, best quality, exceptional, best aesthetic,, 1girl, aqua eyes, baseball cap, blonde hair, closed mouth, earrings, green background, hat, hoop earrings, jewelry, looking at viewer, shirt, short hair, simple background, solo, upper body, yellow shirt,
Negative prompt: worst quality, low quality, bad aesthetic, oldest, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry
Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 2867636749, Size: 768x768, Model hash: d38e779546, Model: wd-beta3-base-fp16, Clip skip: 2, AddNet Enabled: True, AddNet Module 1: LoRA, AddNet Model 1: unreal_6_many_prompts_200steps(fff5917285da), AddNet Weight A 1: -1.0, AddNet Weight B 1: -1.0, Script: X/Y/Z plot, X Type: AddNet Weight 1, X Values: "-1, 0, 1", Version: v1.3.0

今回は上段が強度 0.0 (適用なし)、下段が強度 1.0 (適用あり)となっており、一番上の real lifeinstagram は、その単語がプロンプトに含まれているかどうかを表します。real life, instagram であれば、real life, instagram がプロンプトの中に含まれており、空欄の場所はこれらの単語は何も含まれてないことを意味します。

詳細なプロンプト
real life, masterpiece, best quality, exceptional, best aesthetic, 1girl, school uniform, blue hair, bob cut,
Negative prompt: worst quality, low quality, bad aesthetic, oldest, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, jpeg artifacts, signature, watermark, username, blurry
Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 532013706, Size: 768x768, Model hash: d38e779546, Model: wd-beta3-base-fp16, Clip skip: 2, Script: X/Y/Z plot, X Type: Prompt S/R, X Values: "real life, instagram,\"real life, instagram\",\" \"", Y Type: AddNet Weight 1, Y Values: "0, 1", Version: v1.3.0

例えば instagram の項では先頭の real lifeinstagram に置き換わります。

適用あり(下段)を適用なし(上段)と比べたら、実写要素はかなり軽減されていることがわかると思います。ただ、完璧に消しきれているわけではないため、下段では real life などを指定すると、何も指定していない場合と比べて顔つきが変わってしまっており、他のプロンプトとの相性によっては普通に実写画像が出現してしまうこともあります

既に画像生成の環境が整っている場合は、実際に自分で生成してみると LoRA の効果を実感しやすいかと思います。

この LoRA のダウンロードリンク (Hugging Face)

他に考えられる応用例

今思いつくだけでも、他にこんな感じのものができる可能性があります (うまくいくかどうかはわからないですが...)

  • 低クオリティ概念を出なくする (崩れた人体やピンボケなど)
  • トリガーワードが被っている概念の学習前の事前処理 (katsushika hokusai という名前のキャラを学習したいのに katsushika hokusai が既に他の概念で使われている場合など)
  • 短いワードにいろいろな概念を詰め込む (1girl で出てくる画像を 1girl, red hair, cat ears にしてしまうなど)
  • 指定してないのについてくる要素を消す (apple で iPhone 出てくるとか、bowl cut でボウルが出てくるのを調整するなど)
  • 弱い画風のトリガーワードを強化する/軽減する

自由な発想と工夫によって面白い LoRA ができそうですね。

課題

上の例でも見られた通り、他概念への影響が少なからず発生してしまうことがあったり、必ずしもすべての概念を綺麗に消しきれるとも限らないという点があります。

先程の挙げたモナリザ消去 LoRA ではあまり綺麗な出力にならなかったり、実写概念消去 LoRA もあらゆるプロンプトの組み合わせに対応できるわけでもありませんでした。また、論文中では car の概念を消去することができていましたが、今回試した限りではうまく消去することができませんでした。

完璧に消し去るにはさらなる工夫が必要そうです。

感想

軽量にトレーニング可能+画像のデータセットを用意する必要がないので、気軽にいろんな実験ができて非常に楽しいです。

元々は最後に紹介した、実写概念を出せなくする LoRA を作りたくて始めたのですが、いざ作ってみると他にもいろいろな使い道が思いついたので、一見めちゃくちゃナンセンスで意味のなさそうに見えるものでも、時や場面が変わるだけで急に有用になる、というのを ControlNet の登場ぶりに体感しました。(これまで超謎だなと思っていたポーズ推定だとか深度推定の技術がここまで重要になるとは想像もできなかった...)

謝辞

このプロジェクトは、以下のプロジェクトの存在なくしては成立しませんでした。これら一連のプロジェクトに対して深く感謝の意を表明します。

特に kohya氏の sd-scripts は大いに参考にさせてもらったほか、直接様々な有益なアドバイスをいただいたこと、心より感謝申し上げます。

https://github.com/rohitgandikota/erasing

https://github.com/cloneofsimo/lora

https://github.com/kohya-ss/sd-scripts

https://github.com/ntc-ai/conceptmod

脚注
  1. 「アニメ」と言っているが、ここでは動画のアニメのスタイルではなく、我々がイラストと言われて思い浮かべるようなものを指す。 ↩︎

  2. 安定版は v1.4。この記事で使用しているのはベータ版の v1.5 beta3 となる。 ↩︎

  3. https://github.com/ntc-ai/conceptmod ↩︎

  4. https://civitai.com/tag/conceptmod?sort=Newest ↩︎

  5. https://github.com/ntc-ai/conceptmod/tree/main#new---train-or-animate-on-runpod ↩︎

  6. https://ntcai.xyz/articles/use_runpod_to_train_with_conceptmod/ ↩︎

  7. https://github.com/ntc-ai/conceptmod/tree/main#new---train-or-animate-on-runpod ↩︎

  8. もちろん限界はあり、完璧にどの概念でも忘れられることはないし、他の概念に影響が出てしまうものもある。https://erasing.baulab.info/ の Limitations を参照 ↩︎

  9. NSFWコンテンツをそもそも学習するのを避けたSDv2.x系と比較して、ESDのほうがNSFWコンテンツの出現率を下げることに成功したとのこと。https://erasing.baulab.info/ の Erasing nudity を参照 ↩︎

  10. https://arxiv.org/pdf/2303.07345.pdf#figure.caption.2 ↩︎

  11. https://github.com/kohya-ss/sd-scripts#naming-of-lora ↩︎

  12. 論文中で取り上げられている実験は、すべて学習率 1e-5、Adamオプティマイザーで 1000 ステップで行われていました。今回は LoRA なので学習率は少し高めにする必要がありますが、ステップ数は 1000 もあれば体感十分です。 ↩︎

  13. https://ja.wikipedia.org/wiki/星月夜 ↩︎

  14. real lifeinstagram は Waifu Diffusion 1.5 における、実写画像を意味する強力なトリガーワード ↩︎

GitHubで編集を提案
AIものづくり研究会

Discussion