👌

数式を追う!Transformerにおけるattention

2023/12/15に公開

本記事は、「LabBaseテックカレンダー Advent Calendar 2023」 15日目の記事です。
https://qiita.com/advent-calendar/2023/labbase

はじめに

株式会社LabBaseでインターンをしている佐藤拓真と申します。

  • 本記事では、attentionについて、特にTransformerにおいて使用されている形 に注目して、数式を追った理論的な理解をすること を目指します。
    • 数式やその解説は、『自然言語処理の基礎』[岡崎, 2022] に大幅に依拠しています。
    • プログラミング言語による実装は取り扱いません。
    • Transformerにおけるattention以外の主要技術は取り扱いません。
    • Transformerを使用した主要な事前学習済モデル(BERT, GPT等)については扱いません。
  • 自然言語のみによる平易な解説というより、数式など理論的な側面に注目を当てます。
    • 数式など理論的な側面の理解のために、平易な自然言語を道具として使います。
    • 必要な数学の知識として、大学の理工系学部1,2年程度の線形代数の知識を前提しています。

Transformerとattentionの概要

昨今、大規模言語モデル(LLM, Large Language Model)が隆盛を見せ、ChatGPT等のサービスが世界に極めて強いインパクトをもたらしています。

大規模言語モデルの成功にもっとも寄与している技術要素の一つとして、Transformerが挙げられます。Transformerは"Attention is All You Need"[Vaswani et.al., 2017]において発表されたニューラルネットワークアーキテクチャです[1]。2023年現在主流の事前学習済みモデル(pre-trained model)においては、ほとんどの場合Transformer(ないしその派生アーキテクチャ)が使用されています。たとえば、「ChatGPT」は名前の通り"GPT"という大規模事前学習済みモデルを使用したサービスですが、GPTは"Generative Pre-trained Transformer"を意味しています。

Transformerはさまざまな技術的要素から構成されますが、その中でももっとも中心的かつ重要といえる技術要素の1つとして、attention mechanism(注意機構) が挙げられます[2]。attentionは、大雑把にひとことで表現すると、「入力におけるどの部分に注目して処理を行うかを計算する仕組み」といえます。

Transformerにおいては、attentionを更に洗練させて、self-attentionやQKV attention、multi-head attentionという技術が使用されています。以下の章においては、基本的なattentionについて整理したあと、これらの技術についても解説を行います。

attention

attention mechanism(注意機構)は、直観的には、seq2seq(系列変換)モデルにおいて、「encoderに対する入力系列」と「decoderからの出力系列」の間にある「注意」の関係を計算するメカニズムとして説明されます。例えば、"I love you"という入力系列を「私は貴方を愛している」という系列に変換する翻訳タスクを考える時、出力における「愛している」というトークンは、入力における"you"というトークンに注目していることが期待されます。attentionは、この注目の度合いを計算しているようなイメージかと思います。

ここからは数式を追います。

入力系列を\bm{X}=(\bm{x}_1,...,\bm{x}_I)、出力系列を\bm{Y}=(\bm{y}_1,...,\bm{y}_I)とします。ここで、Iは入力系列長、Jは出力系列長を表します。入力系列のi番目のトークンは\bm{x}_i、出力系列のj番目のトークンは\bm{y}_jと表現されます。ただし、\bm{x}_i, \bm{y}_jは、それぞれd次元ベクトルとします。

つぎに、encoderからの出力ベクトル\bm{h}_iとします。ただし、\bm{h}_iも、encoderへの入力トークンのベクトル表現と同じく、d次元ベクトルとします。「decoderからの最終的な出力ベクトル」ではなく、あくまでencoderが出力する(decoderに引き渡される)特徴ベクトルであることに注意してください。\bm{h}_i\bm{x}_iに対応するということになります。次に、このencoderが出力した特徴ベクトルを並べたベクトル列\bm{H}を考えます。\bm{H}d\times I行列ということになります。

ここで、出力系列におけるトークン\bm{y}_jを予測するために必要な特徴ベクトル\bm{z}_jを考えます。この\bm{z}_jd次元ベクトルとします。\bm{z}_jは、decoderの最終層の出力と考えることができます。ここで、attention mechanismは、この\bm{z}_jを再構築して新しい特徴ベクトル\hat{\bm{z}_j}を作成します\hat{\bm{z}_j}は、以下のように計算されます。再構築後の\hat{\bm{z}_j}も、下式によりd次元ベクトルとなります(計算で確かめることをおすすめします)。

\hat{\bm{z}_j}=\bm{H}\bm{a}\\ \bm{a}=\textsf{softmax}(\bm{a'})\\ \bm{a'}=\bm{H}^{\top}\bm{z}_j

ただし、ここで、行列に対するsoftmax関数は以下のように定義されるものとします。

\textsf{softmax}(\bm{A})=\bm{A'}\odot\textsf{reciprocal}(\bm{1}\bm{A'})\\ \bm{A'}=\exp(\bm{A})\\ (ただし、\textsf{reciprocal}関数は、入力された行列の各要素を逆数に変換する関数。)

以上がattention mechanismの行っている計算です。特徴ベクトル\bm{z}_jを再構築して新しい特徴ベクトル\hat{\bm{z}_j}が作成されました。

self-attention

前章のattention mechanismにおいては、decoderの出力である特徴ベクトル\bm{z}_jを再構築するために、encoderの出力ベクトルを並べた行列である\bm{H}を使用していました。言い換えれば、\bm{z}_jを再構築するために\bm{H}を参照していました。

これに対して、self-attention mechanismでは、「再構築したいベクトル」と「再構築する際に参照するベクトル」の情報源が同じ になります。これが"self-attention(自己注意)"というタームの所以です。

attentionとself-attentionにおいて、数式に現れる違いはこの点のみになります。よって、行われている基本的な操作は、通常のattentionと同じです。実際に、self-attention mechanismの数式を見てみましょう。記号や関数の意味/指示対象は、すべて前章と同じです。

まず、encoder側のself-attention mechanismは以下の式で表現されます。

\hat{\bm{h}_i}=\bm{H}\bm{a}\\ \bm{a}=\textsf{softmax}(\bm{a'})\\ \bm{a'}=\bm{H}^{\top}\bm{h}_i

\bm{h}_jを再構築して\hat{\bm{h}_j}を作成していますが、その際に参照しているのも\bm{H}です。通常のattentionで\bm{z}_jを再構築するために\bm{H}を参照していたことと対比してみましょう。attention mechanismにおいて、情報源が同じであり、「自己注意」をしていることがわかります。

次にdecoder側のself-attention mechanismは以下の式で表現されます。

\hat{\bm{z}_j}=\bm{Z}\bm{a}\\ \bm{a}=\textsf{softmax}(\bm{a'})\\ \bm{a'}=\bm{Z}^{\top}\bm{z}_j

説明すべきことは、encoder側のself-attentionと全く変わりません。\bm{h}_i, \bm{H}\bm{z}_j, \bm{Z}に変わっただけです。ここでも\bm{h}_iを再構築するために\bm{Z}という同一の情報源を参照しており、「自己注意」が行われていますね。

query-key-value attention(QKV attention)

前章まででattentionとself-attentionの説明を行いました。Transformerでは、その少し発展した形であるquery-key-value attention(QKV attention, QKV注意機構) が使われています。

query-key-valueとはどれもデータベースに関連したタームですが、それがなぜattentionやTransformerの文脈において突然登場したのかと困惑するかもしれません。
実は、通常のattention mechanismも、データベースへの問い合わせにたとえて説明されることがあります。encoderの出力ベクトルを並べた列(行列)\bm{H}=(\bm{h}_1,...,\bm{h}_I)を「データベース」decoderの出力系列の各位置jにおける\bm{z}_jを「クエリ」とみなし、データベースに対してクエリを投げることで新しいベクトル\hat{\bm{z}_j}を得る、と考えるのです。

QKV attentionは、この考え方の拡張と捉えることができます。QKV attentionは、「key-value方式のデータベース」が存在し、そのデータベースに対してqueryで問い合わせを行っているようなイメージです。

といっても、QKV attentionは多少複雑なので、実際の式を先に見ていただき、そこから説明を行うほうがよいかと思います。ということでまずはじめに、self-attentionではなく、encoder-decoder間で情報を受け渡す場合のQKV attentionの数式を示します。

まず、記号 \bm{W^{(Q)}},\bm{W^{(K)}},\bm{W^{(V)}}\in \mathbb{R^{d\times d}}を導入します。これらは、encoderやdecoderが各位置で計算しているベクトル\bm{h}_i\bm{z}_iを、それぞれQueryベクトル、Keyベクトル、Valueベクトルに変換するd\times d行列です。これらが、QKV attentionにおけるモデルパラメータとなります

このとき、queryベクトル、keyベクトル、valueベクトルは、それぞれ以下のように計算されます。
ただし、\bm{K}はkeyベクトルを並べた行列で、\bm{K}=(\bm{k}_1,...,\bm{k}_I)です。
\bm{V}はvalueベクトルを並べた行列で、\bm{V}=(\bm{v}_1,...,\bm{v}_I)です。

\bm{q}_j=\bm{W}^{(Q)}\bm{z}_j\\ \bm{K}=\bm{W}^{(K)}\bm{H}\\ \bm{V}=\bm{W}^{(V)}\bm{H}\\

上で通常のattentionについて「decoderの出力系列の各位置jにおける\bm{z}_jを「クエリ」とみなし、データベースに対してクエリを投げるイメージ」と説明しました。QKV attentionについても、クエリ\bm{q}_jはベクトルです。decoderの出力系列中のトークンベクトル\bm{z}_jをモデルパラメータ\bm{W}^{(Q)}で調整したものがqueryになっているだけなので、「decoderの出力系列中のベクトルをクエリとしてデータベースに投げる」という部分は変わりませんね。

「データベース」は、上述したようにkeyとvalueで構成されています。一般にkeyとvalueはペアとなってkey-valueの構造をなすわけですが、ここにおいても同様に、\bm{K}\bm{V}における同じi番目のベクトルが、1つのkey-valueペア(\bm{k}_i, \bm{v}_i)となります。なお、\bm{q}_i, \bm{k}_i, \bm{v}_iはd次元ベクトル、\bm{K},\bm{V}d\times I行列となります。

「データベース」に「クエリ」を投げて結果を得る操作は以下になります。この操作によって、\bm{q}_iを再構築して\hat{\bm{q}_i}が作成されます。

\hat{\bm{q}_i}=\bm{V}\bm{a}\\ \bm{a}=\textsf{softmax}(\bm{a'})\\ \bm{a'}=c\bm{K}^{\top}\bm{q}\\ (ただし、c=\frac{1}{\sqrt{d}}とする。つまり、cはスカラーの定数である。)

\bm{q}_iはdecoderの出力系列中のベクトル\bm{z}_jをモデルパラメータで調整したものなので、「decoderの出力系列中の特徴ベクトルを再構成して新たな特徴ベクトルを得る」というattention mechanismの根幹の操作は変わっていないですね。定数cが新たに付いていることで、前章までのattentionと違うもののように見えてしまうかもしれませんが、これはただの調整のための定数なので、本質的な違いはありません(多分)[3]

以上で、encoder-decoder間でのQKV attentionにおいても、通常のattentionと同じく、decoderの出力系列中の特徴ベクトルを再構成して新たなベクトルを得ることができました。

次に、encoder, decoderそれぞれ単体の中でのself-attentionにおいてQKV attentionを採用した場合の式を見てみましょう。といっても、query, key, valueの求め方が変わるだけなので、その式だけを示します。その後の\hat{\bm{q}_i}の求め方は全く同じです。

(encoder)\\ \bm{q}_i=\bm{W}^{(Q)}\bm{h}_i\\ \bm{K}=\bm{W}^{(K)}\bm{H}\\ \bm{V}=\bm{W}^{(V)}\bm{H}\\ (decoder)\\ \bm{q}_j=\bm{W}^{(Q)}\bm{z}_j\\ \bm{K}=\bm{W}^{(K)}\bm{Z}\\ \bm{V}=\bm{W}^{(V)}\bm{Z}\\

self-attentionなので、encoderにおいては情報源がすべて\bm{H}に、decoderにおいては情報源がすべて\bm{Z}になっているだけですね。

ちなみに、以上のようなQKV attentionを用いたattentionのことを[Vaswani, 2017]では"Scaled Dot-Product Attention"と呼んでいます。Scaled Dot-Product Attentionは短く以下の一行の式で表現されています(内容としては上述の式たちと全く同じ)。ただし、クエリベクトルを行列表記にした\bm{Q}=(\bm{q}_i,...,\bm{q}_T)を導入します(Tについては脚注[4]を参照)。また、d_kは上の式たちにおけるdと同じです。

\mathrm{Attention(Q, K, V)=softmax(\frac{QK^{\top}}{\sqrt{d_k}})V}

multi-head attention

上述したQKV attentionを複数用いるために、Transformerではmulti-head attentionという仕組みを採用しています。"multi-head"における"head"はそれぞれのattention機構を指しているため、単に「1つのモデルの中でattention機構を複数使うための仕組み」という程度の意味になるかと思います。

複数のattention機構(=head)を使うことのメリット(とそれに伴うデメリット)については後に回して、数式を確認します。ここでは、用いるheadの数をH_{num}と表記します。例えば、H_{num}=16は「16つのhead(attention機構)を用いる」という意味です。encoderの出力ベクトルの列\bm{H}と混同しないよう注意してください。全く異なるものです。

まず、必要な記号を導入します。

i\in\{1,...,H_{num}\}\\ \bm{W}_i^{(Q)}\in \mathbb{R}^{d\times d_k}\\ \bm{W}_i^{(K)}\in \mathbb{R}^{d\times d_k}\\ \bm{W}_i^{(V)}\in \mathbb{R}^{d\times d_v}\\ \bm{W}_i^{(O)}\in \mathbb{R}^{H_{num}d\times d_v}\\ ただし、\\ d_k=d_v=\frac{d}{H_{num}}

\bm{W}_i^{(Q)}, \bm{W}_i^{(K)}, \bm{W}_i^{(V)}, \bm{W}_i^{(O)}はすべてモデルパラメータです。
また、ここでのiは系列中のトークンの番号ではなく、headの番号を表していることに注意してください。前章までで使用していたiとは意味が異なります。

このとき、multi-head attention mechanismは以下の式で表現されます。

\hat{\bm{Q}}=\mathrm{MultiHead}(\bm{Q, K, V})=\mathrm{Concat}(\hat{\bm{Q_1}},...,\bm{\hat{Q}_{H_{num}}})\bm{W}^{(O)}\\ \hat{\bm{Q_i}}=\mathrm{Attention(W_i^{(Q)}\bm{Q}, W_i^{(K)}\bm{K}, W_i^{(V)}\bm{V})}

ただし、\mathrm{Concat}は、行列を横方向に連結する操作を表します。

前章までのとおり、\mathrm{Attention}を用いてクエリベクトル列\bm{Q}を新しいベクトル列\hat{\bm{Q_i}}に再構築し、その再構築した後の\hat{\bm{Q_1}},...,\hat{\bm{Q_{H_{num}}}}\mathrm{Concat}して、複数のheadからの出力を結合した最終的な特徴ベクトル列\hat{\bm{Q}}を得ているというわけですね。
ステップを複数挟んでいるので混乱しますが、「\bm{Q}を再構築して\hat{\bm{Q}}を得る」というattentionの本質的な操作としては変わっていません。

さて、ここで、先程飛ばしていた「複数のhead(attention機構)を組み合わせて使用することのメリット」について言及したいと思います。

[岡崎, 2022]は、headの複数化の利点を、softmax関数の性質に基づいて説明しています。softmaxは、入力ベクトルの中の1つの要素を1に近い値に、それ以外の要素を0に近い値に変換する傾向があります。そのため、attention機構を1つだけ用いた場合、注目する観点が1つに偏りがちであり、この問題を解決するために複数のattention機構を用いて情報をうまく「混ぜ合わせる」ことがmulti-head attentionには期待されていると述べられています。一方で、同文献では\bm{H}_{num}の値が増大するほど計算量も増大する、いわば「性能と速度(あるいはメモリ量)のトレードオフ」があることにも言及しており、\bm{H}_{num}を適切に選択する必要性が指摘されています。

[Lewis, 2022]でも、この説明と同じことが述べられています。

しかし、なぜ複数のアテンションヘッドが必要なのでしょうか。その理由は、1つのヘッドのソフトマックスでは、類似性の一面にしか注目しない傾向があるからです。複数のヘッドを持つことで、モデルは一度に複数の側面に注目できます。(p.72)

おわりに

本記事では、Transformerにおいて使用されているattentionについて、その数式を追いました。

LabBaseアドベントカレンダー2023、明日は @takahashik0422さんの記事になります!ぜひご期待ください!

参考文献

  • [岡崎, 2022]岡崎直観・荒瀬由紀・鈴木潤・鶴岡慶雅・宮尾祐介『自然言語処理の基礎』、オーム社、2022年
  • [岡野原, 2023]岡野原大輔『大規模言語モデルは新たな知能か』、岩波書店、2023年
  • [黒橋, 2023]黒橋禎夫『三訂版 自然言語処理』、放送大学教育振興会、2023年
  • [斎藤, 2018]斎藤康毅『ゼロから作るDeep Learning②──自然言語処理編』、オライリー・ジャパン、2018年
  • [Cheng, 2016]Cheng Jiampreng, Li Dong, and Mirella Lapata. "Long Short-Term Memory-Networks for Machine Reading", EMNLP, 2016
  • [Lewis, 2022]Lewis Tunstall, Leandro von Werra, and Thomas Wolf(中山光樹訳)『機械学習エンジニアのためのTransformers──最先端の自然言語処理ライブラリによるモデル開発』、オライリー・ジャパン、2022年
  • [Vaswani, 2017]Ashish Vaswani et.al.,"Attention Is All You Need", arXiv, 2017
脚注
  1. 現論文における定義はともかくとして、現在"Transformer"と言った場合、その語が厳密に何を指しているかは、文脈や語の使用者によって異なる場合があります。本記事では、岡崎, 2022に従い、encoderとdecoderをあわせた系列変換モデルの形式のことを"Transformer"と呼ぶこととします。 ↩︎

  2. 本記事では扱いませんが、TransformerにおいてAttentionと並ぶ中心的な技術要素として、positional encodingやlayer normalizeationなどもしばしば挙げられます。 ↩︎

  3. 元論文([Vaswani, 2017])においては、このc=\frac{1}{\sqrt{d}}は"scaling factor"と呼ばれています。このscaling factorは、dが大きい場合においてドット積が大きくなり、softmax関数が極めて小さい勾配しかもたないような領域に押し込まれることを避けるために導入されているようです。論文中では[Cheng, 2016]がreferされていました。正直私はあまりよくわかっていません。 ↩︎

  4. encoder単体、decoder単体のself-attention mechanismにおいてはそれぞれS=T=I, S=T=Jです。encoderの情報を参照するdecoderのattention mechanismにおいてはS=I, T=Jです。 ↩︎

Discussion