Open5
MLA vs GQA

paper: TransMLA: Multi-Head Latent Attention Is All You Need, Feb 2025

用語
- Grouped Query Attention (GQA): K/Vのprojectionの計算において、headをグループ化し、グループごとに共通のweight matrixを使用する[2]。性能を犠牲にせずに推論スピードを上げる手法としてGemma 2から採用された[4-5]。
- Multi-head Latent Attention (MLA): K/Vのprojectionの計算において、LoRAのような感じでlow-rankの表現に落とし込むことでKV-cacheの容量を削減する。DeepSeek-V2から導入された[3]。
背景
- この論文ではGQAとMLAの関係性を指摘し、「MLAの方がGQAよりも表現力が高く、アーキテクチャの設計として優れている」ことを示そうとするもの
論文の主張
- GQAは等価なMHAに変換できるが、その逆は一般的には成立しない
- KV-cacheの次元が同じ場合、MHAがGQAよりもわずかなパラメータ数の増加で優れた表現力を持つ
- 2の主張は実験結果1からも裏付けられた
実験結果
- W_K, W_Vの重み以外をfreezeしてfine-tuningした場合、MHAの方がGQAよりもPPL/downstream taskで優位だった
論文の主張の妥当性について
- MHAにおいて、本来のhead数を
とすると、GQAは縮退したhead数n _ h (gemma 2ではn _ k )を持ち、足りない分はrepeatで補う(同じ重みを共有するのと等価)n _ h = 8/16/32, n _ k = 2 - GQAを等価なMHAで表現すると、そのrank
は高々r であるn _ k d _ h - したがってGQAをSVDで分解した場合、以下のように2線形変換の積で表現できる。これは形式的にMLAと同じである(
の代わりにn _ k d _ h を用いているのみ)。r \le n _ k d _ h
- よってGQAが表現する線形空間はMLAが表現する線形空間に含まれる。
→確かに線形変換の包含関係としてはそうかもしれないが、両者のrankがどうなるかは学習結果次第なので一般的にどちらが表現力が高いとかは言えないのでは?
反論
- 実験1で
を固定するのはフェアではない。以下の(1)式が示すように、W _ Q との積を計算することでMHAのW _ Q と同等な操作がW _ K^b 側で表現できる可能性があり、W _ Q も学習する設定だとHMAは性能においてGQAより優れているとは言えない可能性がある。W _ Q
- あと、モデル全体でのパラメータ数の増加が微々たるものであっても、実験1の設定だとtrainableなパラメータ数がMLAがGQAの2倍なのでそもそも両者を比較することはフェアでない。
所感
- 少なくともこの論文の実験結果から直ちに「MHAがGQAより優れている」とは言えない
- GQAからMHAへの変換が可能というのは面白い
-
も学習する前提でもMHAがGQAより優位なのか気になるW _ Q

Reference
- [1] TransMLA: Multi-Head Latent Attention Is All You Need, https://arxiv.org/abs/2502.07864
- [2] GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints, https://arxiv.org/abs/2305.13245v3
- [3] DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model, https://arxiv.org/abs/2405.04434
- [4] Gemma explained: What’s new in Gemma 2, https://developers.googleblog.com/en/gemma-explained-new-in-gemma-2/
- [5] Gemma 2: Improving Open Language Models at a Practical Size, https://storage.googleapis.com/deepmind-media/gemma/gemma-2-report.pdf

参考:

DeepSeek-V2の論文[3]ではより現実的な設定でAblationを実施していて、この結果によるとMLAがMQA/GQAと比べてかなりパフォーマンスが高い結果になっている(7B model x 1.33T token)。
さらに、MHA<MLAという結果も示している(これはrankを下げたことによる正則化の影響か?)。