🐬

Multi head Attentionとは (Transformerその2)

2024/03/14に公開

今回は前回に引き続き、TransformerのMulti head Attentionについて解説していきます。

Transformer

初めに、Transformerのアーキテクチャは次のようになっています。

そしてこの中のMulti-head Attentionのアーキテクチャは次のようになっています。

今回はこのMulti-head Attentionのさらに元になっているSingle-Head Attentionから始めようと思います。

Single-Head Attention

Single-Head Attentionとは、Scaled Dot-Product AttentionにLinear層を追加したものです。

前の記事で解説したScaled Dot-Product Attentionは、重みなどのパラメータを持つ層がなかったため、単純な単語埋め込み(ベクトル)の類似度のみを計算していました。

学習パラメータを持たないScaled Dot-Product Attentionの表現能力を広げるために、モデルへの各入力の直前に学習パラメータを持つLinear層を追加することで、入力されるベクトルの特徴空間に依存しない表現力を獲得できます。

しかし、現実での言語には、複数の意味を持つ言葉が存在します。このような単語でもSingle-Head Attentionはその意味を平均化して扱ってしまうため、表現力が不足することになります。

Multi-Head Attention

その課題を解決するためにMulti-Head Attentionは、Single-Head Attentionを多数並列に配置することで、さまざまな注意表現の学習を可能にしました。

Q,Kを用いてVのどこに着目すべきかを把握する注意機構(Attention)であることは変わりませんが、並列にAttention機構を配置することで、より複雑な入力の関係を認識することができるようになります。

やっていることはScaled Dot-Product Attentionから変わっていませんが、より汎用的に入力の関係を認識できるようになっています。

Positional Encoding

Positional Encoding(位置エンコーディング)とは、入力情報の位置関係をモデルに伝えるために用いられる手法です。
TrasformerはAttentionの後、feedforwardレイヤによって学習が行われますが、ここではRNNなどの順番を考慮するモデルが使用されていないため、他の場所で入力の位置情報を教えてあげる必要があります。この役割を担うのがPositional Encodingです。
(ここにRNNなどを使用してしまうと、シーケンシャルな処理が必要となりTransformerの大きな利点である並列計算による計算効率化が失われます)

Positional Encodingは、ニューラルネットワークにおけるバイアス項に似ています。入力の順番に対して微小な値を追加する(バイアスをかける)ことで、学習を繰り返すうちにモデルがその微小な差異を重要な情報として取り込み、その差異によって同じようなベクトルの単語であっても、位置によって持つ意味が大きく変わることを理解します。
例えば、同じ単語ベクトル run = [1,1,2,2]でも、文の前半を示すバイアスを受けた[1.01,1.99,2.01,2.99]と、文の後半を示すバイアスを受けた[1.99,1.01,2.99,2.01]では意味が異なる可能性があることを学習できます。

-- 数式 ----
Positional Encodingは以下の数式を単語ベクトルに加算することで行われます。入力ベクトルの要素indexが偶数であるか奇数であるかによって、適用される式が変わります。

偶数の場合: PE(pos, 2i) = sin(\dfrac{pos}{10000^{2i/d_{model}}})
奇数の場合: PE(pos, 2i+1) = cos(\dfrac{pos}{10000^{2i/d_{model}}})

pos: トークン(単語など)の位置。例[Hello world]なら、Helloのpos=0、worldのpos=1
i: トークン埋め込み後のベクトルにおけるインデックス値
d_{model}: トークン埋め込み後のベクトル次元数


意味の理解

ここからはこちらの記事を参考にこの式の意味を理解していきます。

まず、単語ベクトルに位置情報を与えるにはどのような方法があるか考えてみます。

単純増加

最も簡単なのは、単語ベクトルの順番に応じて1,2,3..と数字を加えていく手法です。
例えば、前述の例(run=[1,1,2,2])に単純増加による位置エンコーディングを適用すると、 run run runという文は [[2,2,3,3],[3,3,4,4],[4,4,5,5]]となります。
これにより、同じ文字でも位置によって異なる情報を持てるようなります。しかし、この手法ではベクトルの値が大きく変化するため、ベクトルの意味そのものが変わってしまいます。

加算値の制限

ここでsinやcosを使用することで、加算する値を-1から1までの範囲に限定できます。しかし、これらは周期性を保つため、通常のままでは同じ位置情報が繰り返されてしまいます。そこで、周波数を非常に低くすることで、繰り返しが起きないようにします。

コード
import matplotlib.pyplot as plt
import numpy as np

# Define the frequency
frequency = 0.01# Very low frequency

# Generate x values
x = np.linspace(0, 100, 1000)

# Calculate y values using the sin function
y = np.sin(frequency * x)

# Plot the function
# Plot the function with adjusted x range and customize y-ticks
plt.figure(figsize=(10, 6))
plt.plot(x, y)
plt.title('Very Low Frequency Sin Wave')
plt.xlabel('x')
plt.ylabel('sin(x)')
plt.grid(True)

# Customize y-ticks to better illustrate the sine wave's amplitude within this range
plt.yticks(np.linspace(-1, 1, 11))
plt.xticks(np.linspace(0, 100, 11))

plt.show()

これにより、微小な値を追加できるようになります。
例:[[1.01,1.01,2.01,2.01],[1.02,1.02,2.02,2.02],[1.03,1.03,2.03,2.03]]

sinとcosの併用

しかし、ベクトルに全て0.01が足されたとしても変化が非常に微小であり、上手く認識されない可能性があります。そこで、cosを併用することで、sinの値が小さい時にはcosの値が大きく付与されるようになります。

これで、より表現力が高まります。
例:[[1.01,1.99,2.01,2.99],[1.02,1.98,2.02,2.98],[1.03,1.97,2.03,2.97]]

-- 式をもう一度考える ----
 これまでの理解を踏まえて、再度式を見てみましょう。
iが偶数の場合: PE(pos, 2i) = sin(\dfrac{pos}{10000^{2i/d_{model}}})
iが奇数の場合: PE(pos, 2i+1) = cos(\dfrac{pos}{10000^{2i/d_{model}}})

非常に低い周波数(1/10000^{2i/d_{model}})で、posの値(入力のインデックス値)に対してsinとcosの計算が行われていることがわかります。これを埋め込み後のベクトルに加算することで、モデルは入力の位置による違いを認識できるようになります。


今回はここまでになります。
読んでいただきありがとうございました。

参考

[1]https://arxiv.org/pdf/1706.03762.pdf
[2]https://developers.agirobots.com/jp/multi-head-attention/

Discussion