🐱

AI界を席巻する「Transformer」をゆっくり解説(4日目) ~Model Architecture編 2~

12 min read

AI界を席巻する「Transformer」を解説するシリーズ4日目です。

Attention Is All You Needの論文PDFはこちら

  • 1日目: Abstract
  • 2日目: Introduction / Background
  • 3日目: Model Architecture 1
  • 4日目: Model Architecture 2
  • 5日目: Model Architecture 3
  • 6日目: Why Self-Attention
  • 7日目: Training
  • 8日目: Results / Conclusion
  • 9日目: Source Code

シリーズ過去記事は一番下にリンク貼ってます。
それではみていきましょう。本論文の中核となる章なので、3回にわけて行います。

Model Architecture

3.2 Attention

Attentionに入ります。Attentionは注意とか注目とかいった意味の単語で、そう訳されてることも多いですが、英語のまま使うことにしてます。

An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors.

Attentionは、クエリ、キー、キーとペアになる値、そして出力値で構成されていて、クエリと、キーとキー値を変換するマッピング装置のようなものだ。出力値はすべてベクトルである。

それぞれが何なのかは追々でてきます。中身はベクトルですが、基本的に機械学習の中身はベクトル、それらをまとめた行列で計算するからです。

The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.

出力値は、キー値と重みの加重和で計算されている。
それぞれのキー値に対する重みは、クエリとそのクエリに対応するキーを使った変換関数から計算される。

詳しくは次をみていきましょう。

3.2.1 Scaled Dot-Product Attention

We call our particular attention "Scaled Dot-Product Attention" (Figure 2). The input consists of queries and keys of dimension d_k, and values of dimension d_v.

ここでAttentionをより具体的に、「Scaled Dot-Product Attention」と呼ぶことにする。実際のScaled Dot-Product Attentionの中を図示したのが下記。

入力は、クエリ Q、キー K、キー値 V。キー K の次元は d_k 、キー値 V の次元は d_v である。d はdimentionのdで次元のことです。

ここは定義式、つまり著者が決めた定義なので、そのまま受け入れます。

We compute the dot products of the query with all keys, divide each by \sqrt{d_k}, and apply a softmax function to obtain the weights on the values.

このScaled Dot-Product Attentionの中で、まず、クエリ Q とキー K のドット積、つまりベクトルの内積を計算して、ベクトルの各要素を \sqrt{d_k} で割ります。この計算結果をsoftmax関数に当てはめて、あとでキー値 V と内積を計算するための重みを計算します。

In practice, we compute the attention function on a set of queries simultaneously, packed together into a matrix Q.

実際には、単語ごとに存在する一連のベクトルのクエリ Q を1つの行列 Q としてまとめてしまって、同時に計算します。

行列とベクトルの違いはあまり意識する必要はありませんが、ベクトルの集合体が行列だと理解するといいと思います。下記の記事が特に参考になりました。

https://www.headboost.jp/what-is-matrix/
https://jp.mathworks.com/help/matlab/ref/dot.html

この行列の内積を計算するレイヤーを表現してるのが、図に2回出てくる「MatMul」です。MatMulはMatrix Multiplicationの略で、つまり行列の内積の英語を省略した言葉です。

MatMulについて参考にしたサイト

https://docs.scipy.org/doc/numpy-1.15.1/reference/generated/numpy.matmul.html

The keys and values are also packed together into matrices K and V . We compute the matrix of outputs as:

キー K とキー値 V についてもクエリ Q と同様に行列 KV にまとめます。ここまで説明した内容を数式で表すと下記のようになります。

\mathrm{Attention}(Q, K, V) = \mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V

The two most commonly used attention functions are additive attention [2], and dot-product (multiplicative) attention.

よく使われるAttentionの例としては、「Additive Attention」と「Dot-Product Attention」の2つがあります。

Dot-product attention is identical to our algorithm, except for the scaling factor of \sqrt{\frac{1} {d_k}}.

Dot-Product Attentionは、本論文と基本的に同じで、違いは、標準化のための \sqrt{\frac{1}{d_k}} がないところだ。その違いがあるので、本論文では、Scaled Dot-Product Attentionと呼ぶことにしている。

scaleとは恐らくpythonのscale関数から来ており、スケールを統一して、正規化・標準化されたという意味で、scaledが使われています。

こちらの記事も参考にしました。

https://note.nkmk.me/python-list-ndarray-dataframe-normalize-standardize/

これで下記の図のScaleの部分も説明されました。

Additive attention computes the compatibility function using a feed-forward network with a single hidden layer.

Additive Attentionは、単一の隠れ層を持つFeed-Forward Networkを使って変換関数を計算します。

Additive Attentionの参考論文は載ってませんでしたが、以前からあるAttentonの一種のようです。

While the two are similar in theoretical complexity, dot-product attention is much faster and more space efficient in practice, since it can be implemented using highly optimized matrix multiplication code.

この2つのAttentionは、理論的な複雑さでは似ているが、経験的にDot-Product Attentionの方がはるかに計算が早く、計算時に使用するメモリなどの容量の効率性も高いことが分かっている。それは、Dot-Product Attentionは、より最適化された行列の内積を内部的に実装しているからだ。

While for small values of d_k the two mechanisms perform similarly, additive attention outperforms dot product attention without scaling for larger values of d_k [3].

一方で、キー K の次元 d_k が小さいうちは、上記の2つのAttentionは同じくらいの性能だが、d_k が大きくなってきた場合に、d_k による正規化がなければ、Additive Attentionの方がDot-Product Attentionを性能で上回ることも知られている。

やや矛盾してるようにも思えますが、あまり論点ではないので、すすみます。

We suspect that for large values of d_k, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients ^4. To counteract this effect, we scale the dot products by \sqrt{\frac{1}{d_k}}.

著者のGoogle Brainを中心としたメンバー達は、まさにこのキー K の次元 d_k が大きく、Dot-Product Attentionの計算結果が異常に大きくなり、softmax関数を勾配消失させるのでは、と考えた。そのため、この事象を回避するため、\sqrt{\frac{1}{d_k}} で正規化することにした。

だとしてもなぜ正規化する時の値が \sqrt{\frac{1}{d_k}} なのか?と思っていたら、注釈がありました。

^4 To illustrate why the dot products get large, assume that the components of q and k are independent random variables with mean 0 and variance 1. Then their dot product, q·k = \sum_{i=1}^{d_k}, has mean 0 and variance dk.

なぜDot-Product Attentionの計算結果が大きくなってしまうのかというと、まずクエリ Q とキー K はそれぞれ平均値0、分散1の独立した変数である。それらの行列の内積 q·k = \sum_{i=1}^{d_k} は平均値0、分散 d_k であるため。

とあります。分散とは、色々な要素があった時の平均から、各要素が実際どれだけ離れているのか集合してるのかを表す指標であり、平均と各要素の差分の2乗を要素ごとに計算して、それを平均したものが分散です。

この分散は平均からどれだけ離れているかを指標化するために、途中で2乗しているので、最終的にこの分散の平方根をとったものが標準偏差です。なので、分散の平方根である、\sqrt{d_k} で割ることで、正規化になるということですね。

標準偏差と分散については数学や統計学の話になりますが、詳しくはこちらの記事などを参考にしてください。

https://best-biostatistics.com/summary/sd-variance.html

3.2.2 Multi-Head Attention

Multi-Head Attentionの説明。なぜMultiなのかというと、簡単に言うと、上述したAttention、具体的にはScaled Dot-Product Attentionを複数使ってるからです。

Instead of performing a single attention function with d_{model}-dimensional keys, values and queries, we found it beneficial to linearly project the queries, keys and values h times with different, learned linear projections to d_k, d_k and d_v dimensions, respectively.

なぜ複数使うのかというと、次元 d_{model} で出来たキー K、キー値 V、クエリQ を使ったSingle-Attention関数をただ使うよりも、クエリ、キー、キー値をそれぞれ h 回Linear関数を使って、次元 d_kd_kd_k、に減らした方が効率がいいとわかったため。下記の図の下にあるレイヤーのことです。

計算を行う際に、次元数は減らした方が高速になります。学習においては計算時間はかなり重要な要素であるため、次元を減らします。人間が認識できる次元は空間的には3次元までですが、3次元の空間の計算を行うよりも、2次元の平面の計算の方が簡単なのは感覚的にわかると思いますし、それ以上に1次元であればもっと簡単です。

元々の次元 d_{model} は512と定義されていますから、これを小さくできればかなり効率がよくなるということです。linerly projectやlinear projectionsが普通に翻訳すると出てきにくいですが、linearが線形変換用のLinear関数やLinearレイヤーで、projectはベクトルや線形代数における射影のことです。

射影や射影による次元削減、Linear関数に関しては、特にこちらの記事はわかりやすかったです。

https://www.hellocybernetics.tech/entry/2017/06/15/072248#効果
https://qiita.com/aya_taka/items/4d3996b3f15aa712a54f

On each of these projected versions of queries, keys and values we then perform the attention function in parallel, yielding d_v-dimensional output values.

上記のように射影して次元を減らしたクエリ、キー、キー値に対して、今回定義したAttention関数を並列に実行し、次元 d_v のアウトプットを得る。図の真ん中の紫色のレイヤーのことです。

These are concatenated and once again projected, resulting in the final values, as depicted in Figure 2.

上記で h 回出力されたアウトプットをConcatレイヤーで連結して、再度Linearレイヤーで別次元に射影することで、最終的なMuti-Head Attentionの出力値が得られます。

Concatは英語のConcatnateの略語で連結するという意味です、Linuxですとしょっちゅう使うcatコマンドがここから来ています。1ファイル開くだけだと感じませんが、複数ファイルくっつけて開くことも出来ますよね。

Concatレイヤーはこちらの記事も参考にしました。

https://qiita.com/kotai2003/items/7dd746d7e4118b7a44de

Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this.

Muti-Head Attentionを使うことで、今回のモデルは、異なる要素にある単語がもつそれぞれ異なる部分ベクトル空間におけるを同時に読みに行くことが可能になっている。Single-Attentionでは平均化によってそれができない。

Subspaceは線形代数やベクトルにおける部分空間のことです。

これらを数式で表現すると、

\begin{aligned} \mathrm{MultiHead}(Q, K, V) &= \mathrm{Concat}(\mathrm{head_1}, ...,\mathrm{head_h})W^O \\ \mathrm{where} \ \mathrm{head_i} &= \mathrm{Attention}(QW_i^Q,KW_i^K,VW_i^V) \end{aligned}

と定義します。

Where the projections are parameter matrices W_i^Q \in \mathbb{R^{d_{model} \times d_k}}, W_i^K \in \mathbb{R^{d_{model} \times d_k}}, W_i^V \in \mathbb{R^{d_{model} \times d_v}} \ \mathrm{and} \ W^O \in \mathbb{R^{hd_v \times d_{model}}}.

Linearレイヤーにおける射影関数は、このレイヤーの重みパラメーター行列であり、クエリ、キー、キー値、またConcatレイヤーの出力値に対してそれぞれ、W_i^Q \in \mathbb{R^{d_{model} \times d_k}}, W_i^K \in \mathbb{R^{d_{model} \times d_k}}, W_i^V \in \mathbb{R^{d_{model} \times d_v}}, W^O \in \mathbb{R^{hd_v \times d_{model}}} である。\mathbb{R} は実数の集合を意味します。

こちらの記事も参考にしています。

https://zenn.dev/wsuzume/articles/b0b3a51cac5d7fe4555b

In this work we employ h = 8 parallel attention layers, or heads. For each of these we use d_k = d_v = d_{model}/h = 64. Due to the reduced dimension of each head, the total computational cost is similar to that of single-head attention with full dimensionality.

本論文では、h = 8 を並列のMuti-Head Attentionのレイヤー数として採用。その中のそれぞれのAttentionにおいて、d_{model} = 512 なので、d_k = d_v = d_{model}/h = 64 である。この次元削減によって、トータルの計算コストは、元々の次元 d_{model} のSingle-Attentionと同等まで削減出来ている。

再度、図と数式を示します。

\begin{aligned} \mathrm{MultiHead}(Q, K, V) &= \mathrm{Concat}(\mathrm{head_1}, ...,\mathrm{head_h})W^O \\ \mathrm{where} \ \mathrm{head_i} &= \mathrm{Attention}(QW_i^Q,KW_i^K,VW_i^V) \end{aligned}

whereは「ただし」の意味です。このMulti-Head Attentionを行列 Q K V の複数の入力と出力からなる1つの関数とみなすと、こう書けます。

\mathrm{MultiHead}(Q, K, V) = f(x)

この f(x) は、Concatレイヤーをまた1つの関数と見ると、Concatレイヤーに対する h=8 個の複数の入力と1つの出力からなる関数に対して、さらにLinear関数がかかった形、とみなせるので、

\begin{aligned} \mathrm{MultiHead}(Q, K, V) &= f(x) \\ &= \mathrm{Linear}(x) \\ &= xW + b \\ &= \mathrm{Concat}(x)W^O + b \end{aligned}

と表現されます。ここでLinear関数で線形変換を行う際に、バイアス b は使わなくてもよいので、本論文では 0 ベクトルで定義されてます。重み W の上付き文字の OO 乗ではなく、単にConcat関数のアウトプットの O だと思われます。

また、通常、線形変換は y = Wx + b と、W が左側に来るのですが、定義上はどちらでも可能で、今回の場合は、右からかけないと、Multi-Head Attentionの出力としての行列の次元が d_{model} に戻らないために、右側からかけるモデルになっています。つまり、次元を変換する場合においては、基本的に右側からかけないといけない、ということですね。行列の内積においては AB \neq BA なので、捕捉しました。

おわり

AI界を席巻する「Transformer」を解説するシリーズ4日目は以上です。ここまで来れば峠は越えたも同然(?)かもしれません。次回はAttentionのおさらいからです。

感想や要望・指摘等は、本記事へのコメントか、TwitterのリプライやDMでもお待ちしております!

https://twitter.com/hnishio0105/status/1390359715041857537?s=20

また、結構な時間を費やして書いていますので、投げ銭・サポートの程、よろしくお願いいたします!

シリーズ関連記事はこちら

https://zenn.dev/attentionplease/articles/2d4b2b55ba396e
https://zenn.dev/attentionplease/articles/c2dba490ccba3f
https://zenn.dev/attentionplease/articles/5b4133a4956578
https://zenn.dev/attentionplease/articles/5510331c45e16a
https://zenn.dev/attentionplease/articles/1a01887b783494
https://zenn.dev/attentionplease/articles/4e09c41d7a85db
https://zenn.dev/attentionplease/articles/d0d7f4e406b4ed
https://zenn.dev/attentionplease/articles/75316e7ad65cfc