Titansを理解したい..タン
tl:dr
Google Researchが2024年12月31日に公開したTitans: Learning to Memorize at Test Timeという論文について理解するために,解説.
勉強のために書いてるので「ここ間違ってそうだな」と思われたら,コメントお願いします!
1. はじめに
LLMで広く使われているTransformer の Self Attention 機構 は,すべてのトークン間の依存関係を計算するため,入力が長くなると計算量が入力長の2乗に比例して増加するという課題がある.
この論文では、Attention を短期記憶とみなし、長期依存性は ニューラル長期記憶モジュール で補完する新たな手法を提案.このモジュールは,推論時にもオンラインで記憶更新が可能なため,長いcontextも効率よく扱うことができるらしい.
2. 論文の背景と課題
2.1 Transformer と Attention の基本
Transformer では,入力系列
ここで,
Causal Attention の場合,出力
この方法は全トークン間の依存関係を捉えることができるが,入力長
2.2 Attentionの効率化とその限界
長いシーケンスに対応するため、線形 Attention が提案されている.
ソフトマックス関数をカーネル関数
この方法を取ることで,共通部分の再利用をすることができ,計算コストが大幅に削減される.ただ,情報を固定サイズの状態に圧縮するため,非常に長い文脈を保持するのには限界がある.
2.3 人間と既存モデルの記憶の仕組みの比較
以下の表にある通り,HopFieldやLSTM,Transformersは人間の脳に触発されて開発されてきたけど,搭載している記憶の要素,独立して動作するか,学び続けるかという3点のどれかが不足している
記憶・学習の要素 | 人間の脳 | 既存のAIモデル |
---|---|---|
短期記憶 | 柔軟に直近の情報を保持し、状況に応じた迅速な判断が可能 | AttentionやLSTMの隠れ状態など、短期情報は保持するが文脈ウィンドウに依存する |
長期記憶 | 長期間にわたり豊富な情報を蓄積し、経験に基づいた判断が可能 | 長期依存性の情報は圧縮・固定長の表現に依存し、十分な抽象化が難しい |
メタ記憶 | 状況に応じてどの記憶を利用するか、記憶自体を制御・調整できる | 基本的に特定のタスクに合わせた学習は行うが、記憶そのものの制御機能は限定的 |
記憶モジュールの独立性と連携性 | 複数の記憶システム(短期、長期、メタ記憶)が独立しつつ連携して動作する | 一般的に一種類の表現(例:隠れ状態やAttention行列)に依存し、独立性や連携が不足 |
学習と抽象化の能力 | 経験から積極的に学習し、重要情報を抽象化して柔軟に活用できる | 入力データから学習するが、過去の情報の抽象化や動的な記憶更新には限界がある |
3. Titansの構造
Titansの論文の肝は, 推論時にもオンラインで記憶更新を行う長期記憶モジュール の設計をし,実験,検証をしたこと.長期記憶モジュールは,基本的には複数層のMLPで構成され,Transformerのように
長期記憶の更新方法についてざっくり言うと,モデルに対する入力の勾配が大きいと言うことは,モデルが予期しないような入力であり,モデルが持ってない知識や考え方に驚いている(Surprise).そのような情報は重要度が高いから記憶しておこうと言う感じ
長期記憶モジュールの詳細
Surprise(驚き)に基づく記憶更新
この論文では、損失関数
最も単純な式は以下の通り.
-
: 時刻M_t におけるメモリ状態t -
: 入力\nabla \ell(M_{t-1}; x_t) に対する損失の勾配(=「驚き」)x_t -
: 更新の強さを制御するパラメータ\theta_t
ただ,この式をそのまま使うと問題が生じてしまう.
具体例
少々雑ですが,例えば入力が以下のようなものを考える.
※著者の解釈に基づいたイメージ
あなたは衛星開発プロジェクトに関わっていた.打ち上げ当日,ニュースを見ていると,自分が開発に携わった衛星を搭載したロケットが司令破壊され,宇宙に届くことなく海の藻屑となったことを知った.
そして1ヶ月後――その原因が「宇宙人の攻撃」だったと判明する.
ロケットの打ち上げ成功率は高いので,「ロケットが司令破壊」になったという事実にまずは驚き,この時損失関数
この「Surprise」が非常に大きい状態で勾配更新が行われたとき,その後では
Surpriseの分割
前述のような連続する入力の変化を効果的に捉えるために,驚き(Surprise)を過去の驚き(Past Surprise)と瞬間的な驚き(Momentary Surprise)を分け,更新は以下の 2 段階で行うようにする.
-
: 時刻S_t における驚き度t -
: 時間と共に過去の驚き具合を減衰させる(\eta_t )0 \le \eta_t \le 1 -
: その時点の勾配の重み付け\theta_t
具体例
-
「司令破壊」
通常の予測(ロケットは成功するはず)と異なるため, は大きな値をとる.その結果,更新項\nabla \ell(M_{t-1}; x_t) にはこの瞬間の驚きが反映され,メモリS_t は式9により前時刻の記憶M_t に更新項M_(t-1) が加えられる.S_t -
「宇宙人の攻撃」
「司令破壊」による驚き はS_(t-1) により一定程度保持され,その上で新たな「宇宙人の攻撃」に対する瞬間的な驚き\eta_t が加わり,更新項- \theta_t \nabla \ell(M_{t-1}; x_t) が再計算される.S_t
これにより,司令破壊と宇宙人の攻撃の両方の情報がバランスよく記憶に組み込まれるらしい
損失の計算方法
ざっくり言うと,Transformersのように入力から
キーとバリューの生成
入力
-
x_t
時刻 (t) の入力(例:ロケット打ち上げのシナリオで「司令破壊」と思っていたが,実は「宇宙人の攻撃」だったという情報) -
:学習可能な重み行列で、入力から特徴を抽出し、それぞれキーとバリューに変換(W_K) と (W_V) -
:後に記憶から関連情報を検索するための「検索キー」として機能key (k_t) -
:実際に記憶すべき情報(「宇宙人襲来」という重大な情報など)value (v_t)
-
連想記憶を学習する際のLoss
「正しい
-
:時刻M_{t-1}(k_t) までに構築された記憶モジュールt-1 に、現在のM_{t-1} key を入力して得られるk_t value
これは、「過去の記憶から、現在の入力に対応する情報がどの程度再現できるか」を示す. -
:式11で生成された実際のv_t value
これとの間の誤差(L2ノルムの二乗)が小さくなるように,記憶モジュールの学習を行う.
忘却メカニズム(Forgetting)の導入
Context長が長い入力が与えられた時にメモリが不要な情報を忘れるようにすることで,記憶容量を管理したい!
過去の情報について,パラメータ
-
: 忘却ゲートのパラメータ\alpha_t \in [0,1] -
の場合:過去情報をほぼそのまま保持\alpha_t \to 0 -
の場合:過去情報を大幅に忘却\alpha_t \to 1
-
長期記憶モジュールのアーキテクチャ
長期記憶モジュールのアーキテクチャは,
記憶の取得
長期記憶モジュールに保存された記憶を取得する方法は,単純に重み更新を行わないフォーワードパスを使用するだけ.
入力に基づいて作成された
-
q_t
ベクトルで,入力query を線形層x_t で変換して得る(例:W_Q )q_t = x_t W_Q
この は、記憶モジュールに対して「どの情報が欲しいか」を示すquery -
M^*(q_t)
記憶モジュール にクエリM を入力して、対応する記憶情報(q_t )を取得する操作value
この操作は,推論時に重み更新を行わず,ただ記憶から情報を引き出す
4. Titans アーキテクチャの種類
長期記憶モジュールを既存のモデルに統合する方法は3種類
4.1 Memory as a Context (MAC)
長期記憶をコンテキスト情報として扱う方式で,簡単に言えば,入力と長期記憶モジュールから抽出した記憶を連結して一つの大きなコンテキストとして扱うもの.
具体的な流れは以下の通り.
-
セグメント分割: 入力系列
を固定サイズのセグメントx \in \mathbb{R}^{N \times d_{in}} に分割S(i) -
記憶からの情報取得: 現在のセグメント
に対して,前時刻の長期記憶S(t) を利用し,M_{t-1} query を生成して対応する情報q_t = S(t)W_Q を取得h_t -
h_t = M_{t-1}^*(q_t) \quad \text{(式21)} -
Persistent Memoryの付加: タスク固有の知識を持つ固定パラメータ列
を先頭に結合し,拡張された入力[p_1, p_2, \ldots, p_{N_p}] を作成\tilde{S}(t) \tilde{S}(t) = \bigl[\, p_1, p_2, \ldots, p_{N_p} \,\,||\,\, h_t \,\,||\,\, S(t) \bigr] \quad \text{(式22)} -
Attention による統合と記憶更新: 拡張された入力
に通常の Attention を適用し、出力\tilde{S}(t) を得た後、次時刻の記憶y_t を更新M_t
利点
過去の重要な情報が明確に反映される
欠点
入力の分割方法によっては,適切に過去の情報が取り出せない
4.2 Gated Memory (MAG)
MAG では,短期記憶と長期記憶を別々に処理し,非線形なゲート機構で統合することで,必要な時に必要に応じて長期記憶の情報を取得するようになり,計算量の面で効率的.
- 短期記憶: スライディングウィンドウ Attention(SW-Attn)を用いて局所的な依存関係を捉える.
- 長期記憶: ニューラル長期記憶モジュールの出力を利用
統合は,次の式で行う.
ここで,
利点
新しい情報を忘れにくい.
重要な情報が残りやすい
欠点
構造が複雑なので学習の難易度が高い
4.3 Memory as a Layer (MAL)
MAL では,長期記憶モジュールを「層」として直接組み込む.
実装が一番簡単な分,性能が高いわけではないらしい.
-
パーシステントメモリの付加: 入力
に対して,固定パラメータ列x を結合します.[p_1, p_2, \ldots, p_{N_p}] x' = \bigl[\, p_1, p_2, \ldots, p_{N_p} \,\,||\,\, x \bigr] \quad \text{(式29)} -
記憶層の適用: 拡張された入力
に対して,長期記憶層x' を適用し,出力M(x) を得る.y y = M(x) \quad \text{(式30)} -
最終出力の生成: その後,スライディングウィンドウ Attention を適用し,最終出力
を生成o o = \text{SW-Attn}(y) \quad \text{(式31)}
利点
構造がシンプル
欠点
細かい情報が抜け落ちやすい
5. 実験
実験については以下の表を参照.簡潔に言うと,長いシーケンスを扱う際はTitansはものすごく性能がいいらしい
セクション | 実験条件・設定 | 結果 |
---|---|---|
実験設定 | - モデルの種類:MAC、MAG、MAL,記憶モジュールのみ(LMM) - パラメータサイズ:170M, 340M, 400M, 760M - 学習データ:15B/30Bトークン(FineWeb‐Edu) - ベースライン:Transformer++、RetNet、GLA、Mambaなど |
|
言語モデリング&常識推論 | - 評価指標:perplexity,accuracy - 実験スケール:340M~760Mパラメータ |
LMMやMAC, MAG, MALが、既存モデルに対して低ppl・高accuracy |
Needle in a Haystack | - タスク:長い文章中から「針」に相当する重要情報の抽出 - シーケンス長:2K, 4K, 8K, 16K(RULERベンチマーク、S-NIAHタスク) |
Titans(特にMAC)がシーケンス長が長くなっても高精度を維持し、一貫した性能 |
BABILong Benchmark | - タスク:長文内に分散する複数の事実を用いた推論 - 設定:Few-shot推論とFine-tuningしたモデルで比較 |
Titans(MAC)がGPT-4等の大規模モデルを凌駕 |
長期記憶モジュールの層増加の効果 | - 比較対象:メモリ深度 𝐿_M = 1, 2, 3, 4 - 同じ学習方法ででパープレキシティとトレーニングスループットを評価 |
メモリ層を増加するとpplは改善されるが,トレーニングスループットは低下する,性能と効率のトレードオフ |
今後の展望
- より大規模なモデルや,時系列予測,ゲノミクスなど他分野への応用
- 忘却機構やSurprise Momemtumのパラメータ調整により,さらに柔軟で効果的な記憶更新が実現できるかも?
意見,感想
googleさん,pytorchでの実装の公開待ってます!
非公式実装はGithubで見つけました
驚き度(Surprise)について
「モデルに対する,入力の勾配の大きさ自体を指標にすること」自体は,真新しいことではなくて,LLMでもカリキュラム学習の分野で,データを学習させる順番を決める際に勾配の大きさの逆順で学習させると上手くいくという論文もあったりする.ただ,それを「Surprise」と表現した上で,実際に長期記憶モジュールを実装して実証したのはさすがGoogle先生と思った.
ひとりごと
Associative Memory:連想記憶,persistent memory:永続的な記憶で翻訳は合ってるのだろうか?初めて聞いた...
忘却ゲートで,過去の情報全てに対して
一旦,日本語Titansモデル作ってみようかな
(実験でどのパラメータのモデルと比較しているかとか,それぞれのモジュールのパラメータはどのくらいかとかが書かれてないので,再現実装でモデル作るのはその辺の情報が出揃ってからの方が良さそう)
Discussion