Zenn
🤖

Stable Diffusionからの概念消去⑱:ConceptPrune (論文)

2025/03/16に公開

ConceptPrune: Concept Editing in Diffusion Models via Skilled Neuron Pruning (ICLR2025)

引き続きICLR2025採択論文をチェックします. 今までの概念消去の研究は誤差逆伝播または閉形式におけるパラメータの変更がメインでした. 今回は枝刈りを取り入れた手法をみていきます.

書籍情報

Ruchika Chavhan, Da Li, and Timothy Hospedales. Conceptprune: Concept editing in diffusion models via skilled neuron pruning. In The Thirteenth International Conference on Learning Representations, 2025.

関連リンク

モチベーション

枝刈りは通常モデルの軽量化に用いられる手法です. モデルの出力に対する寄与が小さい (i.e. 不要な)パラメータなどを削除することで軽量化を図ります. 逆にLLMなどを含め, モデルにとって重要なパラメータも存在しそうだということがわかっています. 日本語圏ではSuper-weightなどが話題になったりしましたが, そのような感じです. 著者らは以下の問いを提示しています.

Can we remove an undesired concept from a pre-trained DM by simply finding neurons specific to this concept, and pruning them?

そして答えはYesであると主張しています. この論文ではそれを実現する手法についてみていきます.

LDMのFFN

ここで対象にするのはLDM, すなわちStable Diffusionです. Stable DiffusionのU-Netには2つのReNet sblockが2つのTransformer blockにサンドイッチされる形式です. Transformer blockはself-attention, cross-attention, GEGLU付きFFNから構成されています. 既存研究の多くの場合はここのcross-attentionに目を向けて概念消去を行います. 一方でNLPの分野ではFFNに知識が蓄積されていそうだとか, 言われていたりします. そこで, この研究ではFFNに対象を向けます.

FFN layer ll のtimestep t におけるprompt p を含んだ入力を \bold{z}_{t}^l(p)\in\mathbb{R}^{d\times N} と表すことにします. N はlatent tokenです. Stable DiffusionのFFNは以下のように書けます.

\begin{align*} &\bold{h}_t^l(p)=\sigma(\bold{W}^{l,1}\cdot\bold{z}_t^l(p)) \\ &\bold{z}_t^{l+1}(p)=\bold{W}^{l,2}\cdot\bold{h}_t^l(p) \end{align*}

です. ここで, \bold{W}^{l, 1}\in\mathbb{R}^{d\times d'}, \bold{W}^{l, 2}\in\mathbb{R}^{d'\times d} はそれぞれ1・2番目の線形層の重み行列です. \sigma はGEGLUです. 以降では, \bold{W}^{l, 2}[i, :]\bold{W}^{l, 2}i 番目の行, \bold{W}^{l, 2}[i,j]\bold{W}^{l, 2}ij 列の要素を指すこととします.

枝刈りの戦術: Wanda

まず, LLMにおける枝刈り戦術のひとつであるWandaについて簡単に確認します. というのも, 提案手法ではこれを拡散モデルに適用しようとしているからです.

https://openreview.net/forum?id=PxoFut3dWW

まず, 線形層を \bold{W}\in\mathbb{R}^{d_{out}\times d_{in}}, 入力を \bold{X}\in\mathbb{R}^{B\times d_{in}} とします. B はデータの数です. 重みの絶対値が小さいものから順番に枝刈りを行うmagnitude-based pruningとは異なり, ニューロンの活性化に対する重みと特徴量の大きさの結合効果を推定します. 各重みの重要度はその大きさと対応する入力特徴の次元ごとの \ell_2 normとの要素ごとの積として計算されます.

\bold{S}(\bold{W}, \bold{X})=|\bold{W}|\odot(\bold{1}^{d_{out}}\cdot\|\bold{X}\|_2)\in\mathbb{R}^{d_{out}\times d_{in}}

ここで, |\cdot| は絶対値をとるものです. これによって求めたWanda Score \bold{S}(\bold{W}, \bold{X})[i,:] の低い k% の重みを0にすることで枝刈りを実現します. すると, 各行で k% の重みが消えるので \bold{W} がスパースになります. この \bold{W} の要素は"weight neurons" (重みニューロン)と呼ばれ, 通常のニューロンとは区別します. Wandaでは特徴量ノルム行列の計算にはcalibration setのみを使用し, 1回の順伝播のみ行うので非常に簡単に枝刈りを行うことができます. 以降では特定の概念を消去するために各行の上位 k% の重みニューロンを枝刈りする方法についてみていきます.

LDMにおけるSkilled Neuronの特定

\mathcal{P}^*=\{p_1^*, \ldots, p_M^*\}\mathcal{P}=\{p_1, \ldots, p_M\} の2つのcalibration promptの集合を考えます. p_i^*, p_i はそれぞれtarget conceptとrefernce conceptを表します. 例えばtarget conceptとしてGogh styleを考えると, p_i^* は"a <object> in Van Gogh style"となり, p_i は"a <object>"です. ここでのobjectはcatやdogといった一般的なカテゴリを表します.

target conceptとrefrence promptの集合に対応するニューロン活性を収集します. これはFFNの式

\begin{align*} &\bold{h}_t^l(p)=\sigma(\bold{W}^{l,1}\cdot\bold{z}_t^l(p)) \\ &\bold{z}_t^{l+1}(p)=\bold{W}^{l,2}\cdot\bold{h}_t^l(p) \end{align*}

によって得られます. これを, \bold{H}_t^l(\mathcal{P}^*)=[\bold{h}_t^l(p_1^*)^\top, \ldots, \bold{h}_t^l(p_M^*)^\top]\bold{H}_t^l(\mathcal{P})=[\bold{h}_t^l(p_1)^\top, \ldots, \bold{h}_t^l(p_M)^\top] という行列の形にします. ここで \bold{H}_t^l(\mathcal{P}^*), \bold{H}_t^l(\mathcal{P})\in\bold{R}^{(M*N)\times d'} です. この処理では1 promptにつき1回の順伝播が必要です.

両方のニューロン活性を収集した後, FFNの式における線形重み \bold{W}^{l,2} の重要度スコアをWanda Scoreと同様に計算します.

\begin{align*} \bold{S}(\bold{W}^{l,2}, \bold{H}_t^l(\mathcal{P}^*))&=|\bold{W}^{l,2}|\odot(\bold{1}^d\cdot\|\bold{H}_t^l(\mathcal{P}^*)\|_2) \\ \bold{S}(\bold{W}^{l,2}, \bold{H}_t^l(\mathcal{P}))&=|\bold{W}^{l,2}|\odot(\bold{1}^d\cdot\|\bold{H}_t^l(\mathcal{P})\|_2) \end{align*}

以降では簡単のために重要度スコアをそれぞれ \bold{S}_t^l(\mathcal{P}), \bold{S}_t^l(\mathcal{P}^*) と表すことにします. 重要度スコアを計算したのち, target concept promptとreference promptにそれぞれ対応する重要度スコアを比較することでskilled neuronを特定します.

重要度スコアでは, Wandaと同様に重み行列全体ではなく, 重み行列の各行における重要度スコアを考えます. ここでは行 \bold{W}^{l,2}[i,:] に対してtarget conceptを生成するtop-k% の重みニューロンを以下で定義します.

\bold{I}_{t}^l(\mathcal{P}^*)[i, j]=\begin{cases} 1 & \bold{S}_t^l(\mathcal{P}^*)[i, j]\in \text{top-}k\text{\% of } \bold{S}_t^l(\mathcal{P}^*)[i,:] \\ 0 & \text{otherwise} \end{cases}

これによって, \bold{I}_{t}^l(\mathcal{P}^*)\mathcal{P}^* のためのmaskとなります. \mathcal{P}^* には \mathcal{P} と比較して追加の不要なtarget conceptが含まれているので, \bold{I}_{t}^l(\mathcal{P}^*) はtarget conceptとreference conceptの両方を生成する重要なニューロン集合となります. ここからはこれらのニューロンをフィルタリングすることでtarget conceptとreference conceptを個別に生成するために分離させます. ここで, これまで出てきたskilled neuronについて定義します.

\bold{W}^{l, 2} でcharacterizedされた線形層に対して 重みニューロン \bold{W}^{l, 2}[i, j]\bold{I}_{t}^l(\mathcal{P}^*)[i, j]==1 \text{かつ} \bold{S}_t^l(\mathcal{P}^*)[i, j]>\bold{S}_t^l(\mathcal{P})[i, j] を満たす場合に時刻 t におけるskilled neuronと定義する.

これは, ある重みニューロンが \bold{W}^{l, 2} の行内でtarget prompt \mathcal{P}^* に対するWanda Scoreのtop-k%に位置する場合に, そのニューロンが不要なtarget conceptまたはreference conceptの生成に寄与していると考えます. target conceptのWanda Scoreがreference conceptのそれを上回る場合にはtarget conceptの生成に影響を与えていると考えます.

その後, \bold{W}^{l, 2} に対して時間依存のbinary mask \bold{M}^l_t を形成します.

\begin{align*} \bold{M}^l_t=\begin{cases} 1, & \text{重みニューロン} \bold{W}^{l, 2}[i,j] \text{がskilled neuronのとき} \\ 0, & \text{そうでないとき} \end{cases} \end{align*}

\bold{M}^l_t\bold{I}^l_t の部分集合で, target conceptによって強く活性化されるニューロンのみが保持されます.

ここまでのskilled neuronは時間依存の設定での話でした. 一方でDiffPrune

https://openreview.net/forum?id=d4f40zJJIS

では, 枝刈りを特定のtimestepにおける相対的な重要度スコアを集約することで可能であることを示しています. これを利用して, 最初の \hat{t} 回のdenoisingを用いてbinary maskを求めます. このことから枝刈り後の重み行列 \hat{\bold{W}}^{l, 2} は論理和 \vee と否定 \neg を用いて

\hat{\bold{W}}^{l, 2}=\bold{W}^{l, 2}\odot(\neg(\vee_{t=T,T-1,\dots,T-\hat{t}}\bold{M}^l_t))

と書けます. 学習済みの拡散モデル f_{\theta} の全ての重みのうち, \bold{W}^{l, 2} のみが \hat{\bold{W}}^{l, 2} に変更されます.

実験

いつも通り実験設定を確認した後に結果を見ます.

実験設定

普段とは異なり, Stable Diffusion 1.5を用います. 16個のFFNがあります. FFNのうち2番目の線形層を枝刈りします \bold{W}^{l, 2} のことです. これは色々なablationの結果から選ばれているようです. Appendixではablationの結果が示されていますが, LLMでの研究同様2番目が適しているとのことです.

skilled neuronの選定には2つのハイパーパラメータ k, \hat{t} を決める必要があります. ここではまず各概念に対して k\in[0.5, 5] で選び, それに応じて \hat{t} を決める方法で決定します. 概ねの場合, \hat{t}=10 で十分であることがAppendixの表から示されています. これはstyleやobjectはdenoiseの初期段階で形成されるというDDPM以降の研究で何度も示された事実と一致しています (論文では「示唆されている」程度の記述です).

比較手法はlightweight approachとしてUCE, FMN (Forget-Me-Not), MACE, Receler, SPMです. 他にもConcept Ablation (CA), ESD, Selective Amnesia (SA), Scissorhands (SH), AdvUnlearnを用います. ただ, これらはfine-tuningなどが必要な重めの手法なので間接的な比較にとどまるようです. 「ICLRに出す論文としては少なくないか?」と思ったのですが, 査読者全員から同じように指摘されていました.

結果: styleの消去

Van Gogh, Claude Monet, Pablo Picasso, Leonardo Da Vinci, Salvador Daliの5 styleを考えます. これらのアーティスト名を付加したpromptを50個ChatGPTに生成させ, CLIP Similarity とCLIP Scoreを測定します. 一般的なCLIP Scoreの使い方ではないようにみえますがrefrenceや定式化が示されていないので扱い方がわかりにくいです. 文章では前者を「生成画像とpromptの類似度を測定している」, 後者を「概念編集後の生成画像がpromptとどの程度一致するかを評価し, 類似度がOriginal SDを超えた場合にペナルティを課したもの」と説明されています. 結果を見てみます.

提案手法は消去性能が優れているというような結果でしょうか. 文章では無関係な概念の保持についても言及されていますが優位に立てる結果ではないでしょう (competitiveでも怪しいような気がします). 定性的な比較もします.

確かにこの結果を見ると概念消去に関しては性能の優位性がありそうです. 例えばUCEやACはVan Gogh特有の星の描き方が出ていたりしますし, FMNに関してはDa Vinciの生成例がかなり微妙です. cherry pickの可能性を含めても定量評価と併せて考えることで消去については優位性を認めることができます.

結果: NSFWの消去

もうお馴染みになっているI2Pを使用します. Stable Diffusion 1.5からの減少率を示します.

減少率なので右までバーが伸び切っているといい指標です. かなりの削減率 (94.1%)になっています. ESDは88%, UCEは85.5%なのでかなりの改善です. なぜ2番目のSHが触れられていないのかは謎です.

結果: objectの消去

まずは単一概念の消去結果を見ます. Imagenetteを用います. 各クラスにつき500枚を生成してtop-1 accuracyを見ます.

他の手法と比較しても他の概念への影響が少なく削除できているように見えます. ただboldの付け方も適当で, 印象操作っぽく見えます.

続いて複数概念の結果を見ます. binary maskは各概念に対して個別に生成されますが, skilled neuronの結合をとることで複数概念消去を実現します.

結果を見た感じだと正直同じくらいかなというように思えます. 論文では"ConceptPrune demonstrates comparable erasing performance while excelling at retaining unrelated concepts."と述べられていますがこの結果からそのように主張するのは無理があるのではと思います. せいぜい「同程度の消去能力, 概念保持能力だが攻撃には強い」程度でしょう. 実験設定もよくわからないのでこれ以上は何も言えない気がします.

結果: 攻撃に対する堅牢性

まずはstyleとobjectに対する堅牢性を確認します. ここではUnlearnDiffAtkとConcept Inversionを用いて確認します. まずはUnlearnDiffAtkです.

styleやobjectの消去においてはASRは全体的に低いことがわかりますが, AdvUnlearnには劣ります. 少なくとも消去性能ではSoTAではないですが, 論文のロジックではLight-weightの中ではSoTAと言いたいのだと思います. Concept Inversionを用いた結果は

のようになっています. コロコロ比較対象が変わっているのが非常に不親切ですね. 少なくともこの比較群の中ではいい性能だと言えます.

続いてNSFWに対する堅牢性を見ます. これまではwhite-boxの攻撃手法でしたがここではblack-boxの攻撃手法も含みます.

Black-boxであるRing-A-BellやMMAはOriginal SDからの減少率を示しています. これを見る限りだとblack-boxに対しては強いですがwhite-boxに対してはそこまで強くないように見えます. 限定的な結果かなと思います.

結果: skilled neuronの密度

skilled neuronの密度を測定します.

これを見るとすべてのケースでFFN重み行列の3%未満であることがわかります. これは, 概念生成が非常に小さな部分空間に依存していることを示唆していると著者らは主張しています. FFNに対する割合は3%未満ですが, これを拡散モデル全体に対する割合にすると0.12%未満なので非常に小さいと言えます.

思ったこと

  • 非常に重大な懸念として, 関連研究と比較手法の欠如があります. これは査読者にも指摘されていましたが, この研究が初めての枝刈りによる概念消去の研究というわけではありません. 例えばNeurIPS2024 Workshop for SafeGenAi採択のP-ESDなどが挙げられます (実際に査読者が示した3つの論文のうちひとつはこれです). 著者らはコードが非公開であることを理由に実験を行っていませんが, ならば実験設定を揃えて別の実験を行うべきであると考えられます. ちなみに (査読者が挙げなかった論文で)他にはLocoEditのSection 7で似たような実験が行われています. こちらはコードも公開されているので比較が可能です.

  • skilled neuronの密度ですが, これはグラフを見ればわかるように概念依存です. この辺りの言及が欲しいなと思いました

  • styleやNSFWといった, ESDではESD-uで更新する概念の種類の場合, skilled neuronの密度が山の形をしている点が興味深いです.

  • 先述の通り, 比較対象がコロコロ変化するのは恣意的な比較に思えます. ICLRはそれでもいいのでしょうか (個人的には包括的な比較が望ましいです)...

参考文献

  • Ruchika Chavhan, Da Li, and Timothy Hospedales. Conceptprune: Concept editing in diffusion models via skilled neuron pruning. In The Thirteenth International Conference on Learning Representations, 2025.
  • Mingjie Sun, Zhuang Liu, Anna Bair, and J Zico Kolter. A simple and effective pruning approach for large language models. In The Twelfth International Conference on Learning Representations, 2024.

Discussion

ログインするとコメントできます