🧌

Titansを理解したい..タン

2025/02/27に公開

tl:dr

Google Researchが2024年12月31日に公開したTitans: Learning to Memorize at Test Timeという論文について理解するために,解説.
勉強のために書いてるので「ここ間違ってそうだな」と思われたら,コメントお願いします!


1. はじめに

LLMで広く使われているTransformer の Self Attention 機構 は,すべてのトークン間の依存関係を計算するため,入力が長くなると計算量が入力長の2乗に比例して増加するという課題がある.
この論文では、Attention を短期記憶とみなし、長期依存性は ニューラル長期記憶モジュール で補完する新たな手法を提案.このモジュールは,推論時にもオンラインで記憶更新が可能なため,長いcontextも効率よく扱うことができるらしい.


2. 論文の背景と課題

2.1 Transformer と Attention の基本

Transformer では,入力系列 x \in \mathbb{R}^{N \times d_{in}} から,以下のようにクエリ Q、キー K、バリュー V を生成する.

Q = xW_Q,\quad K = xW_K,\quad V = xW_V,

ここで,W_Q, W_K, W_V は学習パラメータ
Causal Attention の場合,出力 y_i の式は以下のようになる.

y_i = \frac{\sum_{j=1}^{i}\exp\Bigl(\frac{Q_i^\top K_j}{\sqrt{d_{in}}}\Bigr)V_j}{\sum_{\ell=1}^{i}\exp\Bigl(\frac{Q_i^\top K_\ell}{\sqrt{d_{in}}}\Bigr)}.

この方法は全トークン間の依存関係を捉えることができるが,入力長 N に対して計算量が O(N^2) となるため,長いシーケンスでは非効率的.

2.2 Attentionの効率化とその限界

長いシーケンスに対応するため、線形 Attention が提案されている.
ソフトマックス関数をカーネル関数 \phi に置き換えることで,以下のような式で表される.

y_i = \frac{\phi(Q_i)^\top \Bigl(\sum_{j=1}^{i}\phi(K_j)V_j\Bigr)}{\phi(Q_i)^\top \Bigl(\sum_{\ell=1}^{i}\phi(K_\ell)\Bigr)},

この方法を取ることで,共通部分の再利用をすることができ,計算コストが大幅に削減される.ただ,情報を固定サイズの状態に圧縮するため,非常に長い文脈を保持するのには限界がある.

2.3 人間と既存モデルの記憶の仕組みの比較

以下の表にある通り,HopFieldやLSTM,Transformersは人間の脳に触発されて開発されてきたけど,搭載している記憶の要素,独立して動作するか,学び続けるかという3点のどれかが不足している

記憶・学習の要素 人間の脳 既存のAIモデル
短期記憶 柔軟に直近の情報を保持し、状況に応じた迅速な判断が可能 AttentionやLSTMの隠れ状態など、短期情報は保持するが文脈ウィンドウに依存する
長期記憶 長期間にわたり豊富な情報を蓄積し、経験に基づいた判断が可能 長期依存性の情報は圧縮・固定長の表現に依存し、十分な抽象化が難しい
メタ記憶 状況に応じてどの記憶を利用するか、記憶自体を制御・調整できる 基本的に特定のタスクに合わせた学習は行うが、記憶そのものの制御機能は限定的
記憶モジュールの独立性と連携性 複数の記憶システム(短期、長期、メタ記憶)が独立しつつ連携して動作する 一般的に一種類の表現(例:隠れ状態やAttention行列)に依存し、独立性や連携が不足
学習と抽象化の能力 経験から積極的に学習し、重要情報を抽象化して柔軟に活用できる 入力データから学習するが、過去の情報の抽象化や動的な記憶更新には限界がある

3. Titansの構造

Titansの論文の肝は, 推論時にもオンラインで記憶更新を行う長期記憶モジュール の設計をし,実験,検証をしたこと.長期記憶モジュールは,基本的には複数層のMLPで構成され,Transformerのようにkeyvalueを使って処理をする.
 長期記憶の更新方法についてざっくり言うと,モデルに対する入力の勾配が大きいと言うことは,モデルが予期しないような入力であり,モデルが持ってない知識や考え方に驚いている(Surprise).そのような情報は重要度が高いから記憶しておこうと言う感じ

長期記憶モジュールの詳細

Surprise(驚き)に基づく記憶更新

この論文では、損失関数 \ell の勾配を用いて、入力 x_t がどれほど予期せぬものであるか(すなわち「Surprise」)を評価し、その大きさに応じて更新を行います。
最も単純な式は以下の通り.

M_t = M_{t-1} - \theta_t \nabla \ell(M_{t-1}; x_t) \quad \text{(式8)}
  • M_t: 時刻 t におけるメモリ状態
  • \nabla \ell(M_{t-1}; x_t): 入力 x_t に対する損失の勾配(=「驚き」)
  • \theta_t: 更新の強さを制御するパラメータ

ただ,この式をそのまま使うと問題が生じてしまう.

具体例
少々雑ですが,例えば入力が以下のようなものを考える.
※著者の解釈に基づいたイメージ

あなたは衛星開発プロジェクトに関わっていた.打ち上げ当日,ニュースを見ていると,自分が開発に携わった衛星を搭載したロケットが司令破壊され,宇宙に届くことなく海の藻屑となったことを知った.
そして1ヶ月後――その原因が「宇宙人の攻撃」だったと判明する.

ロケットの打ち上げ成功率は高いので,「ロケットが司令破壊」になったという事実にまずは驚き,この時損失関数 \ell の勾配 \nabla \ell(M_{t-1}; x_t) は非常に大きい値を取る.
 この「Surprise」が非常に大きい状態で勾配更新が行われたとき,その後では\nabla\ellが小さくなってしまうため,「宇宙人の攻撃」と言う大きな情報が十分に反映されない可能性がある.

Surpriseの分割

前述のような連続する入力の変化を効果的に捉えるために,驚き(Surprise)を過去の驚き(Past Surprise)と瞬間的な驚き(Momentary Surprise)を分け,更新は以下の 2 段階で行うようにする.

\begin{aligned} M_t &= M_{t-1} + S_t \quad \text{(式9)}\\[2mm] S_t &= \eta_t S_{t-1} - \theta_t \nabla \ell(M_{t-1}; x_t) \quad \text{(式10)} \end{aligned}
  • S_t: 時刻 t における驚き度
  • \eta_t: 時間と共に過去の驚き具合を減衰させる(0 \le \eta_t \le 1
  • \theta_t: その時点の勾配の重み付け

具体例

  • 「司令破壊」
     通常の予測(ロケットは成功するはず)と異なるため, \nabla \ell(M_{t-1}; x_t)は大きな値をとる.その結果,更新項 S_t にはこの瞬間の驚きが反映され,メモリ M_t は式9により前時刻の記憶 M_(t-1) に更新項 S_t が加えられる.

  • 「宇宙人の攻撃」
     「司令破壊」による驚き S_(t-1)\eta_t により一定程度保持され,その上で新たな「宇宙人の攻撃」に対する瞬間的な驚き- \theta_t \nabla \ell(M_{t-1}; x_t)が加わり,更新項S_tが再計算される.

これにより,司令破壊と宇宙人の攻撃の両方の情報がバランスよく記憶に組み込まれるらしい

損失の計算方法

ざっくり言うと,Transformersのように入力からkeyvalueを生成し,記憶モジュールがその対応関係をどれだけ正確に再現しているかの損失を計算.

キーとバリューの生成

入力x_tを元に,記憶モジュールに情報を保存,検索するためのキー,バリューを生成.

k_t = x_tW_K,\quad v_t = x_tW_V \quad(式11)
  • x_t
    時刻 (t) の入力(例:ロケット打ち上げのシナリオで「司令破壊」と思っていたが,実は「宇宙人の攻撃」だったという情報)

  • (W_K) と (W_V) :学習可能な重み行列で、入力から特徴を抽出し、それぞれキーとバリューに変換

    • key (k_t) :後に記憶から関連情報を検索するための「検索キー」として機能
    • value (v_t) :実際に記憶すべき情報(「宇宙人襲来」という重大な情報など)

連想記憶を学習する際のLoss

「正しいkeyvalueの対応関係」を記憶できるように記憶モジュール内のMLPのパラメータを更新する.

\ell (M_{t-1};x_t)=\|M_{t-1}(k_t) - v_t\|_2^2 \quad(式12)
  • M_{t-1}(k_t) :時刻 t-1 までに構築された記憶モジュール M_{t-1} に、現在のkey k_t を入力して得られるvalue
    これは、「過去の記憶から、現在の入力に対応する情報がどの程度再現できるか」を示す.

  • v_t :式11で生成された実際のvalue
    これとの間の誤差(L2ノルムの二乗)が小さくなるように,記憶モジュールの学習を行う.

忘却メカニズム(Forgetting)の導入

Context長が長い入力が与えられた時にメモリが不要な情報を忘れるようにすることで,記憶容量を管理したい!
過去の情報について,パラメータalphaを入れることで,常に細心かつ重要な情報のみを効率的に記憶するようにした

M_t = (1 - \alpha_t) M_{t-1} + S_t \quad \text{(式13)}
  • \alpha_t \in [0,1]: 忘却ゲートのパラメータ
    • \alpha_t \to 0 の場合:過去情報をほぼそのまま保持
    • \alpha_t \to 1 の場合:過去情報を大幅に忘却

長期記憶モジュールのアーキテクチャ

長期記憶モジュールのアーキテクチャは,L_M \geqq 1層の単純なMLPで構成され,L_M \geqq 2層にすると,より表現力が高くなり性能が向上するが,計算量とのトレードオフがある

記憶の取得

長期記憶モジュールに保存された記憶を取得する方法は,単純に重み更新を行わないフォーワードパスを使用するだけ.
 入力に基づいて作成されたquery(検索文みたいなもん)を元に,記憶モジュールから適切な情報を検索し,記憶を取得する.

y_t = M^*(q_t)
  • q_t
    queryベクトルで,入力 x_t を線形層 W_Q で変換して得る(例:q_t = x_t W_Q
    このqueryは、記憶モジュールに対して「どの情報が欲しいか」を示す

  • M^*(q_t)
    記憶モジュール M にクエリ q_t を入力して、対応する記憶情報(value)を取得する操作
    この操作は,推論時に重み更新を行わず,ただ記憶から情報を引き出す


4. Titans アーキテクチャの種類

長期記憶モジュールを既存のモデルに統合する方法は3種類

4.1 Memory as a Context (MAC)

長期記憶をコンテキスト情報として扱う方式で,簡単に言えば,入力と長期記憶モジュールから抽出した記憶を連結して一つの大きなコンテキストとして扱うもの.
具体的な流れは以下の通り.

  • セグメント分割: 入力系列 x \in \mathbb{R}^{N \times d_{in}} を固定サイズのセグメント S(i) に分割

  • 記憶からの情報取得: 現在のセグメント S(t) に対して,前時刻の長期記憶 M_{t-1} を利用し,query q_t = S(t)W_Q を生成して対応する情報 h_t を取得

  • h_t = M_{t-1}^*(q_t) \quad \text{(式21)}
  • Persistent Memoryの付加: タスク固有の知識を持つ固定パラメータ列 [p_1, p_2, \ldots, p_{N_p}] を先頭に結合し,拡張された入力 \tilde{S}(t) を作成

    \tilde{S}(t) = \bigl[\, p_1, p_2, \ldots, p_{N_p} \,\,||\,\, h_t \,\,||\,\, S(t) \bigr] \quad \text{(式22)}
  • Attention による統合と記憶更新: 拡張された入力 \tilde{S}(t) に通常の Attention を適用し、出力 y_t を得た後、次時刻の記憶 M_t を更新

利点
過去の重要な情報が明確に反映される
欠点

入力の分割方法によっては,適切に過去の情報が取り出せない

4.2 Gated Memory (MAG)

MAG では,短期記憶と長期記憶を別々に処理し,非線形なゲート機構で統合することで,必要な時に必要に応じて長期記憶の情報を取得するようになり,計算量の面で効率的.

  • 短期記憶: スライディングウィンドウ Attention(SW-Attn)を用いて局所的な依存関係を捉える.
  • 長期記憶: ニューラル長期記憶モジュールの出力を利用

統合は,次の式で行う.

\begin{aligned} y &= \text{SW-Attn}^*(x) \quad \text{(式27)}\\[2mm] o &= y \otimes M(x) \quad \text{(式28)} \end{aligned}

ここで,\otimes は非線形なゲート操作を意味し,短期記憶と長期記憶を効果的に融合する.
利点
新しい情報を忘れにくい.
重要な情報が残りやすい
欠点
構造が複雑なので学習の難易度が高い

4.3 Memory as a Layer (MAL)

MAL では,長期記憶モジュールを「層」として直接組み込む.
実装が一番簡単な分,性能が高いわけではないらしい.

  1. パーシステントメモリの付加: 入力 x に対して,固定パラメータ列 [p_1, p_2, \ldots, p_{N_p}] を結合します.

    x' = \bigl[\, p_1, p_2, \ldots, p_{N_p} \,\,||\,\, x \bigr] \quad \text{(式29)}
  2. 記憶層の適用: 拡張された入力 x' に対して,長期記憶層 M(x) を適用し,出力 y を得る.

    y = M(x) \quad \text{(式30)}
  3. 最終出力の生成: その後,スライディングウィンドウ Attention を適用し,最終出力 o を生成

    o = \text{SW-Attn}(y) \quad \text{(式31)}

利点
構造がシンプル
欠点
細かい情報が抜け落ちやすい

5. 実験

実験については以下の表を参照.簡潔に言うと,長いシーケンスを扱う際はTitansはものすごく性能がいいらしい

セクション 実験条件・設定 結果
実験設定 - モデルの種類:MAC、MAG、MAL,記憶モジュールのみ(LMM)
- パラメータサイズ:170M, 340M, 400M, 760M
- 学習データ:15B/30Bトークン(FineWeb‐Edu)
- ベースライン:Transformer++、RetNet、GLA、Mambaなど
言語モデリング&常識推論 - 評価指標:perplexity,accuracy
- 実験スケール:340M~760Mパラメータ
LMMやMAC, MAG, MALが、既存モデルに対して低ppl・高accuracy
Needle in a Haystack - タスク:長い文章中から「針」に相当する重要情報の抽出
- シーケンス長:2K, 4K, 8K, 16K(RULERベンチマーク、S-NIAHタスク)
Titans(特にMAC)がシーケンス長が長くなっても高精度を維持し、一貫した性能
BABILong Benchmark - タスク:長文内に分散する複数の事実を用いた推論
- 設定:Few-shot推論とFine-tuningしたモデルで比較
Titans(MAC)がGPT-4等の大規模モデルを凌駕
長期記憶モジュールの層増加の効果 - 比較対象:メモリ深度 𝐿_M = 1, 2, 3, 4
- 同じ学習方法ででパープレキシティとトレーニングスループットを評価
メモリ層を増加するとpplは改善されるが,トレーニングスループットは低下する,性能と効率のトレードオフ

今後の展望

  • より大規模なモデルや,時系列予測,ゲノミクスなど他分野への応用
  • 忘却機構やSurprise Momemtumのパラメータ調整により,さらに柔軟で効果的な記憶更新が実現できるかも?

意見,感想

googleさん,pytorchでの実装の公開待ってます!
非公式実装はGithubで見つけました

https://github.com/lucidrains/titans-pytorch

驚き度(Surprise)について

「モデルに対する,入力の勾配の大きさ自体を指標にすること」自体は,真新しいことではなくて,LLMでもカリキュラム学習の分野で,データを学習させる順番を決める際に勾配の大きさの逆順で学習させると上手くいくという論文もあったりする.ただ,それを「Surprise」と表現した上で,実際に長期記憶モジュールを実装して実証したのはさすがGoogle先生と思った.

ひとりごと

Associative Memory:連想記憶,persistent memory:永続的な記憶で翻訳は合ってるのだろうか?初めて聞いた...
忘却ゲートで,過去の情報全てに対して(1- alpha)をかけるのは,結構雑な気がしてたけど上手くいくんだな
一旦,日本語Titansモデル作ってみようかな
(実験でどのパラメータのモデルと比較しているかとか,それぞれのモジュールのパラメータはどのくらいかとかが書かれてないので,再現実装でモデル作るのはその辺の情報が出揃ってからの方が良さそう)

Discussion