Attentionの演算解釈についてのメモ
背景
Large Language Model(LLM)に使われているAttentionは、トークンの埋め込みベクトルの類似度(内積)をとっている。しかし、行列計算として見たときに、内積が何を意味しているのかが理解できていないと、個人的に感じている。この背景の元、本記事では、演算の意味に着目して、Attentionを解釈することを目指す。
仮定
本記事では、以下の仮定を置いている。
- LLMに使用されている、Attentionを仮定している。
- シンプルにバッチサイズは1だと仮定する。ただ、バッチサイズが2以上の場合も、同じ処理が、それぞれバッチに対して並列に適用されるだけである。
- Attentionは、シングルヘッドだと仮定する
Attentionの概観
まずは、諸々の記号を定義する。文章を構成するトークンの長さを10、隠れ層の次元数を512とする。
- X: Input Matrix (10, 512)
- Q: Query Matrix (10,512)
- K: Key Matrix (10,512)
- V: Value Matrix (10, 512)
ここで、任意の文章をトークン化した、Xをインプットとすると、Attentionの処理は、以下のPythonコードで表現できる。
def compute_attention(X)
# X (shape: (10, 512))
Q = QueryLayer(X) # shape: (10, 512)
K = KeyLayer(X) # shape: (10, 512)
V = ValueLayer(X)# shape: (10, 512)
# dot_productは、行列同士の積を計算する関数
attention_weights = softmax(dot_product(Q, K.T) / sqrt(H), dim=-1) # shape: (10, 10)
attention_maps = dot_product(attention_weights, V.T) # shape: (10, 512)
new_X = LayerNorm(X + attention_maps) # shape: (10, 512)
return new_X
計算の解釈
dot_product(Q, K.T)
この演算は、行列Qの各トークンに対応するベクトルを、行列Kの各トークンに対応するベクトルとで内積を計算している。A = dot_product(Q, K.T)とすると、Aは、(10, 10)のサイズを持つ行列である。
理解のために、この行列計算をベクトル単位の計算に落とし込む。例として、0番目のトークンに対応するqueryベクトル(Q[0, :])と、全てのトークンのkeyベクトル(K[:, 0])との内積計算を考えてみる。
dot_product(Q[0, :], K[:, 0]) = A[0, 0]
dot_product(Q[0, :], K[:, 1]) = A[0, 1]
dot_product(Q[0, :], K[:, 2]) = A[0, 2]
dot_product(Q[0, :], K[:, 3]) = A[0, 3]
…
dot_product(Q[0, :], K[:, 10]) = A[0, 9]
ということは、Aの要素A[i, j]は、i番目のトークンとj番目のトークンの内積の値を格納していると解釈できる。こうして、attention_weightsが計算される。
dot_product(attention_weights, V.T)
次は、attention_weightsはと行列Vの内積を取るということは、どういうことなのかを考えてみたい。行列同士の内積計算は、attention_weightsの各行とValue Matrixの各列のベクトル積を取っていく。したがって、行列を構成するベクトルに視点を向ける。
2つの行列の列・行ベクトルは以下のように考えられる
- 行列attention_weightsの一行目のベクトル(10, 1) ⇒ 0番目のトークンと全てのトークンの内積をSoftmaxで正規化した値からなるベクトル
- 行列Vの転置行列V.Tの一列目のベクトル(1, 10) ⇒ V.Tは(512, 10)のサイズだったので、このベクトルは、全てのトークンの512次元のベクトルの1次元目の要素からなるベクトル。
そして、計算は以下のように行われる。
dot_product(attention_weights[0, :], V.T[:, 0]) = attention_maps[0, 0]
dot_product(attention_weights[0, :], V.T[:, 1]) = attention_maps[0, 1]
dot_product(attention_weights[0, :], V.T[:, 2]) = attention_maps[0, 2]
dot_product(attention_weights[0, :], V.T[:, 3]) = attention_maps[0, 3]
…
dot_product(attention_weights[0, :], V.T[:, 511]) = attention_maps[0, 511]
そして、問題は、attention_maps[0, 0]は何なのかということである。attention_maps[0, 0]は、0番目のトークンの512次元ベクトルの1次元目の要素である。0番目のトークンに着目すると、0番目のトークンと全てのトークンの内積の値で、次元を重み付けして、足す。それが、0番目のトークンの1次元目の要素になる。それを、512次元分行う。その計算を全てのトークン分繰り返すと、サイズ(10, 512)を持つ行列attention_mapsになる。
LayerNorm(X + attention_maps)
最後に、attention_mapsにLayerNormレイヤーを適用して、インプットXに足す。このAttentionでやっていることは、任意のトークンと全てのトークンの内積に応じて、インプットXの各要素の値の上げ下げを行う。それは、下流タスクでの予想確率の高さに影響を与える。こうして、Neural Networkは、数値の世界で、言語的なコンテクストを"理解"している。
Discussion