🧲

Attentionの演算解釈についてのメモ

2024/03/24に公開

背景

Large Language Model(LLM)に使われているAttentionは、トークンの埋め込みベクトルの類似度(内積)をとっている。しかし、行列計算として見たときに、内積が何を意味しているのかが理解できていないと、個人的に感じている。この背景の元、本記事では、演算の意味に着目して、Attentionを解釈することを目指す。

仮定

本記事では、以下の仮定を置いている。

  • LLMに使用されている、Attentionを仮定している。
  • シンプルにバッチサイズは1だと仮定する。ただ、バッチサイズが2以上の場合も、同じ処理が、それぞれバッチに対して並列に適用されるだけである。
  • Attentionは、シングルヘッドだと仮定する

Attentionの概観

まずは、諸々の記号を定義する。文章を構成するトークンの長さを10、隠れ層の次元数を512とする。

  • X: Input Matrix (10, 512)
  • Q: Query Matrix (10,512)
  • K: Key Matrix (10,512)
  • V: Value Matrix (10, 512)

ここで、任意の文章をトークン化した、Xをインプットとすると、Attentionの処理は、以下のPythonコードで表現できる。

def compute_attention(X)
	# X (shape: (10, 512))
	Q = QueryLayer(X) # shape: (10, 512)
	K = KeyLayer(X) # shape: (10, 512)
	V = ValueLayer(X)# shape: (10, 512)
	# dot_productは、行列同士の積を計算する関数
	attention_weights = softmax(dot_product(Q, K.T) / sqrt(H), dim=-1) # shape: (10, 10)
	
	attention_maps =  dot_product(attention_weights, V.T) # shape: (10, 512)
	
	new_X = LayerNorm(X + attention_maps) # shape: (10, 512)
	return new_X

計算の解釈

dot_product(Q, K.T)

この演算は、行列Qの各トークンに対応するベクトルを、行列Kの各トークンに対応するベクトルとで内積を計算している。A = dot_product(Q, K.T)とすると、Aは、(10, 10)のサイズを持つ行列である。

理解のために、この行列計算をベクトル単位の計算に落とし込む。例として、0番目のトークンに対応するqueryベクトル(Q[0, :])と、全てのトークンのkeyベクトル(K[:, 0])との内積計算を考えてみる。

dot_product(Q[0, :], K[:, 0]) = A[0, 0] 
dot_product(Q[0, :], K[:, 1]) = A[0, 1]
dot_product(Q[0, :], K[:, 2]) = A[0, 2]  
dot_product(Q[0, :], K[:, 3]) = A[0, 3]
…
dot_product(Q[0, :], K[:, 10]) = A[0, 9] 

ということは、Aの要素A[i, j]は、i番目のトークンとj番目のトークンの内積の値を格納していると解釈できる。こうして、attention_weightsが計算される。

dot_product(attention_weights, V.T)

次は、attention_weightsはと行列Vの内積を取るということは、どういうことなのかを考えてみたい。行列同士の内積計算は、attention_weightsの各行とValue Matrixの各列のベクトル積を取っていく。したがって、行列を構成するベクトルに視点を向ける。

2つの行列の列・行ベクトルは以下のように考えられる

  • 行列attention_weightsの一行目のベクトル(10, 1) ⇒ 0番目のトークンと全てのトークンの内積をSoftmaxで正規化した値からなるベクトル
  • 行列Vの転置行列V.Tの一列目のベクトル(1, 10) ⇒ V.Tは(512, 10)のサイズだったので、このベクトルは、全てのトークンの512次元のベクトルの1次元目の要素からなるベクトル。
    そして、計算は以下のように行われる。
dot_product(attention_weights[0, :], V.T[:, 0]) = attention_maps[0, 0]
dot_product(attention_weights[0, :], V.T[:, 1]) = attention_maps[0, 1]
dot_product(attention_weights[0, :], V.T[:, 2]) = attention_maps[0, 2]
dot_product(attention_weights[0, :], V.T[:, 3]) = attention_maps[0, 3]
…
dot_product(attention_weights[0, :], V.T[:, 511]) = attention_maps[0, 511]

そして、問題は、attention_maps[0, 0]は何なのかということである。attention_maps[0, 0]は、0番目のトークンの512次元ベクトルの1次元目の要素である。0番目のトークンに着目すると、0番目のトークンと全てのトークンの内積の値で、次元を重み付けして、足す。それが、0番目のトークンの1次元目の要素になる。それを、512次元分行う。その計算を全てのトークン分繰り返すと、サイズ(10, 512)を持つ行列attention_mapsになる。

LayerNorm(X + attention_maps)

最後に、attention_mapsにLayerNormレイヤーを適用して、インプットXに足す。このAttentionでやっていることは、任意のトークンと全てのトークンの内積に応じて、インプットXの各要素の値の上げ下げを行う。それは、下流タスクでの予想確率の高さに影響を与える。こうして、Neural Networkは、数値の世界で、言語的なコンテクストを"理解"している。

Discussion