🦁

Mamba内部の知識とその編集(COLM2024)

2024/10/09に公開

Locating and Editing Factual Associations in Mamba (COLM2024)

COLM (Conference on Language Modeling)という新設の会議があります.

https://x.com/COLM_conf/status/1749881065055924363

新設なので今年が最初 (第1回)なのですが, そこに採択された論文について見てみます.

タイトルは"Locating and Editing Factual Associations in Mamba"です. これはNeurIPS2020に採択された"Locating and Editing Factual Associations in GPT"を意識したものでしょう.

https://papers.nips.cc/paper_files/paper/2022/hash/6f1d43d5a82a37e89b0665b33bf3a182-Abstract-Conference.html

関連リンク

はじめに

論文の内容に入る前に, タイトルの元ネタとなっている"Locating and Editing Factual Associations in GPT"ではどのようなことをしているのかを軽く確認します. ここでは非常に簡単に述べるので詳しくは論文の方を参照していただければと思います.

タイトルにあるように言語モデルとしてGPTを用いて, 事実に関する事柄 (いわゆる知識)がモデル内部でどのように結びつき, 想起されているかを調べ, その編集方法について提案したのがこの論文です. 因果推論を用いて, 特定の知識に関連する部分を特定します. それを表したのが下の図です.


Locating and Editing Factual Associations in GPTより引用

この結果から, 真ん中あたりのMLP (FFN)に知識が蓄積されていそうだということが[1], 最後の方のself attentionでその知識がトークンにコピーされていそうだ, ということがわかりました. それでMLPを編集しようという (ROME)のがこの論文での提案手法です.

日本語では言語処理学会 (NLP2024)の『言語構造に制約されない大規模言語モデルの知識編集』という論文が詳細な説明をわかりやすくしています.

https://www.anlp.jp/proceedings/annual_meeting/2024/pdf_dir/P10-27.pdf

ROMEの論文ではGPT2-XLを対象としていますが, 多くの場合他のモデルに拡張できます. 日本語記事だと以下のようなものがあります

https://zenn.dev/ohtaman/articles/llm_finetune_lora

https://tech.preferred.jp/ja/blog/llm-fine-tuning-for-domain-knowledge/

以上のことから, 非常に大雑把に言うと, GPT2だけではなく, さらに大きなTransformer decoder言語モデルにもこの話が当てはまることがわかります. ではMambaではどうなのかというのがこの論文での話です.

Mamba

MambaとTransformerではアーキテクチャが大きく異なります. Mambaは同じくCOLM2024に採択されている, Transformerを代替しうると注目されている状態空間モデル (State Space Model, 以降SSM)の発展系です.

Transformerは推論時に系列長の2乗に比例する計算量 (すなわち \mathcal{O}(N^2)) を要するため, Transformerの登場以来これは非常に大きな問題です. 特に長い系列予測を要求される時系列予測では多くの計算量改善手法が登場していますが, 改善するために導入したモジュールなどに起因して余計に時間がかかるようになってしまっています[2].

Transformerの計算量が大きいならTransformerをやめればいいのですが, 性能が非常に悪化します. 特に, Transformer以前のRNNなどは性能が悪い上に勾配消失などの問題を抱えています (LSTMやGRUで解消されます). さらに, 逐次処理であるために並列処理ができないという非常に大きな欠点を抱えています (この欠点を解消したのがTransformerで, Attention Is All You NeedのIntroの最終段落に成果として書かれるほどです).

Transformerのいいところをそのままに推論時の計算量を削減したのがSSMになります. 詳しくはそれぞれの論文を参照してほしいのですが, LSSL, S4, H3, HyenaのようなSSMベースのモデルはSSMの計算の畳み込みの部分の高速化にフーリエ変換を用いています. これをよりsoftとhardの面から最適化したものがMambaです (H3などでもFlashConvのような最適化が使われていますがそれとは別です).

アーキテクチャとしてはH3とGated MLPの組み合わせのような構造をしています.


Mamba: Linear-Time Sequence Modeling with Selective State Spacesより引用

この論文では, 以下のような構造を使用します. notationもこの図に従います.

なお, Mambaに関しては以下の記事などが参考になります.

https://qiita.com/peony_snow/items/649ecb307cd3b5c10aa7

Locating Key States for Factual Recall

では, 論文の内容に入っていきます.

まず, 言語モデルが知っている事実 (s, r, o) を選びます. ここで, r は subject entity s とobject entity o を関連づける「関係」です. 例えば, s=Michael\ Jordan, r=professionally\ played, o=basketball となります.

この正しい事実予測に対する各状態の寄与を推定するために, 3つの異なる実行にわたってモデルの活性化を収集します.

  1. clean run
  2. corrupted run
  3. patched run

clean run (G)

clean runでは, 単にプロンプトを入力します. 例えば, x=(s, r)=Michael\ Jordan\ professionally\ played です. あとで解析するためにclean runで得られた隠れ状態の \{h_i^{(\ell)}, a_i^{(\ell)}, s_i^{(\ell)}, g_i^{(\ell)}\mid i\in[1, T], \ell\in[1, L]\} を保存します.

corrupted run (G^*)

s を別の事実の s^* (Pele) とスワップします. これは, 言語モデルが異なる回答 o^* (soccer) をするようにします. すなわちプロンプト x^*=(s^*, r) を与えます. この入れ替えは既存研究で行われている, 摂動を用いる手法で, ガウシアンノイズで壊すROMEで起こりうる, 領域外の状態をモデルの計算に導入することを回避できます.

patched run (G^*[\leftarrow h_i^{(\ell)}])

x^* を与えるところまではcorrupted runと同じです. h_i^{(\ell)} をcleanな状態と置き換えることで介入を行います. パッチを当てた状態は, それに依存するすべての状態を変化させる可能性があります. 以下の図がわかりやすいです.

分析結果

さて, p(o), p^*(o), p^*[\leftarrow h_i^{(\ell)}](o) をそれぞれ G, G^*, G^*[\leftarrow h_i^{(\ell)}] において正しい答え o に割り当てられた確率とします. h_i^{(\ell)} の寄与を測定するために indirect\ effect (IE)を以下のように定めます.

\mathrm{IE}_{h_i^{(\ell)}}=\dfrac{p^*[\leftarrow h_i^{(\ell)}](o)-p^*(o)}{p(o)-p^*(o)}

以下の図では, Relations Datasetの400の事実について, h_i^{(\ell)} を復元した場合の間接効果をプロットしています.

last tokenについて, late siteで高いIEが確認できます. これはそこでcleanな h_i^{(\ell)} を復元することで G からの計算の大部分が復元できるので自然なことです. しかし, early site (last subject tokenにおける中間部分のレイヤー)でも高いIEが確認できます. これは, ROMEの論文での実験結果と同じもので, GPTでも見られた結果です.

他の変数についても見てみます.

o_i^{(\ell)} は先ほどの h_i^{(\ell)} と非常に似た結果です. また, (c)ではSelective-SSMの出力 s_i^{(\ell)} が後ろの層でのみ高いIEを持ち, GPTのattentionと同じような挙動が確認できます.

自己回帰のTransformerとこの結果を比較するために, 同じ規模のTransformer言語モデル (Pythia-2.8b)でも実験を行いました.

これと比較すると, MambaとTransformerで異なる点が浮かび上がります. それは, TransformerはMLPの出力がearly siteに効果を持ち, late siteでは効果がないです (図b, 同じ部分は色の変化がありますが, TransformerのMLPのように支配的ではないということだと思います). しかし, Mambaはそれに対応する結果が得られません. ここから, 「Mambaのどのパラメーターが事実の想起を媒介するのか?」という疑問が生じます.

その疑問を解消するために, ROMEの論文でやっていることと同じことをMambaでも再現します. すなわち, 因果グラフから特定の経路を切断し, その影響を観察することで経路特有の効果を探ります. ここでは, 事実の想起の際に W_g, Conv+SSM, W_o によって処理される状態 g_i, s_i, o_i の影響を理解することに興味があります.

まず, G^* においてトークン位置 i で全ての s_i 経路からの寄与をキャッシュし, s^*=\{s_i^{*(\ell)}\mid\ell\in[1, L]\} とします.

次に, G^*[\leftarrow h_i^{(\ell)}] では G からキャッシュされた h_i^{(\ell)} を対応する状態に復元します. ただし, 追加の変更を加えます. 具体的には, s_i 経路の寄与を理解したいので, それらの経路は切断し, G^* からキャッシュされた s_i^* をpatchingします. 同様の実験を g_i, o_i についても行います. なお, o_i^{(\ell)} を切断すると s_i^{(\ell)} および g_i^{(\ell)} も切断されることには注意が必要です.

図にすると以下のようになります.

Relations Datasetから400件をランダムサンプリングした結果を示します.

紫色のバーと緑色、赤色、青色のバーとのギャップを見ることでグラフを読み取ります. 大きいギャップはそれぞれ W_g, Conv+SSM, W_o が強い媒介的役割を果たしていることを示します. last subject tokenの初期の層ではConv+SSMと W_g が強い役割を果たしていますが, W_o はそれ以上の働きを示しています.

後ろの層でもやはり W_o が大きな働きを示します. この結果は, 事実予測において W_o が重要な役割を担っていること, 初期の層ではMambaではそれぞれのパラメータの役割がTransformerほど分離されていないことがわかります.

Editing Facts With ROME

ここではROMEをMambaに適用することを考え, 事実の編集ができるかどうかを確かめます. ROMEは任意の線形変換を連想記憶として考えることができ, keyの集合 \mathcal{K}=[k_1|k_2\ldots] を対応するvalue \mathcal{V}=[v_1|v_2|\ldots] にマッピングし, これを用いてTransformer LMの編集を行います. Mambaの線形変換の集合に対してこれを行なって結果を観察します.

入力はプロンプト x=(s, r) で, ここで s はsubject entity (Emmanuel\ Macron), r はrelation (is\ the\ President\ of\ ) です. counterfactual object o^*(England) も受け取り, 正しいobject o(France) を置き換えることを目的とします.

ROMEは層 \ell のlast subject tokenのMLPのdown projection行列 W_{down}^{(\ell)} にrank 1の更新を生成します. この行列は連想記憶の役割を果たしており, W_{down}^{(\ell)} への入力をkey k^* として考えます. 勾配降下法にしたがってvalue v^* を計算し, W_{down}^{(\ell)} の出力とすることでモデルが o^* を出力するようにします.

Mambaには3つの射影行列があります. Conv+SSM経路である W_a^{(\ell)}, gating経路である W_g^{(\ell)}, MambaBlockの最後の出力となる W_o^{(\ell)} です.

編集性能を測るために, CounterFact datasetを用います. 20Kの (s, r, o\rightarrow o^*) で構成されています. o はプロンプト x=(s, r) に対する正しい答え, o^* は新しい答えです. 実験では2000例を用いて行います. 評価指標はROMEの元論文と同じものを用います. 具体的には以下の3指標の調和平均を計算します.

  1. Efficacy (ES): effective とは, 編集ののちに言語モデルが p(o^*)>p(o) と確率を割り当てることを言います. Efficacyは, effectiveな編集の割合を反映します.
  2. Generalization (PS): 編集が成功した場合, (s, r) を言い換えたプロンプトに対しても一貫性のある回答をすることが期待されます. 各リクエスト (s, r, o\rightarrow o^*) に対して, 異なる言い換えを行った x_p\sim \mathcal{P}_r(s) も一緒に p(o)>p(o^*) をチェックします.
  3. Specifity (NS): 編集は特定の \mathcal{P}_r(s) に対して行われるべきであり, o^* に隣接したsubjectである s_n に対しては何の変化がないことが期待されます. これを評価するには p(o_n)>p(o^*) を測定することでできます.

結果を確認します.

この結果を見ると, 初期層から中間層にかけてをROMEで編集を行うと高いスコア (S)が得られています. これはTransformerの言語モデルでの観察結果と同じです. しかし, 編集の性能は場所によることが確認できます. 例えば, W_g^{(\ell)}W_o^{(\ell)} に関しては, (S)と(PS)について43番目の層以降は大きく低下していることがわかります. これは先ほどのIEを計測した実験結果と一貫したものです. また, W_a^{(\ell)} に関しては, 初期層では (PS)が低い結果になっています. 一方で, 初期層でも W_g^{(\ell)}W_o^{(\ell)} を編集すると高い (PS)が得られます.

では, MambaにROMEを適用する際に正しい場所はどこでしょう? 間接効果の図を再掲します.

この図では, W_g^{(\ell)} が適当に見えます. g_i 状態の因果効果はtransformerのMLPの挙動と同様に, last subject tokenに集中しているからです. これと一致するのは, transformerの W_{down}^{(\ell)} が残差接続を通してのみattention moduleに入ることが, W_g^{(\ell)} の場合はConv+SSMと対応しているというアーキテクチャ上の事実です. 実際に, ROMEは W_g^{(\ell)} を修正することで事実をうまく挿入できることがわかりました.

一方, 先ほどの結果では中間層のgatingは (ES)と(PS)が急激に低下します. これはいくつかの層で W_g^{(\ell)} が信頼できない部分である可能性があります. それに加えて先ほどの実験では, W_o^{(\ell)} が最高性能であることがわかります. このことは, 状態 o_i が状態 g_i より強い因果影響がlast subject tokenにある事実と一致します.

このことから著者らは W_o^{(\ell)} の強いパフォーマンスは, 初期・中間層と後ろの層とを分ける役割に起因すると仮説を立てました.

Linearity of Relation Embeddin (LRE)

activation patchingによって, 言語モデルのどの部分に事実が位置しているかを特定することができます. 我々はプロンプト x=(s, r) が与えられた際に, 言語モデルがどのようにこの情報を抽出しているのかを理解することに興味があります. これまでの実験から, Mambaの初期・中間層と後ろの層で役割が明らかに分かれていることが示されています.

Transformerの言語モデルでは, subject entityの表現 \bold{s} は, last subject tokenの位置において, 初期から中間層におけるMLPによって媒介される強化プロセス (enrichment process)を経ます. このプロセスではsubject entity \bold{s} に関連する様々な事実や属性が \bold{s} に付加されます.

この事実は以下の論文によるものです.

https://aclanthology.org/2023.emnlp-main.751/

次に, last subject tokenの位置で, attention moduleはenrichedな \bold{s} に対してクエリを実行し, プロンプト x=(s, r) に対する答えを抽出します. 既存研究では, 特定の関係 r に対してenrichedば \bold{s} に対して実行されるクエリ操作を LM computation F の一階テイラー展開を用いて以下のように近似します.

F(\bold{s}, r)\approx \beta \mathrm{J}_\rho\bold{s}+b

ここで

\mathrm{J}=\mathbb{E}_{\bold{s}_i,r}\left[\dfrac{\partial F}{\partial\bold{s}}|_{(\bold{s}_i,r)}\right], \qquad b=\mathbb{E}_{\bold{s}_i,r}\left[F(\bold{s}, r)-\dfrac{\partial F}{\partial\bold{s}}|_{(\bold{s}_i,r)}\right]

\beta はスカラー, \rho\mathrm{J} のrankです.

著者らはこれを用いてMambaの事実関係のデコードの複雑さを探求します. \beta, \rho, \ell (enriched \bold{s} を抽出するレイヤー)はハイパーパラメータですが, これらはグリッドサーチによって探索します.

\mathrm{J}, b を5サンプルで平均して計算した結果を示します. faithfulnessはLM computation F(\bold{s}, r) を単純なアフィン変換である LRE(\bold{s}) に置き換えた場合に正しく取得できる事実 (s, r, o) の割合を表します.

赤い線はrandom choiceです. 結果を詳しくみます. 26ある事実のうち, 10件の事実のみが線形LREで50%以上を達成できています. 比較のためにPythia-2.8bでも同じ実験をしたところ, 11件でした. MambaよPyhtiaの両方でLREはユニークな回答の数が大きい関係に対しては良好な忠実度を達成できないことがわかります. この結果はLLaMAでの先行研究の結果と一致し, Transformer言語モデルと同様に, Mambaにおける事実知識もrelationごとに異なる方法で表現されている可能性があることを示唆しています.

Attention Knock-out in Mamba?

attention moduleはTransformer言語モデル内の異なるトークン位置間の情報の流れを媒介しています. Attention "Knock-out"実験では, 特定のエッジ (k 番目のトークンから q 番目のトークン)を通る情報が特定のattention headを介してブロックされることで, そのエッジを通る重要な情報が流れているかを理解しています. これは, いわゆる因果媒介分析の一種と見ることができます. Mambaでは過去のトークンの情報は 状態 s_i に保持されていて, Conv+SSMで処理されます. MambaでもAttention "Knock-out"実験と同じようなことを行って事実情報の移動を見ます.

実際には, 非線形な操作であるConv+Selective-SSMによって, 完全に情報を取り除くことは困難です. しかし, Conv+SSMの操作で k 番目のトークンから全ての将来のトークンへの情報の伝播を平均消去によってブロックすることは可能です. 具体的には, あるレイヤー \ell に対して, a_k^{(\ell)} を, a_k^{(\ell)}=\mathbb{E}[a^{(\ell)}] とします. ここで, \mathbb{E}[a^{(\ell)}] はWikiText-103から収集した10000トークンを用いて計算された a^{(\ell)} の平均です. この介入は厳密にはエッジの切断と等価ではないですが, それを考慮しても実験結果はTransformer言語モデルで見られたものと類似していることが確認できます.

Relations datasetsから6つの事実関係にわたって700の事実をランダムに取ってきます. これらのデータに対して特定のレイヤー \ell の周囲10レイヤーのwindow内で, subject, non-subject, prompt-lastのトークン位置の情報伝播をブロックします. 特定のレイヤートークン位置 \ell -k でのConv+SSM情報の流れをブロックする効果は, p(o) がどれだけ変化したか (相対変化)によって測定できます. 具体的には

\dfrac{p(o\mid a_k^{(\ell)}\coloneqq\mathbb{E}[a^{(\ell)}])-p(o)}{p(o)}

です. さて, 結果を見てみます.

著者らはここから3つの結論を導いています.

Mambaは初期~中間層で関係特有の情報を将来のトークンに伝播する

紫色の線をみます. ここから, 初期~中間層でnon-subjectの情報の流れをブロックすることで p(o) を最大で50%減少させることができることがわかります. non-subjectトークンは関係 r を特定するために使用されますので, Mambaは初期~中間層で関係特有の情報を将来のトークンに伝播すると考えられます.

やや後半の層では事実想起に重要な役割を果たしている

subject情報の流れである緑色の線を見ると, 2つの谷があることがわかります.

  1. 最初の谷は初期層にありますが, これはMambaが初期初期層で全ての主語トークンから情報を集約して複数のトークンからなるsubject entity s を認識する必要があるので, 当然の結果です.
  2. しかし, 43~48層に見られる谷は, Mambaがその層でConv+SSM経路を利用してsubjectから後続のトークンへ重要な情報を伝播していることが示唆されます. これは以前の結果と一致しており, その層の s_i 状態が高い間接効果を示していることから, 事実を想起する際に重要な役割を果たしていることが示唆されます.

last subject tokenの除去だけでは不完全

last subject token情報をブロックした際の結果である青い線を見ます. 初期層では後続の層が損失を補うことがわかります. しかし, 20層あたりで谷があり, Mambaがその時点までに完全なsubject entityを認識して関連する連想 (enrichment)を想起することが示唆されます. 特に, enrichmentに重要な役割を果たしていると仮定している o_i, s_i の状態についてのactivation patchingの結果もその領域で強い間接効果を示しています. 30層目以降では青と緑の線は同じ動きをしますが, p(o) の変化が小さいのはlast subject tokenの除去だけではすべてのsubject情報を取り除くことができないと考えられるからです. 例えば, Eiffel Towerという文字列を考えると, Eiffel (トークン化されるとE, iff, el)はTowerよりも情報量が多いです.

まとめ

  • ROMEがMambaに適用できるかの検討
  • 全体としてMambaとTransformerでは同じような考察ができる
  • 著者らは自己回帰型の言語モデリングというタスクが事実想起パターンを誘発するのではと考えている

参考文献

  • Meng, K., Bau, D., Andonian, A., and Belinkov, Y. Locating and editing factual associations in gpt. In Koyejo, S., Mohamed, S., Agarwal, A., Belgrave, D., Cho, K., and Oh, A. (eds.), Advances in Neural Information Processing Systems, volume 35, pp. 17359–17372. Curran Associates, Inc., 2022.
  • Albert Gu and Tri Dao. Mamba: Linear-time sequence modeling with selective state spaces. In First Conference on Language Modeling, 2024.
  • Arnab Sen Sharma, David Atkinson, and David Bau. Locating and editing factual associations in mamba. In First Conference on Language Modeling, 2024.
脚注
  1. もう少し正しくいうと, 因果推論でわかることは「事実の想起において何かしらの大きな関与をしていると予想される.」ということです. ↩︎

  2. 詳細は, Are Transformers Effective for Time Series Forecasting?などを参照してください ↩︎

Discussion