📘

Stable Diffusionからの概念消去⑧:LocoEdit (論文)

2024/05/27に公開

On Mechanistic Knowledge Localization in Text-to-Image Generative Models (ICML2024)

以前, DiffQuickFixの記事を書きました.

https://zenn.dev/fmuuly/articles/fbe5e21ead7aa9

この論文と同じ著者による新作の論文が今回紹介するLocoGenです. DiffQuickFixの後続研究にあたるため, DiffQuickFixの知識は前提とします. 図は論文からの引用です.

書籍情報

Samyadeep Basu, Keivan Rezaei, Priyatham Kattakinda, Vlad I Morariu, Nanxuan Zhao, Ryan A. Rossi, Varun Manjunatha, and Soheil Feizi. On mechanistic knowledge localization in text-to-image generative models. In Ruslan Salakhutdinov, Zico Kolter, Katherine Heller, Adrian Weller, Nuria Oliver, Jonathan Scarlett, and Felix Berkenkamp, editors, Proceedings of the 41st International Conference on Machine Learning, volume 235 of Proceedings of Machine Learning Research, pages 3224–3265. PMLR, 21–27 Jul 2024.

関連リンク

GitHubのリンクはありますが, リポジトリは空です.

Causal Tracing

DiffQuickFixの論文 (以後「前回」とします)ではCausal Mediation Analysisを用いて知識のありそうな場所を特定していました. それはROMEなどのLLMに対する分析手法とは異なっています. 今回の論文ではROMEなどでも行われているCausal Tracingを行います.

今回の分析手法と前回の分析手法との違いを表したのが下の図になります.

流れとしては, U-Netのcross-attentionの部分集合を見つけ, そのkeyとvalueの行列が変更されたときに, スタイルといった視覚的属性がどのように変化するかを確認します. 中間層に関与することで, より直接的な介入の効果を測定することができます.

まずは, 前回と同様の手法を用いて結果を観察します. 簡単に前回の分析手法を述べるとレイヤー単位で壊したり復元したりします.

U-Net

さっそく結果を見てみます.

いわゆるLatent Diffusion Models (LDMs)では幅広く分布していることがわかります. Stable Diffusion 1.5でも同じだったことを思い出すと, 前回と同様の結果と言えます. しかしながら, モデルによってそれは異なっていて, SD-XLでは後ろの層に集中しています. また, ImagenベースのOSSであるDeepFloydでは1箇所だけやや緑が見えますが全体として知識は分布していないことがわかります.

これは実際に生成してみるとわかります. まず, Stable Diffusion 2.1の結果です. オブジェクトという属性は, 様々な層に分布していて, モデル編集では関連するすべての層を編集する必要があります. また, レイヤーの構成要素に非線形性が存在するために閉形式では更新できません.

続いて, SD-XLの結果をみます. これもSD-2.1と同様に様々な部分に分散していますが, それほど密に分布しているわけではありません.

最後に, DeepFloydを確認します. まったくモデル復元ができず, 因果関係がありません. これではU-Netでモデルを編集することができません.

Text Encoder

前回はStable Diffusion 1.xと2.xは同じ結果でした. では, 今回はどうでしょう.

この結果からは, DeepFloydとSD-XLのText Encoderにはこれまであった固有の因果関係が存在しないことがわかります. そこで, Text Encoderの比較を行うと, 以下の表になります. ここから推測できるのは, 単独のCLIP Text Encoderが用いられたときのみにユニークな因果状態が生じるということです.

Model Text Encoder
SD-1.5 CLIP-ViT-L-p14
SD-2.1 OpenCLIP-ViT-H
SD-XL OpenCLIP-ViT-G & CLIP-ViT-L
DeepFloyd T5-XXL

LocoGen: Towards Mechanistic Knowledge Localization

これまで確認した知識局在化の汎化性の欠如を考慮したより汎用的な手法としてLocoGenを提案しています.

Knowledge Control in Cross-Attention Layers

推論過程ではclassifier-free guidanceを用いて条件なし・ありのスコアを取り込みます. 式にすると

\hat{\varepsilon}(\boldsymbol{z}_t, \boldsymbol{c}, t)=\varepsilon_\theta(\boldsymbol{z}_t, \boldsymbol{c}, t)+\alpha(\varepsilon_\theta(\boldsymbol{z}_t, \boldsymbol{c}, t)-\varepsilon_\theta(\boldsymbol{z}_t, t))

です. DDIMで \boldsymbol{z}_t を更新し, \boldsymbol{z}_0 を得るとします. このモデル \varepsilon_\theta(\boldsymbol{z}_t, \boldsymbol{c}, t)\texttt{Clean Model} と呼び, 生成された画像を I_{\mathrm{clean}} とします.

テキストによる条件は, \{C_l\}_{l=1}^M で示されるcross-attention layerを用いて組み込まれます. この層にはkeyとvalueの行列 \{W_l^K, W_l^V\}_{l=1}^M が含まれます. 一般にはここで指定されるテキスト埋め込み \boldsymbol{c} はすべての層で同じです. しかし, 視覚属性に対する制御点を局在化するために, cross-attentionの部分集合に対して \boldsymbol{c}\boldsymbol{c}' で置き換えて生成例を確認します.

イメージとしての話は以上ですが, しっかり数式にしておきます. cross-attention layerの部分集合 C'\subset\{C_l\}_{l=1}^M が, \boldsymbol{c} を入力とする他のcross-attention layerと異なる入力 \boldsymbol{c}' を受け取るとします. このとき, 「モデルが変更された入力 (altered input)を受け取る」と呼びます. これらの層を「制御層 (controlling layers)」と名前をつけます. \boldsymbol{z}_T が初期ノイズとして与えられたとき, 通常の推論 (先ほどの推論の式)を用いて生成された画像を I_{\mathrm{altered}} とします. 変更された入力を持つモデルは以下のように推論し, \texttt{Altered Model} と呼びます.

\hat{\varepsilon}(\boldsymbol{z}_t, \boldsymbol{c}, \boldsymbol{c}', t)=\varepsilon_\theta(\boldsymbol{z}_t, \boldsymbol{c}, \boldsymbol{c}', t)+\alpha(\varepsilon_\theta(\boldsymbol{z}_t, \boldsymbol{c}, \boldsymbol{c}', t)-\varepsilon_\theta(\boldsymbol{z}_t, t))

例えば, 特定のartistに対応するstyleの知識が格納されている層を見つけることを考えます. すると, \{C_l\}_{l=1}^M-C' は"An <object> in the style of <artist>" に対応する埋め込みを受け取ります. C' は"An <object> in the style of painting"に対応する埋め込みを受け取ります. これらの入力で生成された画像が特定のstyleを持たない場合は C' のいずれかの層にそのstyleに関する知識があることが推測できます.

アルゴリズム

目標は異なる視覚属性に対する制御層 C' を見つけることです. 集合 |C'|=m のカーディナリティはハイパーパラメータで, C' の探索空間は指数関数増加 (全部で 2^M-1 あるため)です. なので, 隣接するcross-attentionに限定します. すなわち C'=\{C_l\}_{l=j}^{j+m-1} とします. jj\in[1, M-m+1] です.

その次に, m をどのように設定するかを考えます. ここでは m\in[1, M] で反復的にハイパーパラメータの探索を行います. 各反復において, 隣接する m 個のcross-attention layerが特定の視覚的属性の生成に関与するかどうかを調べます. 最終的なゴールは特定の視覚的属性に対して制御層が存在するような最小の m を見つけることです.

特定の属性にLogoGenを適用するには, 特定の属性を含む入力プロンプトの集合 T=\{T_i\}_{i=1}^N と, T_i から特定の属性情報を取り除いたプロンプトの集合 T'=\{T_i'\}_{i=1}^N を取得します. \boldsymbol{c}_iT_i の埋め込みとし, \boldsymbol{c}_i'T_i' の埋め込みとします.

m が与えられたとき, 制御層の可能性がある M-m+1 個の候補を全探索します. 各々について, N 枚のaltered imageを生成します. このとき, i 番目の画像は m 個の選ばれた層には \boldsymbol{c}_i' を条件としています.

その後, styleとobjectに関しては, T_i とのCLIP Scoreを測定し, factに対しては T_i' とのCLIP Scoreを測定します. styleとobjectについてはCLIP Scoreの低下が属性の除去を表します. styleとobjectのCLIP Scoreの平均が最小でfactの平均が最大のものを特定します. これをアルゴリズムにすると以下のようになります.

特定結果

実際にOSSモデルを用いた結果をみます. 定量結果をみます. ここから, SD-v1-5とSD-v2-1においてはstyleは l=8, m=2 となります. objectとfactはSD-v1-5では l=6, m=2 であることがわかります. 2つのモデルは同じようなU-Netのアーキテクチャで同じようなデータで学習しています (SD-v2-1はSD-v1-5で用いたデータをフィルタリングしたもので訓練しています). しかし, 知識分布は異なっていることから, text encoderが原因であると考えられます.

また, この結果にはないですが, Open-Journeyは概ねSD-v1-5と同じ結果です. Open-JourneyはSD-v1-5をfine-tuningする形で学習されており, Mid-Journeyの生成画像をもとに学習していることを踏まえると, 訓練設定やデータセットの違いではなく, (text encoderやU-Netなどの) モデルアーキテクチャと密接に関連していることが伺えます.

SD-XLの結果にいく前に定性的な結果も見ます.

特定の層に介入することで概念の影響を小さくできています. 例えば, SD-v1-5やOpen-JourneyではLayer 8に介入することでVan Gogh特有の夜空の描き方 (starry night)の要素が消えていることがわかります.

続いてSD-XLです. 先ほどのCLIP Scoreのグラフの (c)と, 下の生成例をみます.

styleとfactは l=45, m=3 で制御できています. また, objectに関しては l=15, m=5 で少し m が大きいです. SD-XLはcross attentionが70個ありますが, まとめると様々な属性はごく一部の層が担っていて, それによってのみ制御できるということが言えます.

しかし, DeepFloydではそれは異なります. 例えばfactという属性を鑑みても, “The British Monarch”は l=6, m=3 ですが “The President of the United States”では l=12 になります.

LocoEdit: Editing to Ablate Concepts

これまでの観察の結果を踏まえて, 概念消去の手法が提案されています. このような解析をこれまでしたのは閉形式で更新するためなので, 当然閉形式で更新を行います.

まず, LocoGenのアルゴリズムの結果として特定の属性に関与するレイヤーの集合が得られます. それを C_{\mathrm{loc}}=\{\hat{W}_l^K, \hat{W}_l^V\}_{l=1}^m とします. これらの重み行列を前回同様に更新することが目標です. 例えば Van Gogh styleであれば painting styleといった感じです.

入力プロンプト T_{\mathrm{orig}}=\{T_i^o\}_{i=1}^N を用います. これは特定の視覚属性を特徴とするプロンプトで構成されます. 同時に, T_{\mathrm{target}}=\{T_i^t\}_{i=1}^N も用意します. ここで, T_i^tT_i^o と同じですが, フォーカスしている属性が含まれないプロンプトです. \boldsymbol{c}_i^o, \boldsymbol{c}_i^t\in\mathbb{R}^d をそれぞれ T_i^o, T_i^t のlast subject tokenの埋め込みとします. \boldsymbol{c}_1^o, \boldsymbol{c}_2^o, \ldots,\boldsymbol{c}_N^o をスタックさせたものを \boldsymbol{X}_{\mathrm{orig}}\in\mathbb{R}^{N\times d} とし, \boldsymbol{c}_i^t についても同様にしたものを \boldsymbol{X}_{\mathrm{target}}\in\mathbb{R}^{N\times d} とします. 層 l\in[1, m] に対して以下の最適化問題を解くことで更新します.

\min_{W_l^K}\|\boldsymbol{X}_{\mathrm{orig}}W_l^K-\boldsymbol{X}_{\mathrm{target}}\hat{W}_l^K\|_2^2+\lambda_K\|W_l^K-\hat{W}_l^K\|_2^2

\lambda_K は正規化の係数です. \boldsymbol{Y}_{\mathrm{orig}}=\boldsymbol{X}_{\mathrm{orig}}W_l^K とすると, 上の最適化問題の解は

W_l^K=(\boldsymbol{X}_{\mathrm{orig}}^T\boldsymbol{X}_{\mathrm{orig}}W_l^K+\lambda_K I)^{-1}(\boldsymbol{X}_{\mathrm{orig}}^T\boldsymbol{Y}_{\mathrm{target}}+\lambda_K\hat{W}_l^K)

です.

結果

まず, 定性的な結果を確認します.

この結果から, styleであったりobjectといったものが消えていることがわかります. これはCLIP Scoreの結果からも裏付けられます.

編集後のモデルは編集前のモデルと比較してstyleとobjectに関するCLIP Scoreが減少し (すなわち要素が消え), factsに関してはCLIP Scoreが上昇している (正しく知識更新された)ことがわかります. しかし, 手法がどれほど有効かはモデルによって異なっていて, 例えばstyleを例にとると, SD-v1-5とOpen-JourneyのグループとSD-v2-1とSD-XLのグループでは減少幅が異なっています. 定性的にはどちらのグループも消せていますが, 定量的には差が出るというのが現状です.

しかし, Deep-Floydでは課題が残ります.

Deep-FloydではT5-XXLが用いられますが, これは双方向のattentionを持ったencoderです. これはCLIP Text Encoderとは異なります. 閉形式での更新手法はlast subject tokenを埋め込むことに依存していているので, last subject token以降のトークンが重要な情報を持つ場合には不適当です.

On Neuron-Level Model Editing

これまではlayerの重み行列を更新していました. ここからはニューロンレベルで編集を行う可能性の探求をします. LocoGenで特定した層に対して, keyとvalueの埋め込みの活性化層でニューロンの選択的dropがstyle要素を効果的に除去できるかどうかを検証することが目的です.

ここでは特定のstyleの生成に関与するニューロンを特定します. 具体的には特定のstyleを含むプロンプトと, 含まないプロンプトを比較する際に, 顕著に変動するニューロンを特定します.

まず, 特定のstyle (例:ゴッホ)を特徴とする N_1 個のプロンプトを収集します. これらのプロンプトのlast subject tokenの埋め込みを \boldsymbol{c}_1, \boldsymbol{c}_2,\ldots,\boldsymbol{c}_{N_1} (\boldsymbol{c}_i\in\mathbb{R}^d) として集めます. また, 特定のstyleを含まない N_2 個のプロンプトを収集し, \boldsymbol{c}'_1, \boldsymbol{c}'_2,\ldots,\boldsymbol{c}'_{N_1} (\boldsymbol{c}'_i\in\mathbb{R}^d) とします. 次に, keyまたはvalueの行列 W\in\mathbb{R}^{d\times d'} について, 入力プロンプトの埋め込みを考えます. すなわち \{z_i\}_{i=1}^{N_1}\cup\{z'_i\}_{i=1}^{N_2} を考えます. ここで, z_i=\boldsymbol{c}_iW および z'_i=\boldsymbol{c}′_iW とし, z_i, z'_i\in\mathbb{R}^{d'} であることに注意します.

さらに, これらの d' 個のニューロンのそれぞれについて, 特定のstyleを含む入力プロンプトとそれを含まない入力プロンプトとの間で活性化の統計的な差異を評価します. 具体的には2つの活性化グループ z_1, z_2, \ldots, z_{N_1} および z'_1, z'_2, \ldots, z'_{N_2} 内の各ニューロンについてz-scoreを計算します. ニューロンはz-scoreの絶対値でランクづけされ, 入力プロンプトに特定の概念が存在するかどうかに応じて活性化に顕著な違いを示すトップのニューロンが特定されます. 生成中にこれらのニューロンをdropoutし, 特定のstyleが除去されるかどうかを確認します.

結果をみてみます.

ニューロンレベルでの編集は効果的であることがわかります. これは, 特定のstyleに関する知識が少数のニューロンにさらに局在化されていることを示唆しています. 特に, 編集するニューロンの数が増えるとより強く概念除去に反映されていることがわかります. CLIP Scoreを測定してみると効果的な除去が確認できます.

まとめ

  • 詳細な分析によってtext-to-image diffusion modelsにおけるU-Net内部の知識の分布を調査
  • それに基づく閉形式編集手法LogoEditの提案
  • U-Net内部の知識の分布はtext-to-imageのアーキテクチャによって様々である
  • Stable Diffusion系ではモデルごとに差はあれど, ある視覚属性について一部のレイヤーの隣接数個に集まる

思ったこと

  • 今回も40ページを超える論文で, すごい量だと思いました.
  • DiffQuickFixの際に思ったことがここで検証されていて, 個人的には安心しました.
  • やはりタイトルでGeerative ModelsといいつつDiffusion Modelsしか検証していないのはあまりいいものではないように思えます.
  • DeepFloydでもできるようになるといいなと思います.
  • 論文の時期的にできていないだけだと思いますがStable Diffusion3のようなCLIPとT5を混合したアーキテクチャでの実験結果が気になります.
  • アーキテクチャによって知識分布が変わるのはLLMではあまりみられない傾向だと思います (自分が知らないだけかもしれませんが). 複数モジュールが関わり合っているからと考えられますがMultimodalな言語モデルでの研究もヒントになりそうです.

参考文献

  • Samyadeep Basu, Keivan Rezaei, Priyatham Kattakinda, Vlad I Morariu, Nanxuan Zhao, Ryan A. Rossi, Varun Manjunatha, and Soheil Feizi. On mechanistic knowledge localization in text-to-image generative models. In Ruslan Salakhutdinov, Zico Kolter, Katherine Heller, Adrian Weller, Nuria Oliver, Jonathan Scarlett, and Felix Berkenkamp, editors, Proceedings of the 41st International Conference on Machine Learning, volume 235 of Proceedings of Machine Learning Research, pages 3224–3265. PMLR, 21–27 Jul 2024.

Changelog

  • PMLR公開によるリンクの更新とそれに伴う文章の削除, 参考文献および書籍情報の更新 (2024/08/26)

Discussion