目的
DeepSeek V2論文のMulti-head Latent Attention (MLA)における、下記表現を理解するために「数式で」整理をすることが目的の記事です。
主に、「どの重みを」「どのように吸収するのか」を「数式で」整理します。
(そのうち詳細な解説付きの記事にするかも)
In addition, during inference, since W U K W^{UK} W U K can be absorbed into W Q W^Q W Q , and W U V W^{UV} W U V can be absorbed into W O W^{O} W O , we even do not need to compute keys and values out for attention.
本記事は、DeepSeek V2の論文中に記載されているMulti-head Latent Attention (MLA)について理解されていることが前提になります。
https://arxiv.org/abs/2405.04434
なお、RoPE導入後も成立するという主張に関しては、「今回は」取り扱いません。
ただ、基本的な主張は変わらないです。そのうち記事にするかもです。
とはいえ、論文中では行間がすごく広くなっているMLAに関して、詳細に記述しているので、誰かの理解の助けになれば嬉しいです。
あと、今のAIでは書けない記事を書きたかった。
(o1 proとかに解説を依頼したとしても、ここまで詳細には解説してくれないはずです)
なお、最近流行りのDeepSeek-R1に関しても論文解説記事を書いています。
本記事とは異なり、理論に詳しくない方が読んで面白い!と思っていただけるように記事を書いたつもりなので、ぜひ読んでいただけると嬉しいです。
https://zenn.dev/asap/articles/34237ad87f8511
数式でMLAの整理
Scaled Dot-Product Attentionは下記で表現されます。
Attention ( i ) ( Q ( i ) , K ( i ) , V ( i ) ) = softmax ( Q ( i ) K ( i ) ⊤ d h ) V ( i )
\text{Attention}_{(i)}(Q_{(i)}, K_{(i)}, V_{(i)}) = \text{softmax} \left( \frac{Q_{(i)}K_{(i)}^\top}{\sqrt{d_h}} \right) V_{(i)}
Attention ( i ) ( Q ( i ) , K ( i ) , V ( i ) ) = softmax ( d h Q ( i ) K ( i ) ⊤ ) V ( i )
上記を前提に、MLAの式展開を考察していきます。
記号の導入
MLAは主にKVキャッシュとクエリの圧縮を目的にしています。
したがって、圧縮されたKV表現と圧縮されたクエリ表現を取り扱うために、記号を導入していきます。
圧縮KV表現とKV
ここでは、圧縮KV表現と、元のKVについて記載していきます。
MLAにおいてKVの圧縮表現C K V ∈ R T × d c \mathbf{C}^{KV} \in \mathbb{R}^{T \times d_c} C K V ∈ R T × d c を導入します。
T T T は全体のtoken数、d c d_c d c は圧縮KV表現の次元数(論文では512次元)です。
この圧縮表現は下記のより作成されます。
(各tokenごとに下記の式変形が適用されます。つまりtoken数分だけ並列処理されます)
c t K V = W D K V h t
\begin{equation}
\mathbf{c}_t^{KV} = W^{DKV} \mathbf{h_t}
\end{equation}
c t K V = W DK V h t
C K V = [ ( c 1 K V ) T ( c 2 K V ) T ⋮ ( c T K V ) T ]
\begin{equation}
\mathbf{C}^{KV} =
\begin{bmatrix}
(\mathbf{c}_1^{KV})^T \\
(\mathbf{c}_2^{KV})^T \\
\vdots \\
(\mathbf{c}_T^{KV})^T
\end{bmatrix}
\end{equation}
C K V = ( c 1 K V ) T ( c 2 K V ) T ⋮ ( c T K V ) T
ただしt t t は、あるt t t token目の特徴量であることを示します。
また、h t ∈ R d \mathbf{h_t} \in \mathbb{R}^{d} h t ∈ R d は、前の層の、あるt t t token目での出力で、d d d は前の層の出力次元数です(論文中では5120次元)
また、W D K V ∈ R d c × d W^{DKV} \in \mathbb{R}^{d_c \times d} W DK V ∈ R d c × d はdown-projection行列です。
c t K V ∈ R d c \mathbf{c}_t^{KV} \in \mathbb{R}^{d_c} c t K V ∈ R d c は、あるt t t token目での、KVの圧縮表現です。
そして、C K V ∈ R T × d c \mathbf{C}^{KV} \in \mathbb{R}^{T \times d_c} C K V ∈ R T × d c はc t K V \mathbf{c}_t^{KV} c t K V を全てのtokenでまとめた行列になります。
KVの圧縮表現C K V \mathbf{C}^{KV} C K V を利用すると、キーとバリューを復元できます。
K = ( W U K ( C K V ) T ) T
\begin{equation}
\mathbf{K} = \left(W^{UK} (\mathbf{C}^{KV})^T\right)^T
\end{equation}
K = ( W U K ( C K V ) T ) T
V = ( W U V ( C K V ) T ) T
\begin{equation}
\mathbf{V} = \left(W^{UV} (\mathbf{C}^{KV})^T\right)^T
\end{equation}
V = ( W U V ( C K V ) T ) T
W U K , W U V ∈ R d h n h × d c W^{UK},W^{UV} \in \mathbb{R}^{d_hn_h \times d_c} W U K , W U V ∈ R d h n h × d c はup-projection行列です。
圧縮されたKVから、マルチヘッドの全てのキー、バリュー表現を復元します。
d h d_h d h は1ヘッドあたりの次元数(論文では128次元)、n h n_h n h はマルチヘッドのヘッド数(論文では128個)
K , V ∈ R T × d h n h \mathbf{K}, \mathbf{V} \in \mathbb{R}^{T \times d_hn_h} K , V ∈ R T × d h n h は、それぞれ圧縮されていないキー、バリュー表現です。
圧縮クエリ表現とQ
続いて、クエリに関しても学習時のパラメータ削除のために圧縮します。
c t Q = W D Q h t
\begin{equation}
\mathbf{c}_t^{Q} = W^{DQ} \mathbf{h_t}
\end{equation}
c t Q = W D Q h t
C Q = [ ( c 1 Q ) T ( c 2 Q ) T ⋮ ( c T Q ) T ]
\begin{equation}
\mathbf{C}^{Q} =
\begin{bmatrix}
(\mathbf{c}_1^{Q})^T \\
(\mathbf{c}_2^{Q})^T \\
\vdots \\
(\mathbf{c}_T^{Q})^T
\end{bmatrix}
\end{equation}
C Q = ( c 1 Q ) T ( c 2 Q ) T ⋮ ( c T Q ) T
また、W D Q ∈ R d c ′ × d W^{DQ} \in \mathbb{R}^{d_c' \times d} W D Q ∈ R d c ′ × d はdown-projection行列です。
d c ′ d_c' d c ′ は圧縮クエリ表現の次元数(論文では1536次元)
c t Q ∈ R d c ′ \mathbf{c}_t^{Q} \in \mathbb{R}^{d_c'} c t Q ∈ R d c ′ は、あるt t t token目での、クエリの圧縮表現です。
C Q ∈ R T × d c ′ \mathbf{C}^{Q} \in \mathbb{R}^{T \times d_c'} C Q ∈ R T × d c ′ は、token全体でのクエリの圧縮表現になります。
このクエリの圧縮表現C Q \mathbf{C}^{Q} C Q を利用すると、元のクエリを復元できます。
Q = ( W U Q ( C Q ) T ) T
\begin{equation}
\mathbf{Q} = \left(W^{UQ} (\mathbf{C}^{Q})^T\right)^T
\end{equation}
Q = ( W U Q ( C Q ) T ) T
W U Q ∈ R d h n h × d c ′ W^{UQ} \in \mathbb{R}^{d_hn_h \times d_c'} W U Q ∈ R d h n h × d c ′ はup-projection行列です。
クエリの圧縮表現から、マルチヘッドの全てのクエリ表現を復元します。
Q ∈ R T × d h n h \mathbf{Q} \in \mathbb{R}^{T \times d_hn_h} Q ∈ R T × d h n h は、それぞれ圧縮されていないクエリ表現です。
通常のMHA(Multi-Head Attention)のQKVと、 MLAの違い
通常のMHAにおいて、QKVはそれぞれ下記で表現されます。
MHAでのキー表現
k t = W K h t
\begin{equation}
\mathbf{k}_t = W^{K} \mathbf{h_t}
\end{equation}
k t = W K h t
K = [ ( k 1 ) T ( k 2 ) T ⋮ ( k T ) T ]
\begin{equation}
\mathbf{K} =
\begin{bmatrix}
(\mathbf{k}_1)^T \\
(\mathbf{k}_2)^T \\
\vdots \\
(\mathbf{k}_T)^T
\end{bmatrix}
\end{equation}
K = ( k 1 ) T ( k 2 ) T ⋮ ( k T ) T
したがって、
K = [ ( W K h 1 ) T ( W K h 2 ) T ⋮ ( W K h T ) T ]
\begin{equation}
\mathbf{K} =
\begin{bmatrix}
(W^{K} \mathbf{h_1})^T \\
(W^{K} \mathbf{h_2})^T \\
\vdots \\
(W^{K} \mathbf{h_T})^T
\end{bmatrix}
\end{equation}
K = ( W K h 1 ) T ( W K h 2 ) T ⋮ ( W K h T ) T
なお、k t ∈ R d h n h \mathbf{k}_t \in \mathbb{R}^{d_hn_h} k t ∈ R d h n h 、W K ∈ R d h n h × d W^{K} \in \mathbb{R}^{d_hn_h \times d} W K ∈ R d h n h × d 、K ∈ R T × d h n h \mathbf{K} \in \mathbb{R}^{T \times d_hn_h} K ∈ R T × d h n h
以下はほぼ同じなので、隠しておきます。興味あればご覧ください。
MHAでのクエリ表現
MHAでのクエリ表現
q t = W Q h t
\begin{equation}
\mathbf{q}_t = W^{Q} \mathbf{h_t}
\end{equation}
q t = W Q h t
Q = [ ( q 1 ) T ( q 2 ) T ⋮ ( q T ) T ]
\begin{equation}
\mathbf{Q} =
\begin{bmatrix}
(\mathbf{q}_1)^T \\
(\mathbf{q}_2)^T \\
\vdots \\
(\mathbf{q}_T)^T
\end{bmatrix}
\end{equation}
Q = ( q 1 ) T ( q 2 ) T ⋮ ( q T ) T
したがって、
Q = [ ( W Q h 1 ) T ( W Q h 2 ) T ⋮ ( W Q h T ) T ]
\begin{equation}
\mathbf{Q} =
\begin{bmatrix}
(W^{Q} \mathbf{h_1})^T \\
(W^{Q} \mathbf{h_2})^T \\
\vdots \\
(W^{Q} \mathbf{h_T})^T
\end{bmatrix}
\end{equation}
Q = ( W Q h 1 ) T ( W Q h 2 ) T ⋮ ( W Q h T ) T
なお、q t ∈ R d h n h \mathbf{q}_t \in \mathbb{R}^{d_hn_h} q t ∈ R d h n h 、W Q ∈ R d h n h × d W^{Q} \in \mathbb{R}^{d_hn_h \times d} W Q ∈ R d h n h × d 、Q ∈ R T × d h n h \mathbf{Q} \in \mathbb{R}^{T \times d_hn_h} Q ∈ R T × d h n h
MHAでのバリュー表現
MHAでのバリュー表現
v t = W V h t
\begin{equation}
\mathbf{v}_t = W^{V} \mathbf{h_t}
\end{equation}
v t = W V h t
V = [ ( v 1 ) T ( v 2 ) T ⋮ ( v T ) T ]
\begin{equation}
\mathbf{V} =
\begin{bmatrix}
(\mathbf{v}_1)^T \\
(\mathbf{v}_2)^T \\
\vdots \\
(\mathbf{v}_T)^T
\end{bmatrix}
\end{equation}
V = ( v 1 ) T ( v 2 ) T ⋮ ( v T ) T
したがって、
V = [ ( W V h 1 ) T ( W V h 2 ) T ⋮ ( W V h T ) T ]
\begin{equation}
\mathbf{V} =
\begin{bmatrix}
(W^{V} \mathbf{h_1})^T \\
(W^{V} \mathbf{h_2})^T \\
\vdots \\
(W^{V} \mathbf{h_T})^T
\end{bmatrix}
\end{equation}
V = ( W V h 1 ) T ( W V h 2 ) T ⋮ ( W V h T ) T
なお、v t ∈ R d h n h \mathbf{v}_t \in \mathbb{R}^{d_hn_h} v t ∈ R d h n h 、W V ∈ R d h n h × d W^{V} \in \mathbb{R}^{d_hn_h \times d} W V ∈ R d h n h × d 、V ∈ R T × d h n h \mathbf{V} \in \mathbb{R}^{T \times d_hn_h} V ∈ R T × d h n h
見るとわかるように、それぞれほぼ同じように構成されます。
なので、キー表現に着目して、MHAとMLAの違いを見ていきます。
MLAにおけるキー表現
重要なのは、「圧縮KV表現とKV」の章で記載した下記の3式です。
c t K V = W D K V h t
\mathbf{c}_t^{KV} = W^{DKV} \mathbf{h_t}
c t K V = W DK V h t
C K V = [ ( c 1 K V ) T ( c 2 K V ) T ⋮ ( c T K V ) T ]
\mathbf{C}^{KV} =
\begin{bmatrix}
(\mathbf{c}_1^{KV})^T \\
(\mathbf{c}_2^{KV})^T \\
\vdots \\
(\mathbf{c}_T^{KV})^T
\end{bmatrix}
C K V = ( c 1 K V ) T ( c 2 K V ) T ⋮ ( c T K V ) T
K = ( W U K ( C K V ) T ) T
\mathbf{K} = \left(W^{UK} (\mathbf{C}^{KV})^T\right)^T
K = ( W U K ( C K V ) T ) T
それぞれ式番号(1)(2)(3)に対応しています。
これをまとめていきます。
式(2)(3)から
K = ( W U K [ c 1 K V , c 2 K V , ⋯ , c T K V ] ) T
\begin{equation}
\mathbf{K} = \left(W^{UK} [\mathbf{c}_1^{KV}, \mathbf{c}_2^{KV}, \cdots ,\mathbf{c}_T^{KV}]\right)^T
\end{equation}
K = ( W U K [ c 1 K V , c 2 K V , ⋯ , c T K V ] ) T
式(1)(11)から
K = ( W U K [ W D K V h 1 , W D K V h 2 , ⋯ , W D K V h T ] ) T
\mathbf{K} = \left(W^{UK} [W^{DKV} \mathbf{h_1} ,W^{DKV} \mathbf{h_2} ,\cdots ,W^{DKV} \mathbf{h_T}]\right)^T
K = ( W U K [ W DK V h 1 , W DK V h 2 , ⋯ , W DK V h T ] ) T
= [ W U K W D K V h 1 , W U K W D K V h 2 , ⋯ , W U K W D K V h T ] T
= [W^{UK}W^{DKV} \mathbf{h_1} ,W^{UK}W^{DKV} \mathbf{h_2} ,\cdots ,W^{UK}W^{DKV} \mathbf{h_T}]^T
= [ W U K W DK V h 1 , W U K W DK V h 2 , ⋯ , W U K W DK V h T ] T
となるため、これまでの書き方と表記を揃えると
K = [ ( W U K W D K V h 1 ) T ( W U K W D K V h 2 ) T ⋮ ( W U K W D K V h T ) T ]
\begin{equation}
\mathbf{K} =
\begin{bmatrix}
(W^{UK}W^{DKV} \mathbf{h_1})^T \\
(W^{UK}W^{DKV} \mathbf{h_2})^T \\
\vdots \\
(W^{UK}W^{DKV} \mathbf{h_T})^T
\end{bmatrix}
\end{equation}
K = ( W U K W DK V h 1 ) T ( W U K W DK V h 2 ) T ⋮ ( W U K W DK V h T ) T
なお、念のため記載すると、K ∈ R T × d h n h \mathbf{K} \in \mathbb{R}^{T \times d_hn_h} K ∈ R T × d h n h 、W U K W D K V h t ∈ R d h n h W^{UK}W^{DKV} \mathbf{h_t} \in \mathbb{R}^{d_hn_h} W U K W DK V h t ∈ R d h n h です。
ちなみに、MHAのキー表現は下記でした。
K = [ ( W K h 1 ) T ( W K h 2 ) T ⋮ ( W K h T ) T ]
\begin{equation}
\mathbf{K} =
\begin{bmatrix}
(W^{K} \mathbf{h_1})^T \\
(W^{K} \mathbf{h_2})^T \\
\vdots \\
(W^{K} \mathbf{h_T})^T
\end{bmatrix}
\end{equation}
K = ( W K h 1 ) T ( W K h 2 ) T ⋮ ( W K h T ) T
したがって、適切に学習がなされた場合、式(12)(13)から、W K = W U K W D K V W^{K} = W^{UK}W^{DKV} W K = W U K W DK V となります。
このとき、W K W^{K} W K の行列のランクやW U K W^{UK} W U K やW D K V W^{DKV} W DK V の行列の自由度d c d_c d c によって、どの程度劣化するかが決まります。
万が一、W K W^{K} W K の行列のランクが、W U K W^{UK} W U K やW D K V W^{DKV} W DK V の行列の自由度d c d_c d c よりも小さい場合は、適切に学習がなされることにより、劣化が生じません。
ついでに、W K ∈ R d h n h × d W^{K} \in \mathbb{R}^{d_hn_h \times d} W K ∈ R d h n h × d は、パラメータ数がd h n h × d d_hn_h \times d d h n h × d なのに対して、W U K ∈ R d h n h × d c W^{UK} \in \mathbb{R}^{d_hn_h \times d_c} W U K ∈ R d h n h × d c とW D K V ∈ R d c × d W^{DKV} \in \mathbb{R}^{d_c \times d} W DK V ∈ R d c × d に分割した場合、パラメータ数がd h n h × d c + d c × d d_hn_h \times d_c + d_c \times d d h n h × d c + d c × d となります。
ここで、d d d はd c d_c d c と比較して非常に大きいため、前者のパラメータよりも、少ないパラメータにて近似が可能です。
深層学習の世界において、しばしば学習可能なパラメータ数を少なくして近似することで、精度の向上が見られたことがあり(画像処理でいうCNNの世界)、今回MHAよりもMLAが性能が良い結果になったのは、その観点も効いているのかもしれません。
be absorbedってどういうこと??
さて、実際にここから、論文中の下記の内容が成立することを確かめます。
In addition, during inference, since W U K W^{UK} W U K can be absorbed into W Q W^Q W Q , and W U V W^{UV} W U V can be absorbed into W O W^{O} W O , we even do not need to compute keys and values out for attention.
W^{UK}の吸収
単一ヘッドを前提とした場合、あるヘッドi i i に着目すると、Attentionは下記のように表すことができます。
SDPA ( i ) ( Q ( i ) , K ( i ) , V ( i ) ) = softmax ( Q ( i ) K ( i ) ⊤ d h ) V ( i )
\begin{equation}
\text{SDPA}_{(i)}(Q_{(i)}, K_{(i)}, V_{(i)}) = \text{softmax} \left( \frac{Q_{(i)}K_{(i)}^\top}{\sqrt{d_h}} \right) V_{(i)}
\end{equation}
SDPA ( i ) ( Q ( i ) , K ( i ) , V ( i ) ) = softmax ( d h Q ( i ) K ( i ) ⊤ ) V ( i )
なお、SDPA ( i ) ∈ R T × d h \text{SDPA}_{(i)} \in \mathbb{R}^{T \times d_h} SDPA ( i ) ∈ R T × d h です。
SDPAはScaled Dot-Product Attentionのことです。
今回の
W U K W^{UK} W U K can be absorbed into W Q W^Q W Q
の部分は、Q ( i ) K ( i ) ⊤ Q_{(i)}K_{(i)}^\top Q ( i ) K ( i ) ⊤ に関連する部分なので、その部分に着目します。
ただし、ヘッドごとに見ていくため、これまでの記号は分割されて下記のようになります。
Q ( i ) ∈ R T × d h Q_{(i)} \in \mathbb{R}^{T \times d_h} Q ( i ) ∈ R T × d h 、K ( i ) ∈ R T × d h K_{(i)} \in \mathbb{R}^{T \times d_h} K ( i ) ∈ R T × d h 、W ( i ) U Q ∈ R d h × d c ′ W^{UQ}_{(i)} \in \mathbb{R}^{d_h \times d_c'} W ( i ) U Q ∈ R d h × d c ′ 、W ( i ) U K ∈ R d h × d c W^{UK}_{(i)} \in \mathbb{R}^{d_h \times d_c} W ( i ) U K ∈ R d h × d c
下記は、ヘッドごとに共通で利用するため変わりません。
C Q ∈ R T × d c ′ \mathbf{C}^{Q} \in \mathbb{R}^{T \times d_c'} C Q ∈ R T × d c ′ 、C K V ∈ R T × d c \mathbf{C}^{KV} \in \mathbb{R}^{T \times d_c} C K V ∈ R T × d c
さて、MLAにおいてQ ( i ) K ( i ) ⊤ Q_{(i)}K_{(i)}^\top Q ( i ) K ( i ) ⊤ は下記のように書き下せます。
Q ( i ) K ( i ) ⊤ = ( W ( i ) U Q ( C Q ) T ) T W ( i ) U K ( C K V ) T
Q_{(i)}K_{(i)}^\top = \left(W^{UQ}_{(i)} (\mathbf{C}^{Q})^T\right)^T W^{UK}_{(i)} (\mathbf{C}^{KV})^T
Q ( i ) K ( i ) ⊤ = ( W ( i ) U Q ( C Q ) T ) T W ( i ) U K ( C K V ) T
= C Q ( W ( i ) U Q ) T W ( i ) U K ( C K V ) T
= \mathbf{C}^{Q} (W^{UQ}_{(i)})^T W^{UK}_{(i)} (\mathbf{C}^{KV})^T
= C Q ( W ( i ) U Q ) T W ( i ) U K ( C K V ) T
= ( ( W ( i ) U K ) T W ( i ) U Q ( C Q ) T ) T ( C K V ) T
\begin{equation}
= \left((W^{UK}_{(i)})^T W^{UQ}_{(i)} (\mathbf{C}^{Q})^T \right)^T (\mathbf{C}^{KV})^T
\end{equation}
= ( ( W ( i ) U K ) T W ( i ) U Q ( C Q ) T ) T ( C K V ) T
上記において、W ( i ) U K ∈ R d h × d c W^{UK}_{(i)} \in \mathbb{R}^{d_h \times d_c} W ( i ) U K ∈ R d h × d c 、W ( i ) U Q ∈ R d h × d c ′ W^{UQ}_{(i)} \in \mathbb{R}^{d_h \times d_c'} W ( i ) U Q ∈ R d h × d c ′ は学習時は分割して学習された重み行列になります。
そこで、推論時には、これらの重み行列から事前に、( W ( i ) U K ) T W ( i ) U Q ∈ R d c × d c ′ (W^{UK}_{(i)})^T W^{UQ}_{(i)} \in \mathbb{R}^{d_c \times d_c'} ( W ( i ) U K ) T W ( i ) U Q ∈ R d c × d c ′ を計算しておき、新しい一つの重み行列にすることもできます。
そうすると、圧縮KV表現C K V \mathbf{C}^{KV} C K V と圧縮クエリ表現C Q \mathbf{C}^{Q} C Q から明示的にキーとクエリを復元することなく、一撃でQ ( i ) K ( i ) ⊤ Q_{(i)}K_{(i)}^\top Q ( i ) K ( i ) ⊤ を計算できます。
これが推論時に、「W U K W^{UK} W U K はW U Q W^{UQ} W U Q (W Q W^{Q} W Q )に吸収される」という言葉の意味だと思われます。
W^{UV}の吸収
今回の、
W U V W^{UV} W U V can be absorbed into W O W^{O} W O
の部分は、Attentionマップを作成し、バリューとの行列積を計算した後、各Headごとの出力をW O W^{O} W O で一つにまとめる部分です。この部分は少し入り組んでいるので、少し整理していきます。
まず、Attentionマップの部分をA t n ( i ) Atn_{(i)} A t n ( i ) とまとめます。
具体的には下記です。
Atn ( i ) ( Q ( i ) , K ( i ) ) = softmax ( Q ( i ) K ( i ) ⊤ d h )
\begin{equation}
\text{Atn}_{(i)}(Q_{(i)}, K_{(i)}) = \text{softmax} \left( \frac{Q_{(i)}K_{(i)}^\top}{\sqrt{d_h}} \right)
\end{equation}
Atn ( i ) ( Q ( i ) , K ( i ) ) = softmax ( d h Q ( i ) K ( i ) ⊤ )
ここで、Atn ( i ) ∈ R T × T \text{Atn}_{(i)} \in \mathbb{R}^{T \times T} Atn ( i ) ∈ R T × T です。
すると、Scaled Dot-Product Attentionは下記のように表せます。
SDPA ( i ) = Atn ( i ) V ( i )
\begin{equation}
\text{SDPA}_{(i)} = \text{Atn}_{(i)} V_{(i)}
\end{equation}
SDPA ( i ) = Atn ( i ) V ( i )
その場合、MHA(Multi-Head Attention)は、下記のように表現されます。
MHA ( Q , K , V ) = W O [ SDPA ( 1 ) T SDPA ( 2 ) T ⋮ SDPA ( n h ) T ]
\begin{equation}
\text{MHA}(Q, K, V) = W^{O}
\begin{bmatrix}
\text{SDPA}_{(1)}^T \\
\text{SDPA}_{(2)}^T \\
\vdots \\
\text{SDPA}_{(n_h)}^T
\end{bmatrix}
\end{equation}
MHA ( Q , K , V ) = W O SDPA ( 1 ) T SDPA ( 2 ) T ⋮ SDPA ( n h ) T
ここで、W O ∈ R d × d h n h W^{O} \in \mathbb{R}^{d \times d_hn_h} W O ∈ R d × d h n h 、S D P A ( i ) T ∈ R d h × T {SDPA}_{(i)}^T \in \mathbb{R}^{d_h \times T} S D P A ( i ) T ∈ R d h × T であり、
[ SDPA ( 1 ) T SDPA ( 2 ) T ⋮ SDPA ( n h ) T ] ∈ R d h n h × T
\begin{equation}
\begin{bmatrix}
\text{SDPA}_{(1)}^T \\
\text{SDPA}_{(2)}^T \\
\vdots \\
\text{SDPA}_{(n_h)}^T
\end{bmatrix}
\in \mathbb{R}^{d_hn_h \times T}
\end{equation}
SDPA ( 1 ) T SDPA ( 2 ) T ⋮ SDPA ( n h ) T ∈ R d h n h × T
となります。
ここで、ちょっとトリッキーですが、W O W^{O} W O を下記のように分解していきます。
MHA ( Q , K , V ) = [ W ( 1 ) O , W ( 2 ) O , ⋯ , W ( n h ) O ] [ SDPA ( 1 ) T SDPA ( 2 ) T ⋮ SDPA ( n h ) T ]
\begin{equation}
\text{MHA}(Q, K, V) = [W^{O}_{(1)}, W^{O}_{(2)}, \cdots, W^{O}_{(n_h)}]
\begin{bmatrix}
\text{SDPA}_{(1)}^T \\
\text{SDPA}_{(2)}^T \\
\vdots \\
\text{SDPA}_{(n_h)}^T
\end{bmatrix}
\end{equation}
MHA ( Q , K , V ) = [ W ( 1 ) O , W ( 2 ) O , ⋯ , W ( n h ) O ] SDPA ( 1 ) T SDPA ( 2 ) T ⋮ SDPA ( n h ) T
このとき、W ( i ) O ∈ R d × d h W^{O}_{(i)} \in \mathbb{R}^{d \times d_h} W ( i ) O ∈ R d × d h となります。
このとき、行列積の性質として下記が成立します。
MHA ( Q , K , V ) = [ W ( 1 ) O , W ( 2 ) O , ⋯ , W ( n h ) O ] [ SDPA ( 1 ) T SDPA ( 2 ) T ⋮ SDPA ( n h ) T ] = W ( 1 ) O SDPA ( 1 ) T + W ( 2 ) O SDPA ( 2 ) T + ⋯ + W ( n h ) O SDPA ( n h ) T
\begin{equation}
\begin{aligned}
\text{MHA}(Q, K, V) &=[W^{O}_{(1)}, W^{O}_{(2)}, \cdots, W^{O}_{(n_h)}]
\begin{bmatrix}
\text{SDPA}_{(1)}^T \\
\text{SDPA}_{(2)}^T \\
\vdots \\
\text{SDPA}_{(n_h)}^T
\end{bmatrix} \\
&= W^{O}_{(1)} \text{SDPA}_{(1)}^T \\
&\quad + W^{O}_{(2)} \text{SDPA}_{(2)}^T \\
&\quad + \cdots + W^{O}_{(n_h)} \text{SDPA}_{(n_h)}^T
\end{aligned}
\end{equation}
MHA ( Q , K , V ) = [ W ( 1 ) O , W ( 2 ) O , ⋯ , W ( n h ) O ] SDPA ( 1 ) T SDPA ( 2 ) T ⋮ SDPA ( n h ) T = W ( 1 ) O SDPA ( 1 ) T + W ( 2 ) O SDPA ( 2 ) T + ⋯ + W ( n h ) O SDPA ( n h ) T
式(17)を展開して、下記のように書きます。
MHA ( Q , K , V ) = W ( 1 ) O ( Atn ( 1 ) V ( 1 ) ) T + W ( 2 ) O ( Atn ( 2 ) V ( 2 ) ) T + ⋯ + W ( n h ) O ( Atn ( n h ) V ( n h ) ) T
\begin{equation}
\begin{aligned}
\text{MHA}(Q, K, V) &= W^{O}_{(1)} \left(\text{Atn}_{(1)}V_{(1)}\right)^T \\
&\quad + W^{O}_{(2)} \left(\text{Atn}_{(2)}V_{(2)}\right)^T \\
&\quad + \cdots + W^{O}_{(n_h)} \left(\text{Atn}_{(n_h)} V_{(n_h)}\right)^T
\end{aligned}
\end{equation}
MHA ( Q , K , V ) = W ( 1 ) O ( Atn ( 1 ) V ( 1 ) ) T + W ( 2 ) O ( Atn ( 2 ) V ( 2 ) ) T + ⋯ + W ( n h ) O ( Atn ( n h ) V ( n h ) ) T
このとき、式(4)から下記にように式変形できます。
!
式(4)は下記です。
V = ( W U V ( C K V ) T ) T
\mathbf{V} = \left(W^{UV} (\mathbf{C}^{KV})^T\right)^T
V = ( W U V ( C K V ) T ) T
なお、C K V ∈ R T × d c \mathbf{C}^{KV} \in \mathbb{R}^{T \times d_c} C K V ∈ R T × d c 、W U V ∈ R d h n h × d c W^{UV} \in \mathbb{R}^{d_hn_h \times d_c} W U V ∈ R d h n h × d c 、V ∈ R T × d h n h \mathbf{V} \in \mathbb{R}^{T \times d_hn_h} V ∈ R T × d h n h
すると、各Headごとに分解すると下記にようになります。
V ( i ) = ( W ( i ) U V ( C K V ) T ) T
\mathbf{V}_{(i)} = \left(W^{UV}_{(i)} (\mathbf{C}^{KV})^T\right)^T
V ( i ) = ( W ( i ) U V ( C K V ) T ) T
= C K V ( W ( i ) U V ) T
= \mathbf{C}^{KV} (W^{UV}_{(i)})^T
= C K V ( W ( i ) U V ) T
なお、C K V ∈ R T × d c \mathbf{C}^{KV} \in \mathbb{R}^{T \times d_c} C K V ∈ R T × d c 、W ( i ) U V ∈ R d h × d c W^{UV}_{(i)} \in \mathbb{R}^{d_h \times d_c} W ( i ) U V ∈ R d h × d c 、V ( i ) ∈ R T × d h \mathbf{V}_{(i)} \in \mathbb{R}^{T \times d_h} V ( i ) ∈ R T × d h
MHA ( Q , K , V ) = W ( 1 ) O ( Atn ( 1 ) C K V ( W ( 1 ) U V ) T ) T + W ( 2 ) O ( Atn ( 2 ) C K V ( W ( 2 ) U V ) T ) T + ⋯ + W ( n h ) O ( Atn ( n h ) C K V ( W ( n h ) U V ) T ) T
\begin{aligned}
\text{MHA}(Q, K, V) &= W^{O}_{(1)} \left(\text{Atn}_{(1)}\mathbf{C}^{KV} (W^{UV}_{(1)})^T \right)^T \\
&\quad + W^{O}_{(2)} \left(\text{Atn}_{(2)}\mathbf{C}^{KV} (W^{UV}_{(2)})^T \right)^T \\
&\quad + \cdots + W^{O}_{(n_h)} \left(\text{Atn}_{(n_h)} \mathbf{C}^{KV} (W^{UV}_{(n_h)})^T \right)^T
\end{aligned}
MHA ( Q , K , V ) = W ( 1 ) O ( Atn ( 1 ) C K V ( W ( 1 ) U V ) T ) T + W ( 2 ) O ( Atn ( 2 ) C K V ( W ( 2 ) U V ) T ) T + ⋯ + W ( n h ) O ( Atn ( n h ) C K V ( W ( n h ) U V ) T ) T
MHA ( Q , K , V ) = W ( 1 ) O W ( 1 ) U V ( C K V ) T Atn ( 1 ) T + W ( 2 ) O W ( 2 ) U V ( C K V ) T Atn ( 2 ) T + ⋯ + W ( n h ) O W ( n h ) U V ( C K V ) T Atn ( n h ) T
\begin{equation}
\begin{aligned}
\text{MHA}(Q, K, V) &= W^{O}_{(1)} W^{UV}_{(1)} (\mathbf{C}^{KV})^T \text{Atn}_{(1)}^T \\
&\quad + W^{O}_{(2)} W^{UV}_{(2)} (\mathbf{C}^{KV})^T \text{Atn}_{(2)}^T \\
&\quad + \cdots + W^{O}_{(n_h)} W^{UV}_{(n_h)} (\mathbf{C}^{KV})^T \text{Atn}_{(n_h)}^T \\
\end{aligned}
\end{equation}
MHA ( Q , K , V ) = W ( 1 ) O W ( 1 ) U V ( C K V ) T Atn ( 1 ) T + W ( 2 ) O W ( 2 ) U V ( C K V ) T Atn ( 2 ) T + ⋯ + W ( n h ) O W ( n h ) U V ( C K V ) T Atn ( n h ) T
上記において、W ( i ) O ∈ R d × d h W^{O}_{(i)} \in \mathbb{R}^{d \times d_h} W ( i ) O ∈ R d × d h 、W ( i ) U V ∈ R d h × d c W^{UV}_{(i)} \in \mathbb{R}^{d_h \times d_c} W ( i ) U V ∈ R d h × d c は学習時は分割して学習された重み行列になります。
そこで、推論時には、これらの重み行列から事前に、W ( i ) O W ( i ) U V ∈ R d × d c W^{O}_{(i)} W^{UV}_{(i)} \in \mathbb{R}^{d \times d_c} W ( i ) O W ( i ) U V ∈ R d × d c を計算しておき、新しい一つの重み行列にすることもできます。
そうすると、圧縮KV表現C K V \mathbf{C}^{KV} C K V から明示的にバリューを復元することなく、一撃でMHA ( Q , K , V ) \text{MHA}(Q, K, V) MHA ( Q , K , V ) を計算できます。
これが推論時に、「W U V W^{UV} W U V はW O W^{O} W O に吸収される」という言葉の意味だと思われます。
まとめて
「W^{UK}の吸収」のまとめ
「W U K W^{UK} W U K の吸収」の章にて、式(15)から下記の式が成立しています。
Q ( i ) K ( i ) ⊤ = ( ( W ( i ) Q ) ′ ( C Q ) T ) T ( C K V ) T
Q_{(i)}K_{(i)}^\top = \left((W^{Q}_{(i)})' (\mathbf{C}^{Q})^T \right)^T (\mathbf{C}^{KV})^T
Q ( i ) K ( i ) ⊤ = ( ( W ( i ) Q ) ′ ( C Q ) T ) T ( C K V ) T
ただし、まとめた式(吸収した式)に、( W ( i ) Q ) ′ (W^{Q}_{(i)})' ( W ( i ) Q ) ′ という名前をつけています。
( W ( i ) Q ) ′ (W^{Q}_{(i)})' ( W ( i ) Q ) ′ は下記で定義されています。
( W ( i ) Q ) ′ = ( W ( i ) U K ) T W ( i ) U Q
(W^{Q}_{(i)})' = (W^{UK}_{(i)})^T W^{UQ}_{(i)}
( W ( i ) Q ) ′ = ( W ( i ) U K ) T W ( i ) U Q
ただし、( W ( i ) Q ) ′ ∈ R d c × d c ′ (W^{Q}_{(i)})' \in \mathbb{R}^{d_c \times d_c'} ( W ( i ) Q ) ′ ∈ R d c × d c ′ 。
また、W ( i ) U K ∈ R d h × d c W^{UK}_{(i)} \in \mathbb{R}^{d_h \times d_c} W ( i ) U K ∈ R d h × d c 、W ( i ) U Q ∈ R d h × d c ′ W^{UQ}_{(i)} \in \mathbb{R}^{d_h \times d_c'} W ( i ) U Q ∈ R d h × d c ′
その上で、Scaled Dot-Product Attentionにおける、Attentionマップの式は下記のようになります。
Atn ( i ) ( Q ( i ) , K ( i ) ) = softmax ( ( ( W ( i ) Q ) ′ ( C Q ) T ) T ( C K V ) T d h )
\text{Atn}_{(i)}(Q_{(i)}, K_{(i)}) = \text{softmax} \left( \frac{\left((W^{Q}_{(i)})' (\mathbf{C}^{Q})^T \right)^T (\mathbf{C}^{KV})^T}{\sqrt{d_h}} \right)
Atn ( i ) ( Q ( i ) , K ( i ) ) = softmax d h ( ( W ( i ) Q ) ′ ( C Q ) T ) T ( C K V ) T
「W^{UV}の吸収」のまとめ
「W U V W^{UV} W U V の吸収」の章にて、式(23)から下記が成立しています。
MHA ( Q , K , V ) = ( W ( 1 ) O ) ′ ( C K V ) T Atn ( 1 ) T + ( W ( 2 ) O ) ′ ( C K V ) T Atn ( 2 ) T + ⋯ + ( W ( n h ) O ) ′ ( C K V ) T Atn ( n h ) T
\begin{aligned}
\text{MHA}(Q, K, V) &= (W^{O}_{(1)})' (\mathbf{C}^{KV})^T \text{Atn}_{(1)}^T \\
&\quad + (W^{O}_{(2)})' (\mathbf{C}^{KV})^T \text{Atn}_{(2)}^T \\
&\quad + \cdots + (W^{O}_{(n_h)})' (\mathbf{C}^{KV})^T \text{Atn}_{(n_h)}^T \\
\end{aligned}
MHA ( Q , K , V ) = ( W ( 1 ) O ) ′ ( C K V ) T Atn ( 1 ) T + ( W ( 2 ) O ) ′ ( C K V ) T Atn ( 2 ) T + ⋯ + ( W ( n h ) O ) ′ ( C K V ) T Atn ( n h ) T
ただし、まとめた式(吸収した式)に、( W ( i ) O ) ′ (W^{O}_{(i)})' ( W ( i ) O ) ′ という名前をつけています。
( W ( i ) O ) ′ (W^{O}_{(i)})' ( W ( i ) O ) ′ は下記で定義されています。
( W ( i ) O ) ′ = W ( i ) O W ( i ) U V
(W^{O}_{(i)})' = W^{O}_{(i)} W^{UV}_{(i)}
( W ( i ) O ) ′ = W ( i ) O W ( i ) U V
ただし、( W ( i ) O ) ′ ∈ R d × d c (W^{O}_{(i)})' \in \mathbb{R}^{d \times d_c} ( W ( i ) O ) ′ ∈ R d × d c 。
また、W ( i ) O ∈ R d × d h W^{O}_{(i)} \in \mathbb{R}^{d \times d_h} W ( i ) O ∈ R d × d h 、W ( i ) U V ∈ R d h × d c W^{UV}_{(i)} \in \mathbb{R}^{d_h \times d_c} W ( i ) U V ∈ R d h × d c 。
全体のまとめ
Multi-head Latent Attention (MLA)を計算するにあたり、必要なのは下記に全て示ました。
Q ( i ) K ( i ) ⊤ = ( ( W ( i ) Q ) ′ ( C Q ) T ) T ( C K V ) T
Q_{(i)}K_{(i)}^\top = \left((W^{Q}_{(i)})' (\mathbf{C}^{Q})^T \right)^T (\mathbf{C}^{KV})^T
Q ( i ) K ( i ) ⊤ = ( ( W ( i ) Q ) ′ ( C Q ) T ) T ( C K V ) T
MHA ( Q , K , V ) = ( W ( 1 ) O ) ′ ( C K V ) T Atn ( 1 ) T + ( W ( 2 ) O ) ′ ( C K V ) T Atn ( 2 ) T + ⋯ + ( W ( n h ) O ) ′ ( C K V ) T Atn ( n h ) T
\begin{aligned}
\text{MHA}(Q, K, V) &= (W^{O}_{(1)})' (\mathbf{C}^{KV})^T \text{Atn}_{(1)}^T \\
&\quad + (W^{O}_{(2)})' (\mathbf{C}^{KV})^T \text{Atn}_{(2)}^T \\
&\quad + \cdots + (W^{O}_{(n_h)})' (\mathbf{C}^{KV})^T \text{Atn}_{(n_h)}^T \\
\end{aligned}
MHA ( Q , K , V ) = ( W ( 1 ) O ) ′ ( C K V ) T Atn ( 1 ) T + ( W ( 2 ) O ) ′ ( C K V ) T Atn ( 2 ) T + ⋯ + ( W ( n h ) O ) ′ ( C K V ) T Atn ( n h ) T
ここからわかることとして、推論時には( W ( i ) Q ) ′ (W^{Q}_{(i)})' ( W ( i ) Q ) ′ と( W ( i ) O ) ′ (W^{O}_{(i)})' ( W ( i ) O ) ′ という新しい重み行列を事前に計算してモデルにおいておくことで、キーバリュークエリを明示的に計算する必要がなく、少ない計算回数にて MLAを計算することが可能だとわかりました。
以上が、論文中の
In addition, during inference, since W U K W^{UK} W U K can be absorbed into W Q W^Q W Q , and W U V W^{UV} W U V can be absorbed into W O W^{O} W O , we even do not need to compute keys and values out for attention.
だと考えています。
私の疑問点
MLAは主にKVキャッシュの圧縮や、クエリの圧縮による、メモリ効率を向上させることに非常に重きを置いた手法に見えています。
一方で、計算回数を減らすために、行列を吸収させる場合、吸収後の行列のパラメータ数は非常に多くなるように思います。
例えば下記の式で考えます。
( W ( i ) Q ) ′ = ( W ( i ) U K ) T W ( i ) U Q
(W^{Q}_{(i)})' = (W^{UK}_{(i)})^T W^{UQ}_{(i)}
( W ( i ) Q ) ′ = ( W ( i ) U K ) T W ( i ) U Q
ただし、( W ( i ) Q ) ′ ∈ R d c × d c ′ (W^{Q}_{(i)})' \in \mathbb{R}^{d_c \times d_c'} ( W ( i ) Q ) ′ ∈ R d c × d c ′ であり、また、W ( i ) U K ∈ R d h × d c W^{UK}_{(i)} \in \mathbb{R}^{d_h \times d_c} W ( i ) U K ∈ R d h × d c 、W ( i ) U Q ∈ R d h × d c ′ W^{UQ}_{(i)} \in \mathbb{R}^{d_h \times d_c'} W ( i ) U Q ∈ R d h × d c ′ です。
このとき、論文中の数字(d c = 512 , d c ′ = 1536 , d h = 128 d_c = 512, d_c' = 1536, d_h = 128 d c = 512 , d c ′ = 1536 , d h = 128 )を代入すると、( W ( i ) Q ) ′ (W^{Q}_{(i)})' ( W ( i ) Q ) ′ のパラメータ数は786 , 432 786,432 786 , 432 、W ( i ) U K W^{UK}_{(i)} W ( i ) U K のパラメータ数は65 , 536 65,536 65 , 536 、W ( i ) U Q W^{UQ}_{(i)} W ( i ) U Q のパラメータ数は196 , 608 196,608 196 , 608 となります。
すなわち、吸収前のパラメータ数は65 , 536 + 196 , 608 = 262 , 144 65,536 + 196,608 = 262,144 65 , 536 + 196 , 608 = 262 , 144
吸収後のパラメータ数は786 , 432 786,432 786 , 432 となってしまいます。
これでは、推論時の必要パラメータが多くなってしまうため、 MLAの利点であるメモリ効率向上と相反する形に見えます。
という意味で、この吸収は実装されているのか?されていないのか?
されている場合は、メモリ効率の低減は、無視ができる程度なのか?そうでないのか?
この辺り、まだ私の中でもあまり固まっていない疑問ですが、詳しい方に教えていただけると嬉しいです。
なお、「吸収して大きな行列を作っているのではなく、2つの行列のまま2回行列を掛け算している」という考えは、私はないと思っております。なぜからその方法であればRoPEを導入した際に問題にならないからです。
その場合、普通に2つの行列演算の間に回転行列をかければいいだけのはずで、わざわざRoPEを適用するQuery・Key、適用しないQuery・Keyを分ける必要はないですよね?
ぜひ詳しい方!教えていただけると嬉しいです!
まとめ
読んでくださってありがとうございました!
MLAの「be absorbed」の部分の検証をしたいと思っていた方の参考になれば幸いです。
また、ぜひ私の疑問点にも、ご回答をもらえると嬉しいです。
Attentionについてのおすすめ書籍
機械学習エンジニアのためのTransformers ―最先端の自然言語処理ライブラリによるモデル開発
Hugging Faceの開発者が書いた本ですので、信頼できます。
当然のことながら、Attentionレイヤーの詳細な解説に関しても記載してあり、理論・基礎を学ぶ上でも非常におすすめの書籍です。
大規模言語モデル入門
大規模言語モデル入門Ⅱ〜生成型LLMの実装と評価
よく紹介させていただいておりますが、こちらの書籍は、LLMのファインチューニングから、RLHF、RAG、分散学習にかけて、本当に幅広く解説されており、いつも参考にさせていただいております。
特に1冊目は、TransformerのAttentionレイヤーの数式に関してもわかりやすく説明されておりおすすめです。
Discussion
実装はされていないようでした.個別に Q, K を計算してその積を計算していました.
主観ですが,数百トークンを入力したときは無視できるのかなと思います.
しかし,最近よく見かける,まとめて proj して split する方法を使えばより簡単に行列積の演算回数は減らせるので,low-rank のパラメータをわざわざ大きくするのはあまり良い方法には感じませんでした...Q,K=X⋅concat(WQ,WK) をした後に QKT を計算
proj+split:
余談ですが,分析研究では対象をシンプルにするためにこの QK, VO の合体が頻出します.https://arxiv.org/abs/2405.00208 の式 2, 3
まず、ご回答いただきありがとうございます。
DeepSeek-V2のPyTorch実装がここから見れるということや、「まとめて proj して split する方法」も言われてみれば「確かに!」でしたが、実際の実装を見たことがなかったので初めて知りました!
ありがとうございます。
確かに、こちらのPyTorch実装を見ると、愚直にKVを一度計算しているように見えます(行列を吸収するのは実装されていない)
しかしながら、そう考えるとDeepSeek-V2論文の下記記載部分(2.1.3)と矛盾します。
ここに記載の通り、「RoPEをQとKに適応すると、行列吸収ができなくなる」ので、わざわざ論文2.1.3にて、RoPEを適用したQKと適用していないQKを分けて処理をしています。
(この処理は、PyTorch版のコードでも実装されています)
しかしながら、記載いただいたPyTorch実装のように、一度圧縮表現CtKV からK Vを計算するのであれば、RoPEを導入したとしても問題なく計算可能(普通にKを計算した後に、RoPEを適用し、あらためてQKT を計算すれば良い)なので、わざわざRoPEを適用するために、特殊な処理を考える必要はないはずです。
一方で、DeepSeek-V2論文3.1.3にて下記の記載がございます。
これを見ると、DeepSeek社内部での学習や、APIで呼び出された時の内部の推論に関してはPyTorchではなく、独自フレームワークを利用している可能性が高いです。
そちらのフレームワークでは、もしかしたら行列吸収の形で実装されており、だから、RoPEの部分で問題があったのかもしれません。
これはおっしゃる通りかもですね。
KVキャッシュされるtoken数が増えるほど、圧縮されるメモリの絶対量は増えそうですので、1パラメータの重み行列の大きさを無視できるタイミングはありそう。
こちらも軽く見させていただきました!
普通のTransformerでもこの行列吸収はできるよな?とは思っていたのでその疑問解消できました!
新情報:github にある DeepSeek-V3 のコードに
absorb
の分岐がありました. APIの内部では吸収の方で動いているような気がしてきました