🎉

Stable Diffusionからの概念消去①:ESD (論文)

2024/03/23に公開

Erasing Concepts from Diffusion Models (ICCV2023)

PlatさんESDをLoRAでできるようにしていた記事を以前投稿されていました. この記事ではその論文について細かく見ていきます.

図や表はことわりのない限り論文からの引用です.

書籍情報

Rohit Gandikota, Joanna Materzynska, Jaden Fiotto- Kaufman, and David Bau. Erasing concepts from diffusion models. In Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), pages 2426–2436, 2023.

関連リンク

TL;DR

  • Diffusion Modelsから特定の概念を消去する研究
  • fine-tuningのみで概念の消去が可能で, データセットのフィルタリングも不要
  • パラメータを更新するので推論時の工夫がなくてもできる

前提

この記事を読む人は概ね理解していると思いますが, 簡単に前提となる知識に触れておきます.

Diffusion Models

一般に拡散モデルは, 「ノイズを徐々に加える過程と徐々に取り除く過程」と表現されます. 実装上は, サンプリングされたガウシアンノイズから始めて T 回の除去ステップを経ます. モデルは時刻 t で加えられたノイズ \varepsilon_t を予測し, これを引き算することで画像 x_t を生成します. x_T は初期ノイズ, x_0 が最終的に生成される画像に対応します. このプロセスはマルコフ遷移確率として以下のように定式化されます.

p_{\theta}(x_{T:0})=p(x_T)\prod_{t=T}^1p_{\theta}(x_{t-1}\mid x_t)

Latent Diffusion Models

簡単に言えば, 先ほどのプロセスを潜在空間で行ったのがLatent Diffusion Models (LDM)です. 潜在空間への射影はVAEのエンコーダー \mathcal{E} が行います. そして, 以下の目的関数を最小化するように学習します.

\mathcal{L}=\mathbb{E}_{z_t\in\mathcal{E}(x),t,c,\varepsilon\sim\mathcal{N}(0, 1)}[\|\varepsilon-\varepsilon_{\theta}(z_t, c, t)\|_2^2]

ここで, x は画像, z=\mathcal{E}(x) です. c は一般に条件ですが, ここではテキストによる条件を仮定します.

生成を行う際は, テキストに合致した条件生成ができるようにclassifier-free guidanceを用います. これはguidance scale \alpha>1 を用いて

\tilde{\varepsilon}_\theta(z_t, c, t)=\varepsilon_\theta(z_t, t) + \alpha(\varepsilon_\theta(z_t, c, t)-\varepsilon_\theta(z_t, t))

で生成画像を求めます. 最後にVAEのデコーダー \mathcal{D} を用いてピクセル空間に射影します.

手法

ある概念を消す方法として2つの方法があります.

  1. 消す概念を含むデータを除去してスクラッチで学習させる
  2. 学習済みモデルをfine-tuningする

1の手法は単純ですが, スクラッチ学習はお金と時間がかかりますし, 消したい概念が出たらまたやり直さなければならないデメリットがあります. そのため, この論文では2の手法について考え, Stable Diffusionを用いて実験します. Stable Diffusionはテキストエンコーダー \mathcal{T}, 拡散モデル (U-Net) \theta^*, デコーダー \mathcal{D} がありますが, この論文では新しいパラメータ \theta を求めます.

大まかな方針は, 対象の概念によって記述される画像 x が生成される確率をファクター \eta>0 によってスケーリングした尤度によって減少させます.

p_{\theta}(x)\propto\dfrac{P_{\theta^*}(x)}{P_{\theta^*}(c|x)^\eta}

c は消したい概念, P_{\theta^*}(x) は元のモデルによって生成される分布を表します. ベイズの定理から

P(c|x)=\dfrac{P(x|c)P(c)}{P(x)}

なので, 対数確率の勾配は

\nabla\log P_{\theta^*}(x)-\eta(\nabla\log P_{\theta^*}(x|c)-\nabla\log P_{\theta}(x))

と書けます. P(c) の部分は x で微分するので定数扱いです. 学習済み拡散モデルがスコアを推定できることを用いれば, 上の式は

\varepsilon_{\theta}(x_t, c, t)\leftarrow\varepsilon_{\theta^*}(x_t, t)-\eta[\varepsilon_{\theta^*}(x_t, c, t)-\varepsilon_{\theta}(x, t)]

となります. この式は, \varepsilon_{\theta} が負のguidance scaleによって誘導されます. すなわち, 本来進む方向とは逆方向に誘導されることになります. このようにして, fine-tuningされたモデルは概念から遠ざけられます.

ここで, 訓練のプロセスを図で見てみます.

これを見ると一目瞭然です. この図や定式化からもわかるように, モデル (U-Net)が2つ必要です.

パラメータ選択の重要性

さて, 先ほどの式

\varepsilon_{\theta}(x_t, c, t)\leftarrow\varepsilon_{\theta^*}(x_t, t)-\eta[\varepsilon_{\theta^*}(x_t, c, t)-\varepsilon_{\theta}(x, t)]

の効果はどのパラメータを更新するかに依存します. 主に, cross-attentionを更新するのか, cross-attention以外を更新するのかで区別されます. cross-attentionはテキストの条件を強く反映するからです (これはテキストがcross attentionによって条件付けされることと関係があると思います).

異なるプロンプトから生成された画像を比較します. すると, self-attentionはプロンプトに関係なく車の特徴に寄与していますが, cross-attentionは単語の有無が影響しています.

そのため, 名前付きのartistic styleを消去するなどの, 消去をプロンプトに制御し, 特定のものにする必要がある場合は, cross attentionを更新するESD-xを提案しています.

さらに, NSFWのヌードなどのグローバルな概念を消去し, プロンプトのテキストと独立している場合, cross-attentionでない層を更新するESD-uを提案しています. 更新に用いる式

\varepsilon_{\theta}(x_t, c, t)\leftarrow\varepsilon_{\theta^*}(x_t, t)-\eta[\varepsilon_{\theta^*}(x_t, c, t)-\varepsilon_{\theta}(x, t)]

では \eta がハイパーパラメータとして設定可能です. 簡単のために各手法の最後に \eta の値をつけてそれぞれESD-x-\eta, ESD-u-\etaとしています. \eta=1 のときは単にESD-x, ESD-uと表記します.

どこを更新するかで比較した結果を見てみます. ここでは"Van Gogh style"を消去しています.

self attentionを更新する場合, グローバルな消去が行われるために他のstyleでも大きく変化していることがわかります. 一方で, cross attentionのみを更新する場合 (1番右), Van Goghは消しつつも他のスタイルは構図を含めて維持できていることがわかります.

実験

では, 実験結果を見ていきます. その前に訓練条件などを確認します.

条件

Stable Diffusion1.4を使用します. バッチサイズは1で学習率は 10^{-5} の設定でAdamを用いて1000回更新します. それ以外の設定は書かれていないので公式実装を確認すると, PyTorchのデフォルトを使用していました.

先ほどの議論から, 更新するパラメータが重要であることがわかりました. 実験では次のようにしています.

  • ESD-x: cross attentionを更新
  • ESD-u: Stable DiffusionのU-Netのcross attention以外を更新

ベースラインは以下の通りです.

  • SD (特に何もしていないStable Diffusionです)
  • SLD (Safe Latent Diffusion)
  • SD-Nagetive-prompt: (Negative promptを用いたものです)

styleの消去

Kelly McKernan, Thomas Kinkade, Tyler Edlin, Kilian Eng と, 『亜人』の5つのスタイルについて実験を行います. まずは, SLDとの比較をみます. main paperには3つの例しかありませんが, Appendixには5つのスタイル消去の例が掲載されています.

注目点は2つです.

  • スタイルが消去されていてかつ構図やその他の内容は維持されている
  • 他のスタイルへの影響が軽微である

これだけでは著者らの主観的な判断のため, User Studyを行います. 細かい設定は省略します (論文には調査に用いたUIを含めて細かく記述されています)が, 結果を見てみます. 13人の参加者は, 実験画像が5つの実際の作品と同じアーティストによって制作されたものであるという確信度を5段階評価で回答しています.

本物の場合は3.85です (最も左). ですが, それに類似したアーティストの場合は3.16, Stable Diffusionが生成したものは3.21であり, AIが生成した作品は類似する本物の作品よりも高く評価されていることを示しています.

次に, 消去する手法同士の比較です. ESDは最もよく概念消去ができていることがわかります (b). また, 他のスタイルも維持していることがわかります (c). それぞれの数値についてはあまり意味がないのでここでは省略しますが, 提案手法の方が定性的評価で優れていることがわかります.

暗黙的な概念の消去

論文では, NSFWなどの概念消去の例をみています. 論文では実験結果の生成例の画像がないために, 定性的な比較はここではできません. 定量的比較はされていますが, この記事では省略します. 論文では5.2に当たる部分です.

Objectの消去

スタイルが消去できることは確認しました. では, オブジェクトはどうでしょうか. ここでは10個のESD-uを用いて調べます. ImageNetのサブクラスから1つずつクラス名を消し去ります. 各クラスで500枚の画像を生成し, ResNet50を用いてtop-1 accuracyを確認します.

結果を見てみると, 概念を消去したクラスではaccuracyが大幅に減少していることがわかります. Churchは難しく, 高いaccuracyを示していますが, それ以外は概ね消去できていると言えそうです. それ以外の生成例でのaccuracyも全体として15%ほど低下していますが, 依然として高いaccuracyといえます. 実際の生成例も見てみます. 論文のmain paperで触れられていた部分をみます.

まずはChurchです. これを見てみると確かにOriginal SDから画像は変化していますが, 依然としてChurchであると言えるかなと思います. この画像であれば分類モデルがChurchと分類するのも納得です.

また, French Hornが他のクラスに大きな影響を与えたとされていますので, それを見てみます. この結果を見ると, 他のクラスへの影響は全体的に軽微ですが, French Hornの場合は例えばゴルフボールが消失してしまったりといった影響が確認できます. ただ, 他の多くの例でもGolf Ballは緑色になっています. もちろん緑のGolf Ballは存在しますが, Original SDは白色なので色の変化を軽微な変化と言っていいのかという点は気になります.

\eta の影響

\eta の値を1, 3, 10と変化させたときの性能を比較しています. 表の形式にはなっていませんので, ここで表に変換しておきます.

\eta "nudity"の出現率 \downarrow 10 Imagenette Accuracyの低下率 \downarrow FID COCO-30k \downarrow
1 0.17 0.07 13.68
3 0.12 0.14 17.27
10 0.08 0.34 記載なし

\eta を減らすと他の概念への影響は抑えられますが, 消去の効果は薄くなります. これはトレードオフの関係のようです. なお, \eta を大きくするとFIDは悪くなります. しかし, CLIP Scoreは同程度です.

Limitations

手法の限界を確認します. 先ほどChurchが消せないという話がありましたが, それと同様に対象の概念が消去できなかったり, 逆に他の概念に大きく影響を与えたりということがあります.

図の上段では, Rembrandt style, Van Gogh styleをそれぞれ消去したときに, 他の概念に影響が出ていることがわかります. また, 下段ではChurchとParachuteがそれぞれ完全に消えていないことがわかります.

まとめ

  • Diffusion Modelsから特定の概念を消去する研究
  • fine-tuningのみで概念の消去が可能で, データセットのフィルタリングも不要
  • パラメータを更新するので推論時の工夫がなくてもできる

思ったこと

  • U-Netの更新は一般にGPUメモリを要求し, 本手法ではU-Netが2つ登場するが, どれほどの時間とメモリを要するのかが不明だった
  • この手の定量評価は非常に難しい気がしていて, 例えばUser Studyを用いているのにUser Studyと評価の良し悪しが乖離するFIDを使っていたりと苦しんでいるように見えます (この点については以下の論文が参考になります).

https://openaccess.thecvf.com/content/CVPR2023/html/Otani_Toward_Verifiable_and_Reproducible_Human_Evaluation_for_Text-to-Image_Generation_CVPR_2023_paper.html

  • 一般に追加学習をすると破滅的忘却が起こるとされています. それを防ぐ策を講じていないにもかかわらず, 他の概念への影響が少ないのは画像ドメイン特有の性質でしょうか.

参考文献

Discussion