💁🏻

Transformer: アテンションの計算式の意味を数理的に理解する

2022/07/31に公開

はじめに

Transformerにおけるアテンションの計算式は、scaleを無視すると以下のように計算される[1]

\text{output} := \text{softmax}(QK^\top) \tag{1}V

この計算が数理的にどのような意味を持つのかについて考察する。

記法

  • 以下の議論では、表記を簡単にするため、「Xの埋め込みベクトルのシーケンス」を単に「Xのシーケンス」と表現する。

考察

まず、式(1)の計算は以下の2つのパートに分割できる:

  1. アテンションスコアの計算
  2. 特徴量の選択

1. アテンションスコアの計算

\text{softmax}(QK^\top)の部分である。ここで、Q, Kはそれぞれ(n, d)次元のベクトルとする。nはシーケンス長で、dは埋め込みベクトルの次元である。すなわち、Q, Kの行方向はトークンのシーケンスを表し、列方向は埋め込みベクトルを表す。この時、Q, Kの行ベクトル、すなわちトークンごとの埋め込みベクトルの集合をそれぞれ\{q_i\}, \{k_j\}とおくと、

QK^\top = \begin{bmatrix} \langle q_1, k_1 \rangle & \cdots & \langle q_1, k_m \rangle \\ \vdots & \ddots & \vdots \\ \langle q_n, k_1 \rangle & \cdots & \langle q_n, k_m \rangle \\ \end{bmatrix} \tag{2}

とかける(\langle x, y \rangleはベクトルx, yの内積を表す)。つまり、この行列は2つのトークン列\{q_i\}, \{k_j\}のすべてのペアに対してアテンションスコア[2]を計算したものである。アテンションスコアが意味するところは、Q, K, Vのそれぞれに何を代入するかによって変わる。

まずself-attentionの場合、Q=XW^Q, K=XW^K, V=XW^Vである。ここでXは「エンコーダ(あるいはデコーダ)の入力シーケンス」である。この時、式(2)の行列は入力シーケンス中の各トークンペアの間アテンションスコアを総当たりで計算していることを意味する(ただし、maskが指定してある場合は部分的な関係にのみ注目する)。これは例えば、"I have a dog."というシーケンスを入力した場合、(I, have, a, dog, .)のそれぞれのトークンに対するアテンションスコアを計算することを意味する。

また、source-attentionの場合、Q=YW^Q, K=MW^K, V=MW^Vである。ここでMは「エンコーダの出力シーケンス」、Yは「デコーダの入力シーケンス」である。この時、式(2)の行列はエンコーダが解釈した入力シーケンスのコンテクストとデコーダの入力シーケンスのアテンションを総当たりで計算したものを表す(ただし、maskが指定してある場合は総当たりでなく一部の関係のみについて計算する。

最後に、softmaxをとる。最終的な計算結果は、シーケンス中のどのトークンに注目すべきかを表す正規化されたスコア(全シーケンスで合計すると1になる)を表す。

2. 特徴量の選択

アテンションスコアを行列Aで表すと、\text{softmax}(QK^\top)V=AVとなる。この式はトークンの埋め込みベクトルのシーケンスの重み付き和を計算する。重みが大きいトークンはより強く注目され、小さいトークンは弱く注目されることを意味する。

\begin{bmatrix} a(q_1, k_1) & \cdots & a(q_1, k_m) \\ \vdots & \ddots & \vdots \\ a(q_n, k_1) & \cdots & a(q_n, k_m) \end{bmatrix} \begin{bmatrix} v_1 \\ \vdots \\ v_m \\ \end{bmatrix} = \begin{bmatrix} \sum_i^m{a(q_1, k_i)v_i} \\ \vdots \\ \sum_i^m{a(q_n, k_i)v_i} \end{bmatrix} \tag{3}

Query, Key, Valueの名称の由来について

Query, Key, Valueの名称の由来について考察してみる。(3)式によると、

  • 計算結果はQと同じ次元を持つ
  • アテンションスコアはQとKの各行ベクトルのペアについての関係を表現する
  • KとVの行ベクトルは1:1に対応している(添字が一致している)
  • 計算結果はVの行ベクトルの線形結合で表される

という性質を持つので、

  • 目的はQのそれぞれのトークンについての情報を取り出すこと
  • KにはQの行ベクトルごとにVのどの行ベクトルに着目すればよいかの情報がエンコードされている(アテンションスコアはKとQの内積で計算する)。
  • Vには最終的に出力として伝達する情報の材料(=ベクトルの基底)がエンコードされている

という解釈が成り立つ。

つまり、(3)式を一言で説明すると「Query(=問い合わせ)の各トークンに対応する、Value(=値)の重要度をKey(=鍵)を使って取り出し、Vの行ベクトルの線形結合として出力したもの」と解釈できる。

補足

  • Q: XやYを直接入力せず、線形変換を作用させたものを入力するのはなぜか?
    • A: 理由は2つある。
      • 1つ目の理由は、トークンの埋め込み表現に対する依存度を下げるため。例えば、self-attentionにおいて、KとQにともにXを入力したとする。このとき、アテンションスコアは単なるXの行ベクトルどうしの内積、すなわち類似度に近い指標となるが、これだとアテンションに「似た意味の(=埋め込みが近い)トークンに注意を向ける」という機能しか持たせることができない。線形変換を作用させることで、元のトークンの埋め込みベクトルとは異なる空間にマッピングされるため、元のトークンの埋め込み表現に対する依存度を下げる効果がある。
      • 2つ目の理由は埋め込みベクトルの次元を変換したいため。テキスト[1:1]では\mathbb{R}^{d_\text{model}} \rightarrow \mathbb{R}^{d_k}の写像となっている。トークンの埋め込みベクトルの次元とアテンションを計算する際の埋め込みベクトルの次元は一致している必要はない。これは例えば、次元を削減することでアテンションヘッドを増やしても計算量を一定に保つ効果がある。テキストでは512次元->64次元に圧縮している。
  • (3)式はQの行ベクトルどうしの関係は表現しないことに注意。したがって、デコーダにおいてself-attentionとsource-attentionのモジュールがそれぞれ必要になる。source-attentionモジュールはQすなわちデコーダの入力シーケンス中のトークンどうしの関係は記述できないからである。
脚注
  1. https://nlp.seas.harvard.edu/2018/04/03/attention.html ↩︎ ↩︎

  2. ここで扱われているattentionはdot productによるattentionで、以下の論文で紹介されている: Effective Approaches to Attention-based Neural Machine Translation ↩︎

GitHubで編集を提案

Discussion