The purpose of this article is to organize "with equations" the following expression from the Multi-head Latent Attention (MLA) section of the DeepSeek V2 paper. Primarily, it organizes "which weights" and "how they are absorbed" using "equations."
(I might turn this into a more detailed explanatory article later.)
In addition, during inference, since W^{UK} can be absorbed into W^Q, and W^{UV} can be absorbed into W^{O}, we even do not need to compute keys and values out for attention.
This article assumes that you have an understanding of the Multi-head Latent Attention (MLA) described in the DeepSeek V2 paper. https://arxiv.org/abs/2405.04434
Note that regarding the claim that it still holds even after the introduction of RoPE, I will not handle that "this time."
However, the basic claim remains the same. I might write an article about it eventually.
That said, I have described MLA in detail—which has very wide "gaps between the lines" in the paper—so I would be happy if this helps someone's understanding.
Also, I wanted to write an article that current AI cannot write.
(Even if you ask something like o1 pro for an explanation, it likely won't provide one this detailed.)
Additionally, I have also written a paper commentary article on the recently popular DeepSeek-R1.
Unlike this article, I tried to write it so that those who are not familiar with theory would find it interesting! So please take a look if you'd like. https://zenn.dev/asap/articles/34237ad87f8511
Organizing MLA with Equations
Scaled Dot-Product Attention is expressed as follows:
Based on the above, let's examine the derivation for MLA.
Introduction of Symbols
MLA is primarily aimed at compressing the KV cache and queries.
Therefore, we will introduce symbols to handle compressed KV and query representations.
Compressed KV Representation and KV
Here, we describe the compressed KV representation and the original KV.
In MLA, we introduce the compressed KV representation \mathbf{C}^{KV} \in \mathbb{R}^{T \times d_c}. T is the total number of tokens, and d_c is the dimensionality of the compressed KV representation (512 dimensions in the paper).
This compressed representation is created as follows:
(The following transformation is applied to each token. In other words, parallel processing is performed for the number of tokens.)
Here, t denotes the features of the t-th token.
Also, \mathbf{h_t} \in \mathbb{R}^{d} is the output of the previous layer at the t-th token, where d is the output dimensionality of the previous layer (5120 dimensions in the paper).
Furthermore, W^{DKV} \in \mathbb{R}^{d_c \times d} is the down-projection matrix. \mathbf{c}_t^{KV} \in \mathbb{R}^{d_c} is the compressed KV representation at the t-th token.
And \mathbf{C}^{KV} \in \mathbb{R}^{T \times d_c} is a matrix that summarizes \mathbf{c}_t^{KV} for all tokens.
Using the compressed KV representation \mathbf{C}^{KV}, we can restore the keys and values.
W^{UK}, W^{UV} \in \mathbb{R}^{d_hn_h \times d_c} are the up-projection matrices.
They restore all key and value representations for the multi-head attention from the compressed KV. d_h is the dimensionality per head (128 dimensions in the paper), and n_h is the number of heads (128 heads in the paper). \mathbf{K}, \mathbf{V} \in \mathbb{R}^{T \times d_hn_h} are the uncompressed key and value representations, respectively.
Compressed Query Representation and Q
Next, queries are also compressed to reduce parameters during training.
Also, W^{DQ} \in \mathbb{R}^{d_c' \times d} is the down-projection matrix. d_c' is the dimensionality of the compressed query representation (1536 dimensions in the paper). \mathbf{c}_t^{Q} \in \mathbb{R}^{d_c'} is the compressed query representation at the t-th token. \mathbf{C}^{Q} \in \mathbb{R}^{T \times d_c'} is the compressed query representation for the entire sequence of tokens.
Using this compressed query representation \mathbf{C}^{Q}, the original query can be restored.
W^{UQ} \in \mathbb{R}^{d_hn_h \times d_c'} is the up-projection matrix.
It restores all query representations for the multi-head attention from the compressed query representation. \mathbf{Q} \in \mathbb{R}^{T \times d_hn_h} is the uncompressed query representation.
Difference between regular MHA (Multi-Head Attention) and MLA
In regular MHA, Q, K, and V are each expressed as follows:
Note that \mathbf{v}_t \in \mathbb{R}^{d_hn_h}, W^{V} \in \mathbb{R}^{d_hn_h \times d}, and \mathbf{V} \in \mathbb{R}^{T \times d_hn_h}
As you can see, they are constructed in almost exactly the same way.
So, let's look at the difference between MHA and MLA, focusing on the key representation.
Key representation in MLA
What is important are the following three equations described in the "Compressed KV Representation and KV" section.
Therefore, if training is done appropriately, from equations (12) and (13), we have W^{K} = W^{UK}W^{DKV}.
In this case, the degree of degradation is determined by the rank of the matrix W^{K} and the degree of freedom d_c of the matrices W^{UK} and W^{DKV}.
In the unlikely event that the rank of the matrix W^{K} is smaller than the degree of freedom d_c of the matrices W^{UK} and W^{DKV}, no degradation will occur with proper training.
Additionally, while W^{K} \in \mathbb{R}^{d_hn_h \times d} has d_hn_h \times d parameters, when split into W^{UK} \in \mathbb{R}^{d_hn_h \times d_c} and W^{DKV} \in \mathbb{R}^{d_c \times d}, the number of parameters becomes d_hn_h \times d_c + d_c \times d.
Since d is very large compared to d_c, it is possible to approximate it with fewer parameters than the former.
In the world of deep learning, it has often been observed that approximating by reducing the number of trainable parameters leads to improved accuracy (as seen in the world of CNNs for image processing), so the fact that MLA performed better than MHA this time might be due to that perspective as well.
What does "be absorbed" mean??
Now, let's verify that the following content from the paper actually holds.
In addition, during inference, since W^{UK} can be absorbed into W^Q, and W^{UV} can be absorbed into W^{O}, we even do not need to compute keys and values out for attention.
Absorption of W^{UK}
Assuming a single head, if we focus on a specific head i, Attention can be expressed as follows:
Note that \text{SDPA}_{(i)} \in \mathbb{R}^{T \times d_h}.
SDPA stands for Scaled Dot-Product Attention.
The part
W^{UK} can be absorbed into W^Q
relates to Q_{(i)}K_{(i)}^\top, so we will focus on that part.
However, since we are examining this head by head, the symbols used so far are divided as follows: Q_{(i)} \in \mathbb{R}^{T \times d_h}, K_{(i)} \in \mathbb{R}^{T \times d_h}, W^{UQ}_{(i)} \in \mathbb{R}^{d_h \times d_c'}, W^{UK}_{(i)} \in \mathbb{R}^{d_h \times d_c}
The following remain unchanged as they are used across heads: \mathbf{C}^{Q} \in \mathbb{R}^{T \times d_c'}, \mathbf{C}^{KV} \in \mathbb{R}^{T \times d_c}
Now, in MLA, Q_{(i)}K_{(i)}^\top can be written out as follows:
In the above, W^{UK}_{(i)} \in \mathbb{R}^{d_h \times d_c} and W^{UQ}_{(i)} \in \mathbb{R}^{d_h \times d_c'} are weight matrices that were learned separately during training.
Therefore, during inference, we can pre-calculate (W^{UK}_{(i)})^T W^{UQ}_{(i)} \in \mathbb{R}^{d_c \times d_c'} from these weight matrices and combine them into a single new weight matrix.
By doing so, we can calculate Q_{(i)}K_{(i)}^\top in one go without explicitly restoring the keys and queries from the compressed KV representation \mathbf{C}^{KV} and the compressed query representation \mathbf{C}^{Q}.
This seems to be the meaning of the words "W^{UK} is absorbed into W^{UQ} (W^Q)" during inference.
Absorption of W^{UV}
The part > W^{UV} can be absorbed into W^{O} refers to the stage where, after creating the Attention map and calculating the matrix product with the values, the output for each head is consolidated into one using W^{O}. Since this part is a bit complex, let's organize it step-by-step.
First, let's represent the Attention map part as Atn_{(i)}. Specifically, it is as follows:
In the above, W^{O}_{(i)} \in \mathbb{R}^{d \times d_h} and W^{UV}_{(i)} \in \mathbb{R}^{d_h \times d_c} are weight matrices that were learned separately during training.
Therefore, during inference, we can pre-calculate W^{O}_{(i)} W^{UV}_{(i)} \in \mathbb{R}^{d \times d_c} from these weight matrices and combine them into a single new weight matrix.
By doing so, we can calculate \text{MHA}(Q, K, V) in one go without explicitly restoring the values from the compressed KV representation \mathbf{C}^{KV}.
This seems to be the meaning of the words "W^{UV} is absorbed into W^{O}" during inference.
Summary
Summary of "Absorption of W^{UK}"
In the section "Absorption of W^{UK}," the following equation holds based on equation (15):
From this, we can see that during inference, by pre-calculating the new weight matrices (W^{Q}_{(i)})' and (W^{O}_{(i)})' and keeping them in the model, it is possible to calculate MLA with fewer computations without explicitly calculating the keys, values, and queries.
I believe the above is the meaning of the following statement in the paper:
In addition, during inference, since W^{UK} can be absorbed into W^Q, and W^{UV} can be absorbed into W^{O}, we even do not need to compute keys and values out for attention.
My Questions
MLA seems to be a method that places great emphasis on improving memory efficiency, primarily through the compression of the KV cache and queries.
On the other hand, it seems to me that when matrices are absorbed to reduce the number of computations, the number of parameters in the post-absorption matrices becomes quite large.
For example, consider the following equation:
(W^{Q}_{(i)})' = (W^{UK}_{(i)})^T W^{UQ}_{(i)}
where (W^{Q}_{(i)})' \in \mathbb{R}^{d_c \times d_c'}, W^{UK}_{(i)} \in \mathbb{R}^{d_h \times d_c}, and W^{UQ}_{(i)} \in \mathbb{R}^{d_h \times d_c'}.
In this case, substituting the numbers from the paper (d_c = 512, d_c' = 1536, d_h = 128), the number of parameters for (W^{Q}_{(i)})' is 786,432, for W^{UK}_{(i)} is 65,536, and for W^{UQ}_{(i)} is 196,608.
That is, the number of parameters before absorption is 65,536 + 196,608 = 262,144, while the number of parameters after absorption becomes 786,432.
This would increase the required parameters during inference, which seems to conflict with MLA's advantage of improving memory efficiency. In that sense, is this absorption actually implemented or not? If it is, is the decrease in memory efficiency negligible or not?
This is a question that I haven't fully settled yet, but I would be happy if someone knowledgeable could explain it to me.
Note that I do not believe the idea is "instead of making a large matrix by absorbing, the multiplication is performed twice with two separate matrices." This is because that approach wouldn't be an issue when RoPE is introduced.
In that case, you would just need to apply the rotation matrix between the two matrix operations, and there would be no need to purposely separate Queries/Keys that have RoPE applied from those that don't, right?
To anyone knowledgeable, I would be very happy if you could let me know!
Summary
Thank you for reading!
I hope this is helpful for anyone who wanted to verify the "be absorbed" part of MLA.
Also, I would be happy if I could get answers to my questions.
If we apply RoPE for the keys 𝐤_t^C, W^{UK} in Equation 10 will be coupled with a position-sensitive RoPE matrix. In this way, W^{UK} cannot be absorbed into W^Q any more during inference, since a RoPE matrix related to the currently generating token will lie between W^Q and W^{UK} and matrix multiplication does not obey a commutative law.
DeepSeek-V2 is trained based on the HAI-LLM framework (High-flyer, 2023), an efficient and light-weight training framework developed internally by our engineers.
Discussion
実装はされていないようでした.個別に Q, K を計算してその積を計算していました.
主観ですが,数百トークンを入力したときは無視できるのかなと思います.
しかし,最近よく見かける,まとめて proj して split する方法を使えばより簡単に行列積の演算回数は減らせるので,low-rank のパラメータをわざわざ大きくするのはあまり良い方法には感じませんでした...Q, K = X \cdot \text{concat}(W_Q, W_K) をした後に QK^T を計算
proj+split:
余談ですが,分析研究では対象をシンプルにするためにこの QK, VO の合体が頻出します.https://arxiv.org/abs/2405.00208 の式 2, 3
まず、ご回答いただきありがとうございます。
DeepSeek-V2のPyTorch実装がここから見れるということや、「まとめて proj して split する方法」も言われてみれば「確かに!」でしたが、実際の実装を見たことがなかったので初めて知りました!
ありがとうございます。
確かに、こちらのPyTorch実装を見ると、愚直にKVを一度計算しているように見えます(行列を吸収するのは実装されていない)
しかしながら、そう考えるとDeepSeek-V2論文の下記記載部分(2.1.3)と矛盾します。
ここに記載の通り、「RoPEをQとKに適応すると、行列吸収ができなくなる」ので、わざわざ論文2.1.3にて、RoPEを適用したQKと適用していないQKを分けて処理をしています。
(この処理は、PyTorch版のコードでも実装されています)
しかしながら、記載いただいたPyTorch実装のように、一度圧縮表現C_t^{KV} からK Vを計算するのであれば、RoPEを導入したとしても問題なく計算可能(普通にKを計算した後に、RoPEを適用し、あらためてQK^T を計算すれば良い)なので、わざわざRoPEを適用するために、特殊な処理を考える必要はないはずです。
一方で、DeepSeek-V2論文3.1.3にて下記の記載がございます。
これを見ると、DeepSeek社内部での学習や、APIで呼び出された時の内部の推論に関してはPyTorchではなく、独自フレームワークを利用している可能性が高いです。
そちらのフレームワークでは、もしかしたら行列吸収の形で実装されており、だから、RoPEの部分で問題があったのかもしれません。
これはおっしゃる通りかもですね。
KVキャッシュされるtoken数が増えるほど、圧縮されるメモリの絶対量は増えそうですので、1パラメータの重み行列の大きさを無視できるタイミングはありそう。
こちらも軽く見させていただきました!
普通のTransformerでもこの行列吸収はできるよな?とは思っていたのでその疑問解消できました!
新情報:github にある DeepSeek-V3 のコードに
absorbの分岐がありました. APIの内部では吸収の方で動いているような気がしてきました