🔥

Stable Diffusionからの概念消去⑰:Concept Pinpoint Eraser (論文)

2025/03/08に公開

Concept Pinpoint Eraser for Text-to-image Diffusion Models via Residual Attention Gate (ICLR2025)

引き続きICLR2025採択論文を確認します. 概念消去では多くの場合Cross Attentionが更新されます. これはESDなどの論文でそれがいいとわかったからですが, その効果などはあまり検証されていません.

書籍情報

Byung Hyun Lee, Sungjin Lim, Seunggyu Lee, Dong Un Kang, and Se Young Chun. Concept pinpoint eraser for text-to-image diffusion models via residual attention gate. In The Thirteenth International Conference on Learning Representations, 2025.

関連リンク

現時点 (2025/03/08)ではarXivにはないです.

TL; DR

著者らがまとめている貢献を並べます.

  • cross attention layerの更新は幅広い概念保持の能力に欠けることを数学的に示した
  • attention anchoring loss, ロバストな訓練戦略によるResidual Attention Gates (ResAGs)を持つConcept Pinpoint Eraser (CPE)を提案
  • 実験で多様な概念保持を達成

Cross Attentionの更新だけでは不十分

この論文の主張のひとつ目をみていきます. まずはCross Attention (以降CA)が何なのかを定義する必要があります.

モデルの l 番目のblockにおける H-headの線形射影を \theta^l=\cup_{h=1}^H\{\boldsymbol{W}_q^{l, h},\boldsymbol{W}_k^{l, h},\boldsymbol{W}_v^{l, h},\boldsymbol{W}_o^{l, h}\} と表すことにします. 毎回このように書くのは大変なので, 問題のない範囲で添字などが省略されます. \bold{E}\in\mathbb{R}^{d\times m} をtext embeddingとします (d は特徴量の次元, m はトークン列の長さ). このときCAは以下のように定義できます.

\sigma(\cdot) をsoftmaxとする. H-headのCA layerは \theta=\cup_{h=1}^H\{\boldsymbol{W}_q^h,\boldsymbol{W}_k^h,\boldsymbol{W}_v^h,\boldsymbol{W}_o^h\} とquery token \boldsymbol{z}\in\mathbb{R}^{d_1} によってパラメータ化され, 以下のように定義される.

\tau(\boldsymbol{z}, \bold{E})=\sum_{h=1}^H\boldsymbol{W}_o^h\boldsymbol{W}_v^h\bold{E}\cdot\sigma\left(\dfrac{(\boldsymbol{W}_k^h\bold{E})^\top\boldsymbol{W}_q^h\boldsymbol{z}}{\sqrt{m}}\right)

ここで, \boldsymbol{W}_q^h\in\mathbb{R}^{\frac{d_2}{H}\times d_1}, \boldsymbol{W}_k^h\in\mathbb{R}^{\frac{d_2}{H}\times d}, \boldsymbol{W}_v^h\in\mathbb{R}^{\frac{d_2}{H}\times d}, \boldsymbol{W}_o^h\in\mathbb{R}^{d_1\times \frac{d_2}{H}}

CA layerの更新のみの変化による概念保持の課題を調べるために, CA layerの出力の変化が元の出力から拡散モデルの出力の変化を引き起こすと仮定します. この仮定の下でまずはCA layerにおける線形射影が変化した際の出力の上限を調査します. 次に, CA layerの出力の変化の期待される上限を導出することによって, 概念維持に対する上限を制限することには限界があることを示し, 拡散モデルの忘却を引き起こす可能性があることを示します.

簡単な場合としてUCEやMACEのように \boldsymbol{W}_k^h, \boldsymbol{W}_v^h が更新されるケースを考えます. 上限が書けることを定理1によって示します.

定理1: \tilde{\boldsymbol{W}}_k^h, \tilde{\boldsymbol{W}}_v^h をそれぞれ更新された後の \boldsymbol{W}_k^h, \boldsymbol{W}_v^h の重みとします. \|\bold{E}\|_2\leq M_1\|\boldsymbol{z}\|_\infty\leq M_2 を仮定し, 線形変換 \boldsymbol{W} のLipschitz定数を L_{\boldsymbol{W}} と書くとき,

\|\tau(\boldsymbol{z}, \bold{E}; \tilde{\boldsymbol{W}}_k^h, \tilde{\boldsymbol{W}}_v^h)-\tau(\boldsymbol{z}, \bold{E}; \boldsymbol{W}_k^h, \boldsymbol{W}_v^h)\|_2\leq\sum_{h=1}^H[C_1^h\|\Delta\boldsymbol{W}_k^h\bold{E}\|_F+C_2^h\|\Delta\boldsymbol{W}_v^h\bold{E}\|_F]

です. ここで, \Delta\boldsymbol{W}=\tilde{\boldsymbol{W}}-\boldsymbol{W}, C_1^h=\dfrac{M_1M_2\sqrt{m-1}}{m\sqrt{m}}L_{\boldsymbol{W}_o^h\boldsymbol{W}_v^h}L_{\boldsymbol{W}_q^h}, C_2^h=L_{\boldsymbol{W}_o^h} です.

証明

f_1(\boldsymbol{X})=\boldsymbol{W}_o^h\boldsymbol{X}, f_2(\boldsymbol{X})=\sigma\left(\dfrac{\boldsymbol{X}^\top\boldsymbol{W}_o^h\boldsymbol{z}}{\sqrt{m}}\right) とします. このとき重みが変化した際の拡散モデルの出力の違いは

\begin{align*} &\|\tau(\boldsymbol{z}, \bold{E}; \tilde{\boldsymbol{W}}_k^h, \tilde{\boldsymbol{W}}_v^h)-\tau(\boldsymbol{z}, \bold{E}; \boldsymbol{W}_k^h, \boldsymbol{W}_v^h)\|_2 \\ &=\left\|\sum_{h=1}^Hf_1(\tilde{\boldsymbol{W}}_v^h\bold{E})f_2(\tilde{\boldsymbol{W}}_k^h\bold{E})-\sum_{h=1}^Hf_1(\boldsymbol{W}_v^h\bold{E})f_2(\boldsymbol{W}_k^h\bold{E})\right\|_2 \\ &\leq\sum_{h=1}^H\|f_1(\tilde{\boldsymbol{W}}_v^h\bold{E})f_2(\tilde{\boldsymbol{W}}_k^h\bold{E})-f_1(\boldsymbol{W}_v^h\bold{E})f_2(\boldsymbol{W}_k^h\bold{E})\|_2 \\ &\leq\sum_{h=1}^H[\|(f_1(\tilde{\boldsymbol{W}}_v^h\bold{E})-f_1(\boldsymbol{W}_v^h\bold{E}))f_2(\tilde{\boldsymbol{W}}_k^h\bold{E})+f_1(\tilde{\boldsymbol{W}}_v^h\bold{E})(f_2(\tilde{\boldsymbol{W}}_k^h\bold{E})-f_2(\boldsymbol{W}_k^h\bold{E}))\|_2] \\ &\leq\sum_{h=1}^H[\|f_2(\tilde{\boldsymbol{W}}_k^h\bold{E})\|_{\infty}\|(f_1(\tilde{\boldsymbol{W}}_v^h\bold{E})-f_1(\boldsymbol{W}_v^h\bold{E}))\|_2+\|f_1(\tilde{\boldsymbol{W}}_v^h\bold{E})\|_2\|(f_2(\tilde{\boldsymbol{W}}_k^h\bold{E})-f_2(\boldsymbol{W}_k^h\bold{E}))\|_2] \end{align*}

\|f_2(\tilde{\boldsymbol{W}}_k^h\bold{E})\|_{\infty}\leq1 です. これは, softmaxの最大値が1だからです. また, \|f_1(\tilde{\boldsymbol{W}}_v^h\bold{E})\|_2\leq L_{\boldsymbol{W}_o^h\boldsymbol{W}_v^h}\|\bold{E}\|_2\leq L_{\boldsymbol{W}_o^h\boldsymbol{W}_v^h}M_1 です. そのため,

\|(f_1(\tilde{\boldsymbol{W}}_v^h\bold{E})-f_1(\boldsymbol{W}_v^h\bold{E}))\|_2\leq L_{\boldsymbol{W}_o^h}\|\Delta\boldsymbol{W}_{v}^h\bold{E}\|_F

であり,

\begin{align*} \|(f_2(\tilde{\boldsymbol{W}}_k^h\bold{E})-f_2(\boldsymbol{W}_k^h\bold{E}))\|_2&=\left\|\sigma\left(\dfrac{(\tilde{\boldsymbol{W}}_k^h\bold{E})^\top\boldsymbol{W}_q^h\boldsymbol{z}}{\sqrt{m}}\right)-\sigma\left(\dfrac{(\boldsymbol{W}_k^h\bold{E})^\top\boldsymbol{W}_q^h\boldsymbol{z}}{\sqrt{m}}\right)\right\| \\ &\leq\dfrac{\sqrt{m-1}}{m\sqrt{m}}\|(\tilde{\boldsymbol{W}}_k^h\bold{E})^\top\boldsymbol{W}_q^h\boldsymbol{z}-(\boldsymbol{W}_k^h\bold{E})^\top\boldsymbol{W}_q^h\boldsymbol{z}\|_2 \\ &\leq\dfrac{\sqrt{m-1}}{m\sqrt{m}}L_{\boldsymbol{W}_q^h}\|\boldsymbol{z}\|_{\infty}\|\tilde{\boldsymbol{W}}_k^h\bold{E}-\boldsymbol{W}_k^h\bold{E}\|_2 \\ &\leq\dfrac{M_2\sqrt{m-1}}{m\sqrt{m}}L_{\boldsymbol{W}_q^h}\|\Delta\boldsymbol{W}_k^h\bold{E}\|_F \end{align*}

となります. 最初の不等式ではsoftmaxのLipschitz constantが \dfrac{\sqrt{m-1}}{m} であるという既存研究の成果を用いています.

https://openreview.net/forum?id=H1aIuk-RW

以上から定理が従います.

\bold{E}_{\text{tar}}, \bold{E}_{\text{rem}} をそれぞれtarget concept, remaining conceptのtext embeddingとします. すると定理1から, target conceptに関する分布のシフトを達成するためには \|\Delta\boldsymbol{W}\bold{E}_{\text{tar}}\| が十分に大きいことが必要であることを示唆しています. そうでなければ CA layerの出力は十分に変化しないとも言えます. 一方で定理1の不等式の上限が大きくても, 維持したい概念のCA layerの出力変化は保証されません. 従って, 維持したい概念に対するCA layerの出力の変化を抑えるためには \|\Delta\boldsymbol{W}\bold{E}_{\text{rem}}\| を最小化することが望ましいと言えます. すると, 任意の \|\Delta\boldsymbol{W}\bold{E}_{\text{rem}}\| をどの程度抑えられるのかという問題が発生します. これを明らかにするために, text embeddingに対するガウス混合モデルを用いて \|\Delta\boldsymbol{W}\bold{E}_{\text{rem}}\| の期待値を解析します. ここで, 定理2を紹介します.

定理2: \bold{E}_{\text{rem}}=[\boldsymbol{e}_1, \ldots, \boldsymbol{e}_i, \ldots,\boldsymbol{e}_m] とします. 維持したい概念のtext embeddingがガウス混合分布に従うと仮定したとき, すなわち p(\boldsymbol{e}_i)=\displaystyle\sum_{k=1}^K\pi_k\mathcal{N}(\boldsymbol{e}_i;\boldsymbol{\mu}_k^i, \sigma_k^2\boldsymbol{I}) かつ \displaystyle\sum_{k=1}^K\pi_k=1 です. \boldsymbol{\mu}_k=[\boldsymbol{\mu}_k^1,\ldots,\boldsymbol{\mu}_k^m] のとき

\mathbb{E}_{\bold{E}_{\text{rem}}}[\|\Delta\boldsymbol{W}\bold{E}_{\text{rem}}\|_F^2]=C_3\|\Delta\boldsymbol{W}\|_F^2+\sum_{k=1}^K\pi_k\|\Delta\boldsymbol{W}\boldsymbol{\mu}_k\|_F^2

です. ただし C_3=\displaystyle\sum_{k=1}^K\pi_k\sigma_k^2 です.

証明

\boldsymbol{e}_i が混合ガウス分布から得られると仮定すると

\begin{align*} \mathbb{E}_{\bold{E}_{\text{rem}}}[\|\Delta\boldsymbol{W}\bold{E}_{\text{rem}}\|_F^2] &=\mathbb{E}_{\bold{E}_{\text{rem}}}[\|\Delta\boldsymbol{W}[\boldsymbol{e}_1,\ldots,\boldsymbol{e}_m]\|_F^2] \\ &=\mathbb{E}_{\bold{E}_{\text{rem}}}[\|[\Delta\boldsymbol{W}\boldsymbol{e}_1,\ldots,\Delta\boldsymbol{W}\boldsymbol{e}_m]\|_F^2] \\ &=\mathbb{E}_{\bold{E}_{\text{rem}}}\left[\displaystyle\sum_{i=1}^m\|\Delta\boldsymbol{W}\boldsymbol{e}_i\|_2^2\right] \\ &=\sum_{i=1}^m\mathbb{E}_{\boldsymbol{e}_i}[\boldsymbol{e}_i^\top\Delta\boldsymbol{W}^\top\Delta\boldsymbol{W}\boldsymbol{e}_i] \\ &=\sum_{i=1}^m\sum_{k=1}^K\pi_k\mathbb{E}_{\boldsymbol{e}_i\sim\mathcal{N}(\boldsymbol{\mu}_k^i, \sigma_k^2\boldsymbol{I})}[\boldsymbol{e}_i^\top\Delta\boldsymbol{W}^\top\Delta\boldsymbol{W}\boldsymbol{e}_i] \end{align*}

です. 対称行列の \bold{B} を用いた2次形式 \boldsymbol{e}_i^\top\bold{B}\boldsymbol{e}_i の期待値は

\mathbb{E}_{\boldsymbol{e}_i\sim\mathcal{N}(\boldsymbol{\mu}_k^i, \sigma_k^2\boldsymbol{I})}[\boldsymbol{e}_i^\top\bold{B}\boldsymbol{e}_i]=\sigma_k^2\mathrm{tr}(\bold{B})+(\boldsymbol{\mu}_k^i)^\top\bold{B}\boldsymbol{\mu}_k^i

と書ける事実を用いると,

\mathbb{E}_{\boldsymbol{e}_{i}\sim\mathcal{N}(\boldsymbol{\mu}_{k}^i, \sigma_k^2\boldsymbol{I})}[\boldsymbol{e}_i^\top\Delta\boldsymbol{W}^\top\Delta\boldsymbol{W}\boldsymbol{e}_i]=\mathrm{tr}(\Delta\boldsymbol{W}^\top\Delta\boldsymbol{W})\sum_{k=1}^K\pi_k\sigma_k^2+\sum_{k=1}^K\pi_k(\boldsymbol{\mu}_k^i)^\top\Delta\boldsymbol{W}^\top\Delta\boldsymbol{W}\boldsymbol{\mu}_k^i

です. ここから定理が従います.

この式における C_3\|\Delta\boldsymbol{W}\|_F^2 はCA layerのみを更新する際のジレンマを強調しています. target conceptを消去するために \|\Delta\boldsymbol{W}\|_F を大きくすると, 維持したい概念に対する上限も大きくなるということです. \Delta\boldsymbol{W} を固定すると, \sigma_k^2 を減らさない限り上限は変わりません. さらに, \bold{E}_{\text{rem}} のモードが \Delta\boldsymbol{W} のnull spaceにない場合は上限がさらに悪化します. これは, target conceptを消去するためにCA layerをfine-tuningなどで更新することで他の概念の維持ができない可能性があり, 結果としてモデルの出力が変化することを示唆しています.

Residual Attention Gate (ResAG)

これまでの結果を踏まえて著者らは非線形モジュールを追加することを提案しています. f(\bold{E})=\bold{V}_k\in\mathbb{R}^{m\times m} をサンプルの概念に適応する埋め込み依存の射影とします. このとき以下の系が従います.

系1: \Delta\boldsymbol{W}f(\bold{E})\Delta\boldsymbol{W}\bold{E} の代わりに使うとき, 定理2における式は

\mathbb{E}_{\bold{E}_{\text{rem}}}[\|\Delta\boldsymbol{W}f(\bold{E}_{\text{rem}})\|_F^2]=\|\Delta\boldsymbol{W}\|_F^2\sum_{k=1}^K\pi_k\sigma_k^2\|\bold{V}_k\|_F^2+\sum_{k=1}^K\pi_k\|\Delta\boldsymbol{W}\boldsymbol{\mu}_k\bold{V}_k\|_F^2
証明

定理2と同様にすると

\mathbb{E}_{\bold{E}_{\text{rem}}} [\|\Delta\boldsymbol{W}f(\bold{E}_{\text{rem}})\|_F^2]= \sum_{k=1}^K\pi_k\sum_{i=1}^m\mathbb{E}_{\bold{E}_{\text{rem}}\sim\mathcal{N}(\boldsymbol{\mu}_k, \sigma_k^2\boldsymbol{I})}[(\bold{E}_{\text{rem}}\bold{V}_k^i)^\top\Delta\boldsymbol{W}^\top\Delta\boldsymbol{W}\bold{E}_{\text{rem}}\bold{V}_k^i]

です. \bold{V}_{k}^{i}\bold{V}_{k}i 番目のcolumnです. \bold{E}_{\text{rem}}\bold{V}_k^i\mathcal{N}(\boldsymbol{\mu}_k\boldsymbol{V}_k^i, \sigma_k^2\|\bold{V}_k^i\|_2^2\boldsymbol{I}) なので,

\sum_{i=1}^m\mathbb{E}_{\bold{E}_{\text{rem}}\sim\mathcal{N}(\boldsymbol{\mu}_k, \sigma_k^2\boldsymbol{I})}[(\bold{E}_{\text{rem}}\bold{V}_k^i)^\top\Delta\boldsymbol{W}^\top\Delta\boldsymbol{W}\bold{E}_{\text{rem}}\bold{V}_k^i]=\sigma_k^2\|\bold{V}_k\|_F^2\|\Delta\boldsymbol{W}\|_F^2+\|\Delta\boldsymbol{W}\boldsymbol{\mu}_k\bold{V}_k\|_F^2

となります.

さて, この式ですが, もし f(\bold{E}) がembeddingの概念を検出し, 維持する概念に対しては \|\bold{V}_k\|_F^2 を0に抑えつつtarget conceptに対しては大きくなるような \bold{V}_k を見つけることができれば, target conceptを消去しつつ他の概念を維持することが可能であることを示しています. これをさらに明確にするために次の命題のような例を考えてみます.

命題1: \alpha(\bold{E})\geq0 のときに f(\bold{E})=\alpha(\bold{E})\boldsymbol{I} とする. また, \mathcal{D}_{\text{tar}} をtarget conceptの埋め込みの分布とする. \alpha(\bold{E}) がtarget conceptの埋め込みが得られる分布を分類できるものと仮定します. すなわち

\alpha(\bold{E})=\begin{cases} 1 & \bold{E}\sim\mathcal{D}_{\text{tar}} \\ 0 & \text{otherwise} \end{cases}

このとき, 系1の式は \bold{E}_{\text{rem}} に対して0になり, \mathbb{E}_{\bold{E}_{\text{rem}}}[\|\Delta\boldsymbol{W}f(\bold{E}_{\text{rem}})\|_F^2] を抑制することが可能になります.

ところが, このような \alpha(\bold{E}) を見つけることはできません. その上, 概念は単なる独立したトークンではなくてトークン列なので f(\bold{E}) はembeddingにおけるトークン間の関係に基づいて設計されるべきです. 例えば, target conceptに"Bill Clinton"を選んで"Bill Murray"を維持したい場合を考えます. このとき, 両者は"Bill"という同じトークンを含んでいますが, f(\bold{E}) は前者にのみ活性化する必要があります. これを実現するには f(\bold{E})[1] を設計する際に3つの要素を考慮する必要があります.

  1. text embedding内の概念の関係性を理解すること
  2. target conceptを含むembeddingを正確に識別すること
  3. target conceptの出力を抑制しつつ, 他の概念に対しては出力を維持すること

では, どうやって f(\bold{E}) を作ればいいでしょうか. 著者らが着目したのはAttention Gateです. Attention Gateは埋め込み同士の関係に選択的に焦点を当て, 学習した重要度を動的に調整します. これによって不要な詳細をフィルタリングできます. 具体的には, 各target concept c に対して個別のattention gateモジュール f_c を学習し, 推論次にそれらを統合して複数の概念を処理します. これを著者らはResidual Attention Gate (ResAG)と名付けています. 以下に構造を示します.

式で書くと

f_c(\bold{E})=\boldsymbol{A}_cS(\bold{v}_c^\top\bold{E}\boldsymbol{A}_c),\quad \boldsymbol{A}_c=\sigma\left(\dfrac{(\boldsymbol{U}_{1, c}\bold{E})^\top(\boldsymbol{U}_{2,c}\bold{E})}{\sqrt{m}}\right)

となります. \boldsymbol{U}_{1,c},\boldsymbol{U}_{2,c}\in\mathbb{R}^{s_1\times d}, \bold{v}_c\in\mathbb{R}^d で, s_1\ll d です. S(\cdot)\in\mathbb{R}^{m\times m} は対角行列で, その対角成分は入力に対するsigmoidの出力です. この構造によって, f(\bold{E}) はtarget conceptとremaining conceptを識別し, target conceptのみに修正を加えることが可能になるようです. \boldsymbol{U}_{1,c},\boldsymbol{U}_{2,c}, \bold{v}_c は全てのCA layerで共有されるのでtarget conceptを消去するのに必要なパラメータ数を減らすことが可能です (CAのパラメータはtext embeddingに対して影響を与えないので同じものを使っても問題ないです). さらに, \Delta\boldsymbol{W}_c はLoRAとして定義します. 分解すると \Delta\boldsymbol{W}_c=\boldsymbol{U}_{4,c}^\top\boldsymbol{U}_{3,c} です. 出力の変化は \boldsymbol{U}_{4,c}^\top\boldsymbol{U}_{3,c}f(\bold{E})_c になります.

複数概念を消去する際には下図に示すように, それぞれのtarget conceptに対して学習したResAGを統合します. 選び方は c^*=\argmax_{c}\{S(\bold{v}_c^\top\bold{E}\boldsymbol{A}_c)_{ii}\} によって求めたトークンを加算します. これは各トークンごとにゲート値が最も高いtarget concept c^* のResAGのみを追加することを意味します.

損失関数

まず, target conceptを消去するために用いる損失を示します. 基本的な考え方は今までと同様で, 別の概念 \bold{E}_{\text{sur}} になるようにします.

\mathcal{L}_{\text{era}}(\mathcal{E}_{\text{tar}}, \mathcal{E}_{\text{sur}})=\mathbb{E}_{(\bold{E}_{\text{tar}}, \bold{E}_{\text{sur}})}\|(\boldsymbol{W}\bold{E}_{\text{tar}}+R_{\text{tar}}(\bold{E}_{\text{tar}}))-(\boldsymbol{W}\bold{E}_{\text{sur}}-\eta\boldsymbol{W}(\bold{E}_{\text{tar}}-\bold{E}_{\text{sur}}))\|^2

ここで, 概念 c に対して R_c(\bold{E})=\boldsymbol{U}_{4, c}^\top\boldsymbol{U}_{3, c}\bold{E}f_c(\bold{E}) です. 大体ESDなどと同じですが, 既存研究ではこれをstable diffusionの出力に対して行なっていたのに対し, 提案手法ではkey/valueの射影の出力対して行います. key/valueの添字は消しています.

続いて, remaining conceptに対する損失を示します. 事前に定義したremaining conceptsに対して定理1で出てきた不等式

\|\tau(\boldsymbol{z}, \bold{E}; \tilde{\boldsymbol{W}}_k^h, \tilde{\boldsymbol{W}}_v^h)-\tau(\boldsymbol{z}, \bold{E}; \boldsymbol{W}_k^h, \boldsymbol{W}_v^h)\|_2\leq\sum_{h=1}^H[C_1^h\|\Delta\boldsymbol{W}_k^h\bold{E}\|_F+C_2^h\|\Delta\boldsymbol{W}_v^h\bold{E}\|_F]

の上限を最小化するように学習を行います. この式に含まれる \Delta\boldsymbol{W}\bold{E} はResAGで置き換えることが可能です. remaining conceptを含むtext embeddingの集合を \bold{E}_{\text{rem}}[2] として, 以下の損失を最小化します.

\mathcal{L}_{\text{rem}}(\mathcal{E}_{\text{rem}}) = \mathbb{E}_{\bold{E}_{\text{rem}}}\|R_{\text{tar}}(\bold{E}_{\text{rem}})\|_F

remaining conceptですが, 先行研究から少し変えています. target conceptのドメインに対してLLMで類似概念をたくさん用意し, すこからtarget conceptに最も類似するいくつかの概念を選んで使用します. これらのみを用いると過学習の危険があるのでノイズの注入とランダム補間を行います.

ロバストな消去

ResAGはtarget conceptをピンポイントで消去しますが, 非線形性によってtarget conceptのインスタンス数が限定的な場合に広範囲な消去を達成することが難しくなるので, 直接学習させることが困難になります. この問題が発生すると, ResAGのロバスト性や消去性能が低下してしまう可能性があるので対策を講じます.

そこで, Robust erasure via Adversarial Residual Embeddings(RARE)を提案しています. 順序は以下のようになります.

  1. 普通にResAGを学習する
  2. 学習したResAGに対して敵対的な埋め込みを \bold{E}_{\text{tar}} に加える形で学習する
  3. 再度ResAGを学習する

2番目についてみていきます. \mathcal{E}_{\text{adv}}=\{\bold{E}_{\text{adv}}^1, \ldots, \bold{E}_{\text{adv}}^N\} を学習可能な敵対的埋め込みの集合とします. \bold{E}_n'=\bold{E}_{\text{tar}}+\bold{E}_{\text{adv}}^n に対して

\min_{\mathcal{E}_{\text{adv}}}\dfrac{1}{N}\sum_{n=1}^N\mathbb{E}_{\mathcal{E}_{\text{tar}}}\|\boldsymbol{W}\bold{E}_n'+R_{\text{tar}}(\bold{E}_n')-\boldsymbol{W}\bold{E}_{\text{tar}}\|_F

で学習を行います. 最終的な損失は

\min_{R_{\text{tar}}}\mathcal{L}_{\text{era}}(\mathcal{E}_{\text{tar}}, \mathcal{E}_{\text{sur}})+\dfrac{1}{N}\sum_{n=1}^N\mathcal{L}_{\text{era}}(\mathcal{E}_{\text{tar}}+\bold{E}_{\text{adv}}^n, \mathcal{E}_{\text{sur}})+\lambda\mathcal{L}_{\text{rem}}(\mathcal{E}_{\text{rem}})

です.

実験

では, 実験で性能を確認していきます. 比較手法は以下の通りです.

  • celebrities and artistic styles erasure: FMN, ESD-x, ESD-u, UCE, MACE
  • adversarial prompt: AdvUnlearn

original SDとしてStable Diffusion 1.4を使い, 画像生成ではDDIMを50 stepsで生成します.

Celebrities Erasure

50の概念を同時に消します. 5つのprompt templateと5つのseedを使い, 1250枚の画像を生成して評価します. remaining conceptsに関しては, 100人のcelebritiesと100 styles, 64のword2vecのcharactersの3ドメインを考えます. 各ドメインのremaining conceptに対して5つのpromptと5つのseedで25枚の画像を生成します. また, remaining conceptにはMSCOCO-30kも一緒に使います. ResAGのモジュールのrankは s_1=16, \boldsymbol{U}_{3,c}, \boldsymbol{U}_{4,c} のrank s_2 は1とします. また, \eta=0.3, \lambda=10^5 です. 敵対的埋め込みに対する学習は N=16 個を5段階にわたって行います.

結果を見てみます. 論文ではFigure 5と記されていますが, 話の流れからFigure 3の間違いであることは明らかなのでFigure 3の結果を示します.

比較手法含めて多くの場合, Andrew Garfieldを消せています (論文では全てとしていますがFMNは微妙なところです). 一方で, 維持すべき概念の結果を見ると多くの場合失敗していることがわかります. 定量評価をみてみます. お馴染みの評価指標に加えてGIPHY Celebrity Detector (GCD)で分類した際のtop-1 accuracyをACCとして示します.

これを見ると, UCEと提案手法が非常に高い消去性能 (ACCが1%以下)を達成しています. ただ, 生成結果の比較を見ても他の定量評価の結果からもわかるように, remaining conceptsの維持性能は低い結果となっています. 逆にFMNは概念維持の観点ではかなりいい感じですがそもそも消去性能が低いという結果になっています. MACEはanchorとしてremaining concept (celebrity)やMSCOCOを使うのでそれらに対しては維持性能が高いですが, 他のドメインでは忘却が発生しています.

Artistic Styles Erasure

100 styleの同時消去をします. ほとんどの設定はcelebritiesと同じですが, \lambda=10^4, \eta=0.5 とします. また, 敵対的埋め込みに対する学習は16個を10段階にわたって行います. 早速結果を確認します. 論文ではFigure 3と書かれていますがどう見てもFigure 4です.

FMNはあまり消せている感じがないです. その他の手法は消せていそうです. ESUやUCEは他のstyleの維持に失敗しています. また, MACEはcelebrityの維持に失敗しており, 消去と維持のどちらも成功しているのは提案手法のみということになっています. 続いて定量評価を確認します.

概ね結果はcelebritiesの場合と同じであることがわかります.

Explicit Contents Erasure

NSFWを消去します. I2Pを用いて生成を行ういつもの評価です. 設定は s_1=64, s_2=4, \eta=3.0, \lambda=10^4 です. 敵対的埋め込みに対する学習は64個を20段階にわたって行います. 定量評価を確認します.

提案手法は全体の検出数が最小です. MSCOCO-30kのFIDはFMNとMACEに劣っていますが, FMNは消去性能が非常に悪く, MACEはleakの可能性があるのを念頭に置くと, 提案手法は良い性能であると言えます.

Robust Erasure

続いてattackの手法に対する性能をみます. ここではRing-A-BellとUnlearnDiffAtkを使います. 評価指標はASRです.

結果を見るとすごい性能に見えます. 完全に攻撃を防いでるケースも多くあり, 非常に優秀です. 実際の攻撃例を示してみます.

これからもわかるように, UnlearnDiffAtkの攻撃を防げています.

Ablation

ablationをします. なぜかFIDはやめてMSCOCO-1kでのKIDに評価指標が変わっていることに注意です. ResAGがないというのは \boldsymbol{U}_{3,c}, \boldsymbol{U}_{4,c} の学習の有無です.

ablationなので基本的には「これを無くしたら性能が悪くなった」という話になります. s_1 の値は1から16に増やすと性能が上がることが確認できたが, それを超えて増やしても性能への影響は限定的です. なお, s_2 では以下に示すように1が最も良い結果となっています.

一方でこれまでみた結果では例えばNSFW消去では s_2=4 なので概念の種類依存と言えそうです. そのことへの言及は確認できず (この部分はAppendixに書かれているので仕方ないですが, 査読での指摘もないです), 少し恣意的に見えます. \eta, \lambda も値を変えてやっていますが同じことが言えます. あくまでcelebritiesの消去でのみ有効なablationです.

思ったこと

  • ミスが多いなと思います. notationやFigureのリンクミスなどが目立ちますが査読での指摘がなく, 査読者は読んだのでしょうか...
  • それに関して結構読みにくいです. 定理2で登場する K は, 論文では R で書かれており, ここでは定数として用いているにもかかわらず, 以降では写像としての R が登場し, 添字の有無で判別できるにしても優しくないなと思いました.
  • 提案手法の前段階にあった定理2の \bold{E}_{\text{rem}} が混合ガウス分布に従うという仮定が満たされているかは結局確認できませんでした. 消去結果が成功しているからいいだろという考えなのでしょうか.
  • cross attentionの分析はとてもいいと思いましたが, この手法に限らず何かを取り付けてそれを更新する形で消去を行うと取り付けたモジュールを外せば概念が復活するので消去したと言えるのか...という意見もあるようです. これはBlack-boxでの消去なのかWhite-boxでの消去なのかという問題設定に依存する話なので最初に述べておくべきなのかもしれません.

参考文献

  • Byung Hyun Lee, Sungjin Lim, Seunggyu Lee, Dong Un Kang, and Se Young Chun. Concept pinpoint eraser for text-to-image diffusion models via residual attention gate. In The Thirteenth International Conference on Learning Representations, 2025.
脚注
  1. 論文では V(\bold{E}) と書かれていますがこのあと V が登場しない上に \bold{V} とboldの有無はありつつも記号が被っているので f の間違いだと思います. ↩︎

  2. 論文中では \bold{E}_{\text{anc}} ですが, おそらく同じだと思います. ↩︎

Discussion