⛩️

【論文解説】Gated Linear Attentionで実現する超効率的な長文処理

2025/01/30に公開

はじめに

最近の人工知能や自然言語処理の世界では、「Transformers (トランスフォーマー)」 というモデルが大活躍しています。ChatGPT などの対話型AIや、大規模な言語モデルのほとんどがこの仕組みを使っています。しかし、トランスフォーマーは 「入力の長さが長い」 ほどメモリや計算量が大きくなり、扱うのが非常に大変です。

今回紹介する論文は、

Gated Linear Attention Transformers with Hardware-Efficient Training

というタイトルで、「長い入力をもっと効率よく扱いつつ、そこそこの精度も落とさないTransformer」を提案しています。名前は難しそうですが、やっていることを「やさしく」言うと、

  1. 計算が速くてメモリも少なく済む仕組みを取り入れた『リニアアテンション』という方法
  2. そこに、RNN(LSTM など)でおなじみの『ゲート(忘却ゲート)』を入れて、精度を上げる工夫
  3. GPU上で本当に高速に動かすための細かな最適化テクニック

を使って、「長い文章(長い入力)に対しても効率よく学習できる」モデルを作りましたよ、という内容になります。

そもそも「アテンション」とは何か?

「トランスフォーマー」というモデルでは、アテンション (attention) という仕組みが中心にあります。これは超ざっくり言うと、

いま処理している単語が、過去のどの単語に注目すると良いか?

を計算する仕組みです。具体的には、1単語あたりの「クエリ(質問)」と、各単語が持っている「キー(鍵)」との相性を計算し、その重みで「バリュー(情報)」を混ぜ合わせる、という方法をとっています。
これが「ソフトマックス・アテンション」と呼ばれるものなのですが、入力(文章)の長さを L とすると、計算量やメモリ消費が O(L^2)(入力長の2乗に比例) になり、1万文字の入力などを扱うと非常に重い処理になってしまいます。

リニアアテンションのアイデア

ここで登場するのが リニアアテンション (Linear Attention) と呼ばれる手法です。ソフトマックスによる重み付けを止めて、より計算しやすい内積(線形演算)に置き換えることで、理論上は

  • 推論(実際の生成)をするとき: 系列長 L に対して線形時間 O(L)
  • 学習時も部分的には高速化しやすい

といわれています。

しかし問題は、リニアアテンションはソフトマックス・アテンションに比べて精度が落ちがち という点です。長い文章を「本当に」うまく理解・生成できるのか? と聞かれると微妙な結果が多かったのです。

忘却ゲート(ゲーティング)を加えた新手法

論文タイトルの「Gated Linear Attention」とは、「リニアアテンションに、ゲート(忘却ゲート)を加えて賢くしたもの」 という意味です。RNN(たとえば LSTM)の「どの情報を保持して、どれを忘れるか」を制御する フォゲットゲート に相当する仕組みを加えて、長い文脈でも情報を適切に扱えるようにしています。

具体的に言うと、

従来のリニアアテンションは、過去から積み上げてきた(行列状の)重みを加算し続けるだけ

入力に応じて「ここは忘れていい」「ここはちゃんと残そう」という行列ゲートを掛け算

という設計をします。
結果として

  • 単に「全部ちょっとずつ薄くするスカラーの減衰」よりも、もっと細かく制御 できる
  • 過去の不要な情報を消しつつ、必要なところは保持できる

ようになり、性能が大きく向上します。

なぜ高速で学習できるのか?:チャンク分割 & IO最適化

ただし、いくら理論上「リニア (線形)」だと言っても、実装が下手だと結局は遅いです。GPU のメモリの読み書きが多すぎたりして、ソフトマックスのほうがよほど高速に動く...という話がこれまでありました。

そこで論文では、FLASHAttention という既存の高速実装(ソフトマックス・アテンションをめちゃくちゃ速くする手法)の考え方を参考にしながら、

  1. チャンク分割: 長い文章を「チャンク」という小さな塊に分けて計算
  2. GPUの高速演算ユニット(Tensor Core)を最大限利用: なるべく行列乗算の形にまとめる
  3. 不要な中間結果を保存せず、必要なら再計算(あえて書き戻さないことでメモリ転送を減らす)

などの手法を組み合わせて、高速かつ低メモリで計算できるよう最適化しています。

つまり論文が言いたいのは、

「リニアアテンション」を考えても、ちゃんと実装しないと実際には速くならない。
が、チャンク分割や再計算など、ハードウェア寄りの最適化をうまく組み合わせることで、
本当に高速で学習できる『Gated Linear Attention』が動くようになった!

というわけです。

実験結果

論文では中規模(3億〜10億パラメータ級)ぐらいのモデルを作り、ソフトマックスの人気モデル(LLaMA系)や、ほかのリニアアテンション系モデル(RetNet, Mamba など)と比較しています。

  • 精度(パープレキシティや各種ベンチマーク):
    RetNet・Mamba と同等または上回り、さらにソフトマックスに近い性能を達成。
  • 学習速度(GPU上のトークンスループット):
    リニアアテンションでありながら、既存実装より速く、ソフトマックスに匹敵するか、場合によっては上回る速度を発揮。
  • 長い文脈に対する強さ:
    例として、学習時に 2k トークンしか扱わなかったのに 18k 〜 20k トークンのテストでそこそこ高い性能を保つ、あるいは 8k や 24k の長さでも大きな追加コストなく学習可能である、など「長文に強い」一面もみせる。

また、「大事な情報を後で正しく思い出す力が必要なタスク(リコールが大事なタスク)」でも、行列ゲートが効いているのか、従来のリニア系モデルよりも良い成績を出しているということです。

まとめ:この論文の「すごい」ところ

  1. 「リニアアテンションは速いけど精度イマイチ」という弱点を、入力に依存するゲート(忘却機構)で補った

    • ゲートによって、過去の情報をより上手に取捨選択できる → ソフトマックス・アテンションに迫る性能や、長期記憶が必要なタスクでの有利性を得やすい。
  2. GPU上で本当に速く動く実装を追求した

    • 単なる理論上のアイデアにとどまらず、細かい最適化(IO-awareness, チャンク分割, テンソルコアの活用など)で、
      ソフトマックス最適化の代表例「FlashAttention-2」と同等か、それ以上の速度を出せる場合を示した。
  3. 長い文章に対して効率的&スケーラブル

    • 短い文脈だけで訓練しても、もっと長い文脈に extrapolation(汎化)しやすい様子を示した。
    • さらに本当に長い 8k, 24k といった文脈長でも学習がまわりやすい。

このように、トランスフォーマーの大きな課題「長文を処理するときの計算量・メモリ量」を解決する一手段として、ゲート機構つきのリニアアテンションを、高速実装込みで具体的に示した点が非常に注目されています。

もし「リニアアテンションって、ほんとに使い物になるの?」と疑問を持っている人がいれば、「ゲート + ハードウェア最適化で、やればできるんだ!」という手応えが得られる論文になっています。

興味があれば

  • もっと数式的な詳細や、具体的なアルゴリズム(どうタイル分割し、どう再計算し、どこでゲートをかけているかなど)を知りたい人は、ぜひ論文本文を読むと面白いです。
  • RNN のゲート機構や LSTM/GRU の歴史を少しでも知っていると「ああ、なるほど、リニアアテンション版の LSTM みたいな感じね」と理解が深まるでしょう。
  • 「FlashAttention」の仕組みも、公式実装や他の解説サイトで学んでおくと、GPU最適化やチャンク分割のエッセンスが分かりやすいかと思います。

いずれにせよ、今回のポイントは「長い文章をもっとラクに扱える大規模言語モデル」への一歩として、ゲート付きリニアアテンション + 高速実装が提案された、ということでした。

今後ますます巨大化する自然言語処理モデルにおいて、こうした「計算削減しつつ性能を維持する」アイデアは注目度が高まっていくと思われます。ぜひ、この分野がどのように進化していくか、引き続きウォッチしてみてください。

上級者向けの解説

ここからは、本論文「Gated Linear Attention Transformers with Hardware-Efficient Training」の数理的なさらなる詳細に迫ります。
上級者を対象に少し深い内容まで触れていきます。

背景と問題設定

なぜリニアアテンションか?

トランスフォーマーの標準的なソフトマックス・アテンションは、系列長 L に対して計算コスト・メモリコストが O(L^2) となる問題があります。長いコンテキストを扱おうとすると、学習や推論に非常に時間やメモリを要するため、より効率的な “サブ二乗時間” のモデルが望まれてきました。

一方で、リニアアテンション(Linear Attention) は

{exp}(q_t k_i^\top)

のようなソフトマックスの指数演算を、単純な内積ベースのカーネル

k(x, y) = \langle \phi(x),\, \phi(y) \rangle

に置き換えることで、理論上は系列長 L に対して線形オーダー O(L) での推論(自動回帰生成)が可能になります。しかし、標準的な「パラレル形式」の実装(トランスフォーマーの学習でよく使われる並列処理形態)では、依然として O(L^2) のコストが必要です。そのため、学習時は並列化を最大限に活かし、推論時にリニア(線形時間)を実現するには「チャンク分割(Chunkwise)+再帰形式(RNN形式)」というアプローチが有力とされています。

なぜゲーティング (Gating) が必要か?

リニアアテンションは高速ですが、ソフトマックスアテンションと比べると精度面で不利なことが、これまでの研究で指摘されてきました。その要因のひとつとして、ゲート(忘却ゲート / フォゲットゲート) がないことが挙げられます。

LSTM や GRU など従来の RNN が優れた長期依存学習を可能にする重要要素は「ゲート機構」です。ところが、リニアアテンションをそのまま RNN 的に見ると、行列サイズの「高速重み (Fast Weights)」を足し合わせていく更新式のみで、入力依存の減衰やゲートがありません。

最近の一部のリニアアテンション研究(RetNet, TransNormer など)では、定数(学習パラメータだがスカラー)による減衰 をかけるアイデアが成果を上げています。しかし、スカラーだけではなく、「入力に依存する行列状のゲート(ゲーティング)」を導入できれば、RNN のフォゲットゲートに近い、より柔軟な情報制御ができるのでは、というのが本論文のモチベーションです。

論文の主な貢献

1. ハードウェア効率を考慮したリニアアテンションの実装

  • 既存のリニアアテンション実装は理論的には O(L) またはサブ二乗 O(L_d) などになり得るものの、実際の GPU 上ではデータのロード・メモリ転送などがネックになり、そこまで高速にならない問題がありました。

  • 本論文では、FLASHAttention 系列で培われた “IO-awareness”(入出力の最適化)を取り入れて、FLASHLINEARATTENTION という実装を開発し、Chunkwise の手法や再計算(recomputation)を組み合わせることで、高速かつメモリ効率の良い学習を実現しています。実験では、やや短めの系列長(例: 1k トークン)でも、標準のソフトマックス用の FlashAttention-2 より速いレイヤー速度を出せることを示しています。

2. 行列ゲート(2 次元ゲート)を導入したリニアアテンション(Gated Linear Attention; GLA)

  • 単純なスカラー減衰(RetNet など)よりも表現力を高めるために、入力トークンごとに α_t を計算し、それを行列形式のゲートとして適用するモデルを提案しています。

  • このゲートのパラメータ化には、低次元に写した特徴からシグモイドなどを通して得るなど、パラメータを増やしすぎず、かつゲートを行列レベルで適用できる設計を行っています。

3. チャンク分割 + 二段階タイル化での安定かつ半精度演算の活用

ゲートつきのリニアアテンションでは、指数や除算が絡むため単純に半精度行列演算(Tensor Core)に落としにくい部分があります。しかし、

  • チャンク分割(Chunkwise)に加え、チャンクの中をさらに小分割(「サブチャンク」と呼ぶ)することによって、ほとんどの部分を半精度の行列演算(Tensor Core)で計算し、
  • 個別の小領域(サブチャンク内部の因子)のみフル精度で安全に計算する
    という「二段階のタイル化」アイデアで、高速演算と数値安定性を両立させています。

4. 中規模(340M / 1.3B パラメータ)言語モデル実験での有効性実証

  • LLaMA 系列の「Transformer++」(ソフトマックスアテンションを用いた強力なベースライン)、あるいは最近登場したリニアアテンション系モデル(RetNet, Mamba)と比較して、GLA Transformer が同規模・同データで競合以上の性能、あるいは同程度の性能を出せることを示しています。
  • 特に長文(文脈長 20k を超えるテストなど)での性能劣化が小さく、「学習時 2k 長のコンテキストでしか訓練していなくても 18k 〜 20k に extrapolation(拡張)」が可能となる点を強調しています。
  • また、サブ二乗時間で学習を行えるため、長いコンテキストでの大規模学習が比較的やりやすいことも利点として挙げています。

5. リコール(記憶)を要するタスクでの有利性

  • 単に言語モデルとしての perplexity や Zero-shot 性能以外に、長文の情報を思い出す必要がある「Recall-intensive」な合成タスクやデータ抽出系タスクで、線形時間モデルが全般的に苦戦しがちな中、GLA が比較的良好な結果を示したことを報告しています。
  • また「ゲートの行列化による大きめの『隠れ状態』保持」が、Mamba(スカラーゲート)などより長期依存・リコールに有利だと示唆しています。

モデルとアルゴリズムの概要

1. リニアアテンションと再帰形式

  • 通常の(ソフトマックス)自己注意は

    o_t = \frac{\sum_{i=1}^t \exp(q_t k_i^\top)\, v_i}{\sum_{i=1}^t \exp(q_t k_i^\top)}

    のように計算しますが、リニアアテンションでは、これを

    o_t \;=\; q_t \left(\sum_{i=1}^t k_i^\top\, v_i \right)

    に置き換えます(単純化のために正規化を省いている形)。これによって再帰的な更新が可能になります。

  • 推論時には、

    S_t = S_{t-1} + k_t^\top v_t, \quad o_t = q_t\,S_t

    と RNN 的な定式化ができ、系列長に対して線形時間でオートレグレッシブ生成が可能です。

2. Chunkwise Parallel Form(チャンク分割並列)

  • ただし、学習時に全長 L について逐一再帰を回すと並列性が低くなります。そこで、Chunkwise(系列を複数の「チャンク」に分割)することにより、
    1. 各チャンク内部では並列計算(行列乗算)でまとめて処理
    2. チャンク間の依存のみ再帰的に順序づけて処理
      とする方法が考案されてきました。
  • この手法により、学習全体の計算量が (LC\,d) 程度に抑えられ、かつ GPU のテンソルコアを活用しやすい「行列乗算による大きなバッチ演算」が可能になる、という利点があります。

3. FLASHLINEARATTENTION:IO-Aware 実装

  • 既存の実装では、チャンク分割をしていてもテンソルの読み書き(HBM/SRAM 間のロードやストア)が過剰に発生し、想定ほど高速にならない問題がありました。
  • 本論文では、
    • タイル化 して同じブロックを何度もロードせずに済むようにする
    • シーケンス長方向の並列化SM(Streaming Multiprocessor)占有率 を考慮
    • 中間状態 {S1,…,St} などを一時的にメインメモリ(HBM)に書き出す方法と、再計算戦略を組み合わせることで、ソフトマックス用の最先端実装「FlashAttention-2」よりも短い系列長 (∼1k) でも高速になることを示しています。

4. ゲーティング付きリニアアテンション(GLA)

  • 上述のリニアアテンションに、行列ゲート G_t を導入し、

    S_t = G_t \odot S_{t-1} \;+\; k_t^\top v_t

    の形に拡張したものを提案しています。ここで G_t\rm{sigmoid} などを通して得る行列で、入力トークン x_t から求められます。

    • スカラーゲートを用いた Mamba/RetNet などの先行研究に比べ、1 トークンに対して全要素にわたる細かい抑制や忘却の制御が可能になります。
    • ただしゲート行列全体を巨大にするとパラメータが大幅に増えるので、低ランク分解を使ったパラメータ設計(例: 小さな中間次元を介して d\times16\times d_k 程度に抑える)などの工夫を取り入れています。
  • さらにこのゲート入りリニアアテンションを、上記と同様のチャンク分割+二段階タイル化+I/O 最適化によって効率的に計算できる「FLASHLINEARATTENTION」の拡張版を実装し、GLA Transformer として報告しています。

実験結果・考察

  1. 中規模言語モデル (340M / 1.3B パラメータ) の学習
    • データセット: SlimPajama (約 100B トークンのサブセット)
    • Baseline:
      • Transformer++ (ソフトマックスアテンション、LLaMA の改良版)
      • RetNet (グローバルな減衰)、Mamba (スカラーのデータ依存ゲート)
    • 結果:
      • Perplexity(WikiText, LAMBADA など)や常識推論タスクなどで、RetNet より精度が高く、Mamba や Transformer++ に近いレベルの性能。
      • 特に長い文脈での extrapolation(例: 学習時 2k でもテスト時 18k〜20k まで大きく伸ばす)能力が良好。Mamba, RetNet もある程度は伸びるが、ゲートの設計差などで挙動に違いあり、GLA のほうが安定。
      • 訓練スループット(tokens/s)は Mamba より高く、ほぼソフトマックストランスフォーマーに近い。
  2. リコール能力が必要なタスク
    • 合成タスク(Multi-Query Induction など)や、抽出型 QA など「ある区間で出てきた情報を後で正確に呼び出す必要がある」ようなタスクで、ゲート付きリニアアテンションが他のリニア系モデルより優位になりやすいことを確認。
    • これは、ゲート行列(行列サイズの隠れ状態)によって保持できる情報量が多いからではないかと示唆しています。
  3. 長文学習
    • 8k や 24k の長さで学習し、そのまま長文推論を行うケースでも、GLA は効率よく学習を回せる。
    • 24k 学習はトランケート BPTT(2k 分割)による再帰伝播を行っているため、そこまでメモリコストが増えずに 24k の学習が可能。
    • 長文で学習するとさらに長文推論が安定し、PG19 など大規模コーパスのテストで良好な結果が得られた。

限界と今後の展望

  • 本研究は 1.3B パラメータ規模までの実験であり、10B〜100B 超級の超大型モデルへ拡張した場合の挙動はまだ未知数。ただし理論的には、チャンク分割や行列ゲート実装を工夫することで、大規模化してもソフトマックス・トランスフォーマーより学習効率が良くなる可能性が示唆される。
  • モデルの安定性やメモリ使用量、学習レシピ(ハイパーパラメータ)に関しては、ソフトマックス・アテンションほど細かく最適化された実績がまだ少ないため、さらなる検討の余地がある。
  • リニアアテンションは、自然言語だけでなく、画像やマルチモーダルなど超長系列を扱う場面でも有用と見られ、GLA のようなゲーティング機構付きリニアアテンションの汎用的応用が期待される。

まとめ

  • 本論文の核心は、(1) リニアアテンションをハードウェア(GPU)上でいかに実際に高速に動かすか、(2) 忘却ゲートや減衰要素をデータ依存の行列として設計し、ソフトマックス・アテンションに迫る表現力を実現するか、という点にあります。
  • 提案手法である FLASHLINEARATTENTION は、短い系列長(~1k)でも高度に最適化された FlashAttention-2 より高速になり得るほど I/O を考慮した実装であり、Mamba/RetNet といった既存の線形アテンションモデルを上回る精度および学習効率を示しています。
  • ゲーティングを入れたリニアアテンション GLA は、(i) 長文推論への安定した拡張性、(ii) Recall が必要なタスクへの強さ、(iii) 適切な行列サイズでの効率的な実装、を兼ね備えた有力なモデルとして位置づけられます。

以上が本論文の概要と主な貢献です。ソフトマックス・アテンションの二乗オーダーを回避しつつ、ゲーティングによる高い表現力を得た点、そして何より実際に高速かつ大規模学習に耐える実装を提供している点が本研究の大きな意義といえます。

Discussion