🐙

Transformerを納得できるところまで

2023/01/04に公開

はじめに

(今更ですが)ChatGPT が世を席巻しているということで,忘備録として書きます.
元論文の他にもわかりやすい記事が多数ありますので,参考にさせていただきました.
私自身 NLP 専門ではないので,明らかな間違いがございましたらご指摘下さい.

記事の方針と全体像

記事の方針

Deep Learning の基礎を知っている方がある程度納得できるような記事にしたいと考えています.そのために次の2点を意識して書きました:

  1. 理論的に納得するためには,要所技術の背景知識等があった方がよいと感じています.そのため,知識が足りないと感じたら,それを補いに行けるような記事を「補足」部分に示しています.
  2. 流れを納得するために,行列の形状を頭に浮かべながら流れを追うと具体的に行っていることがわかりやすいと感じたので,節ごとに付しています.

全体像

  • 畳み込み層や再帰構造を用いず,翻訳タスクの SoTA を達成し,現在では自然言語処理の他に画像系タスクにおいても使用されている最強手法
  • Attention を軸とした Encoder-Decoder 構造であり,分割すると次の4つから構成されます:
  1. Embeddings & Positional Encoding
  2. Encoder
  3. Decoder
  4. Linear & Softmax


[1] より一部引用

1 Embeddings & Positional Encoding

まず,transformer全体の中のどの部分かを示します.


[1] より一部引用

1.1 Embedding

単語埋め込みを行う層のこと.単語をそのままコンピュータで扱うことはできないので,いい感じにベクトル化する必要があります.最も簡単には,one-hot ベクトルで各単語を識別する(つまり語彙ごとに固有のIDを割り振る)ことが可能ですが,「各単語の関連性などを考慮したベクトル化ができない」「単語表現に必要となる空間が広すぎる」 等の理由から,一般的には似た単語は近いベクトルになるように低次元空間に射影する等の工夫が行われます.本論文ではこの手法については詳しく述べられていませんでしたが,学習済みの線形層を通しているとの記述があったため,行列により単語間の関連性をいい感じに保てる低次元空間に写像していることが推測できます.
行列の形状は,(LoS, N) \rightarrow (LoS, d_{model}) となっており,LoS は 入力する文の長さ,そして N 次元ベクトルで各単語を表現していますが,線形層を通して d_{model} 次元まで落ちていることがわかります.ただし,N は単語数,d_{model} は1単語を表現するのに必要な次元数(本論文では 512 に設定)となっています.

補足

1.2 Positional Encoding

transformer は,再帰構造や畳み込みが存在しないため,sequence 内での位置をモデルが知る術がありません.したがって,相対 or 絶対位置情報を別途補う必要があります.transformer では次の式を(絶対)位置情報として付加することを提案しています.

PE_{(pos, 2_i)} = sin(pos/10000^{2i/d_{model}}) \\ PE_{(pos, 2_{i+1})} = cos(pos/10000^{2i/d_{model}})

文字が3つ見当たりますが,まず,d_{model} は1単語を表すための次元数で,本論文内では512 に設定されている定数です.pos が単語の位置(position),i が1単語内の次元であり,これらは変数になります.
 つまり,sequence の何番目か(文中の何単語目か)と,分散表現(1単語のベクトル表現)の何次元目かによって値を加算させているという処理に相当します.このため,次元は変わらず (LoS, d_{model})です.

補足

2. Encoder

いよいよ本質に迫ってきました.Encoder の全体像を示します.
重要なのは,Multi-Head Attention の部分です.ここを理解すれば Encoder だけでなく, Decoder も理解したも同然です.

[1] より一部引用

2.1 Attention

Attention は最も簡単に説明すると,どこに注目すべきかを学習できる機構 です.Attention と一口に言っても種類があり,今回使用しているのは中でも (i) Dot-Product Attention と呼ばれるものを Scaling(正規化) し,(ii) Multi-Head 化した (iii) Self-Attention です.なんだかルー大柴さんみたいになってますが,ご容赦ください.順に説明していきます.

[1] より一部引用

(i) Scaled Dot-Product Attention

Attention の意味と行列の形状を意識しながら話を追っていきます.式は次の通りです.

\rm{Attention} (Q, K, V) = \rm{softmax} (\frac{QK^T}{\sqrt{d_k}})V

Q, K, Vのそれぞれの役割をまず認識しましょう.Q (query) は問い合わせです.それに対して,K (key) は,問い合わせに対して何らかの検索をします.V (value) はその検索結果に基づいて重み付けした適切な回答を返すイメージです.
 では,この認識を頭に置いて上式を理解していきましょう.一旦,Q, K, Vはすべて同じ大きさの行列としてください.つまり,\mathrm{Q_{size}} = \mathrm{K_{size}} = \mathrm{V_{size}} = (文の長さ LoS, 分散表現の次元数 d_{model})です.
 ここで,Q は (LoS, d_{model}) ですが,わかりやすくするために1単語のみについて考えてみましょう.つまり,\bold{q_1} = (1, d_{model}) です.K は転置されているので (d_{model}, LoS) となっており,行列積をとると,サイズは (1, LoS) になります.ここで重要なのが行列積をとることは単語 \bold{q_1} と \bold{k_1}, ..., \bold{k_{LoS}} との類似度計算に相当するということです.その後,\sqrt{d_k} で除算していますがいったん無視すると,softmax では単に確率に変換しているだけなので,\bold{q_1} と似ているキー \bold{k_x} (1 \le x \le LoS) に基づいて V を重み付けする操作に相当します.つまり (1, LoS)(LoS, d_{model}) の行列積なので,(1, d_{model}) となり,入力と出力のサイズは一致しますし,何よりも問い合わせの単語 \bold{q_1} に対して,似ている単語の重みを大きくした V が得られることがわかると思います.

補足

(ii) Multi-Head Attention

先ほどの Scaled Dot-Product Attention を並列に複数行うことを指します.ただし単純に複数回行った場合,同じ結果が複数回得られるだけでなんの意味もありませんし,増加した次元に対してどのような操作を行うのかも疑問です.この部分をどのように工夫しているのかを見ていきましょう.
式は,次の通りです.

\mathrm{MultiHead}(Q, K, V) = \mathrm{Concat(head_1, ..., head_h)}W^O \\ \mathrm{where \quad head_i = Attention}(QW_i^Q, KW_i^K, VW_i^V)

この式を見ると,先ほどの Scaled Dot-Product Attention が where 句に現れています.この引数部分に注目すると,Q, K, V に何か行列がかかっていることに気づきます.これが先ほどの「補足」部分で述べた変換です.この行列をかますことで,Q, K, V をそれぞれいい感じに変換しています.その後の Attention により得られた行列を Head と呼び,それを h 個 concat (連結)した後に,特定の行列 W^O を掛けることで出力を得ています.
 この意味としては,それぞれ異なった役割を持つ Head をそれぞれ組み合わせることで,最終的に豊かな表現力を持つ Multi-Head を得ることと捉えることができます.
 もちろん,これだけでは出力される行列が増えてしまい困ります.したがって,それぞれの Head で乗算する W_i^Q, W_i^K, W_i^V のサイズを (d_{model}, \lfloor d_{model} / h \rfloor) とすることで,Attention の Q, K, V のサイズを (LoS, \lfloor d_{model} / h \rfloor) まで落とし,それに伴って各 Head のサイズも (LoS, \lfloor d_{model}/h \rfloor) まで落ちます.その後,行列の横方向に h 個 concat することで,(LoS, \lfloor d_{model} / h \rfloor \times h) のサイズへと戻します.最後に,(\lfloor d_{model} / h \rfloor \times h, d_{model}) サイズの W_O によって,(LoS, d_{model}) まで戻します.以上の操作で行列のサイズは入力と出力で等しくなります.

補足

(iii) Self-Attention

ここでは,「K, Q, V が元々同じ部分から得られる Attention のこと」という理解で大丈夫です.

補足

2.2 Add & Norm

先ほどから,入出力のサイズを一致させていたことに気づいたと思いますが,この理由がここでわかります.まず,式を見てみましょう.

\mathrm{LayerNorm}(x + \mathrm{Sublayer(x)})

ここで用いられている工夫は2つあり,1つは Shortcut Connectionx + \mathrm{Sublayer(x)} 部分)で,もう1つは Layer Normalization\mathrm{LayerNorm}() 部分)です.
Shortcut Connection は ResNet において,劣化問題(degradation)(層を深くすると,test だけでなく train の精度まで下がってしまう問題)を解決するために導入された機構です.この機構を導入するためには,Sublayer の出力と入力 x のサイズが一致している必要がありますが,深層学習モデルの層を段違いに深くすることが可能になるため,非常によく用いられます.
Layer Normalization は,バッチサイズに平均・分散が依らず,また入力サイズが一定でないsequence データを用いるタスクにおいてよく使用され,勾配消失・爆発などを避け,学習を安定化させる効果を見込めます.


[2] より一部引用


[2] より一部引用

補足

2.3 Feed Forward Network (FFN)

Attention 層を通った後,2層の全結合層を通ります.具体的には以下の式で表されます.

\mathrm{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2

式を見ると,「全結合層 \rightarrow ReLU \rightarrow 全結合層」となっていることがわかります.入力 x(1, d_{model}) で表されている点に注意してください.

補足

2.4 Add & Norm

2.2 と同様です.

3. Decoder

続いて Decoder です.全体像は以下の通りです


[1] より一部引用

ほとんど Encoder と同じですが,1つ異なる点があります.
それは Masked Multi-Head Attention です.先ほどの Scaled Dot-Product Attention を確認すると,Scale と Softmax の間に Mask が存在しています.

Mask は,Decoder の入力において答えとなる単語の情報を含めないようにする目的で導入されています.(学習時にはまとめて文を入力するため必要となりますが,推論時は逐次処理なので必要ないようです
行列のサイズは,Encoder と同様に考えることで,(LoS, d_{model}) を保つことができます.

補足

4. Linear & Softmax

いよいよ最後です.


[1] より一部引用

Decoder で出力された値を 線形層 \rightarrow softmax に通すことで,次のトークン(単語)の確率に変換しています.ここでの線形層の重みは embedding 層の重みと共有されています.
行列の形状は,(LoS, d_{model}) \rightarrow (LoS, N) となります.

補足

おわりに

かなり長くなってしまいましたが,全体的にやりたいことが理解できた気がしています.
実験結果や学習時の工夫など,細かいところまで知りたい方は是非元論文をご参照ください.
百聞は一見に如かずということで,次回は実装して動かすところをやりたいと思います.

参考

  1. transformer元論文
  2. ResNet元論文
  3. The Illustrated Transformer
  4. 全力解説!Transformer
  5. Python(PyTorch)で自作して理解するTransformer
  6. 作って理解する Transformer / Attention
  7. Transformerを理解したい
  8. 図で理解するTransformer
  9. 埋め込み層 (Embedding Layer) [自然言語処理の文脈で]
  10. 位置符号化 (Positional Encoding, 位置エンコーディング) [Transformer の部品]
  11. 【自然言語処理】Attentionとは何か
  12. バッチ正規化(Batch Normalization)とその発展型
  13. Transformer と seq2seq with attention の違いは?[系列変換]【Q and A 記事】
  14. 自然言語処理の必須知識 Transformer を徹底解説!

Discussion