Transformerを納得できるところまで
はじめに
(今更ですが)ChatGPT が世を席巻しているということで,忘備録として書きます.
元論文の他にもわかりやすい記事が多数ありますので,参考にさせていただきました.
私自身 NLP 専門ではないので,明らかな間違いがございましたらご指摘下さい.
記事の方針と全体像
記事の方針
Deep Learning の基礎を知っている方がある程度納得できるような記事にしたいと考えています.そのために次の2点を意識して書きました:
- 理論的に納得するためには,要所技術の背景知識等があった方がよいと感じています.そのため,知識が足りないと感じたら,それを補いに行けるような記事を「補足」部分に示しています.
- 流れを納得するために,行列の形状を頭に浮かべながら流れを追うと具体的に行っていることがわかりやすいと感じたので,節ごとに付しています.
全体像
- 畳み込み層や再帰構造を用いず,翻訳タスクの SoTA を達成し,現在では自然言語処理の他に画像系タスクにおいても使用されている最強手法
- Attention を軸とした Encoder-Decoder 構造であり,分割すると次の4つから構成されます:
- Embeddings & Positional Encoding
- Encoder
- Decoder
- Linear & Softmax
[1] より一部引用
1 Embeddings & Positional Encoding
まず,transformer全体の中のどの部分かを示します.
[1] より一部引用
1.1 Embedding
単語埋め込みを行う層のこと.単語をそのままコンピュータで扱うことはできないので,いい感じにベクトル化する必要があります.最も簡単には,one-hot ベクトルで各単語を識別する(つまり語彙ごとに固有のIDを割り振る)ことが可能ですが,「各単語の関連性などを考慮したベクトル化ができない」,「単語表現に必要となる空間が広すぎる」 等の理由から,一般的には似た単語は近いベクトルになるように低次元空間に射影する等の工夫が行われます.本論文ではこの手法については詳しく述べられていませんでしたが,学習済みの線形層を通しているとの記述があったため,行列により単語間の関連性をいい感じに保てる低次元空間に写像していることが推測できます.
行列の形状は,
補足
1.2 Positional Encoding
transformer は,再帰構造や畳み込みが存在しないため,sequence 内での位置をモデルが知る術がありません.したがって,相対 or 絶対位置情報を別途補う必要があります.transformer では次の式を(絶対)位置情報として付加することを提案しています.
文字が3つ見当たりますが,まず,
つまり,sequence の何番目か(文中の何単語目か)と,分散表現(1単語のベクトル表現)の何次元目かによって値を加算させているという処理に相当します.このため,次元は変わらず
補足
2. Encoder
いよいよ本質に迫ってきました.Encoder の全体像を示します.
重要なのは,Multi-Head Attention の部分です.ここを理解すれば Encoder だけでなく, Decoder も理解したも同然です.
[1] より一部引用
2.1 Attention
Attention は最も簡単に説明すると,どこに注目すべきかを学習できる機構 です.Attention と一口に言っても種類があり,今回使用しているのは中でも (i) Dot-Product Attention と呼ばれるものを Scaling(正規化) し,(ii) Multi-Head 化した (iii) Self-Attention です.なんだかルー大柴さんみたいになってますが,ご容赦ください.順に説明していきます.
[1] より一部引用
(i) Scaled Dot-Product Attention
Attention の意味と行列の形状を意識しながら話を追っていきます.式は次の通りです.
Q, K, Vのそれぞれの役割をまず認識しましょう.Q (query) は問い合わせです.それに対して,K (key) は,問い合わせに対して何らかの検索をします.V (value) はその検索結果に基づいて重み付けした適切な回答を返すイメージです.
では,この認識を頭に置いて上式を理解していきましょう.一旦,Q, K, Vはすべて同じ大きさの行列としてください.つまり,
ここで,Q は
補足
(ii) Multi-Head Attention
先ほどの Scaled Dot-Product Attention を並列に複数行うことを指します.ただし単純に複数回行った場合,同じ結果が複数回得られるだけでなんの意味もありませんし,増加した次元に対してどのような操作を行うのかも疑問です.この部分をどのように工夫しているのかを見ていきましょう.
式は,次の通りです.
この式を見ると,先ほどの Scaled Dot-Product Attention が where 句に現れています.この引数部分に注目すると,Q, K, V に何か行列がかかっていることに気づきます.これが先ほどの「補足」部分で述べた変換です.この行列をかますことで,Q, K, V をそれぞれいい感じに変換しています.その後の Attention により得られた行列を Head と呼び,それを h 個 concat (連結)した後に,特定の行列
この意味としては,それぞれ異なった役割を持つ Head をそれぞれ組み合わせることで,最終的に豊かな表現力を持つ Multi-Head を得ることと捉えることができます.
もちろん,これだけでは出力される行列が増えてしまい困ります.したがって,それぞれの Head で乗算する
補足
(iii) Self-Attention
ここでは,「K, Q, V が元々同じ部分から得られる Attention のこと」という理解で大丈夫です.
補足
2.2 Add & Norm
先ほどから,入出力のサイズを一致させていたことに気づいたと思いますが,この理由がここでわかります.まず,式を見てみましょう.
ここで用いられている工夫は2つあり,1つは Shortcut Connection (
Shortcut Connection は ResNet において,劣化問題(degradation)(層を深くすると,test だけでなく train の精度まで下がってしまう問題)を解決するために導入された機構です.この機構を導入するためには,Sublayer の出力と入力 x のサイズが一致している必要がありますが,深層学習モデルの層を段違いに深くすることが可能になるため,非常によく用いられます.
Layer Normalization は,バッチサイズに平均・分散が依らず,また入力サイズが一定でないsequence データを用いるタスクにおいてよく使用され,勾配消失・爆発などを避け,学習を安定化させる効果を見込めます.
[2] より一部引用
[2] より一部引用
補足
2.3 Feed Forward Network (FFN)
Attention 層を通った後,2層の全結合層を通ります.具体的には以下の式で表されます.
式を見ると,「全結合層
補足
2.4 Add & Norm
2.2 と同様です.
3. Decoder
続いて Decoder です.全体像は以下の通りです
[1] より一部引用
ほとんど Encoder と同じですが,1つ異なる点があります.
それは Masked Multi-Head Attention です.先ほどの Scaled Dot-Product Attention を確認すると,Scale と Softmax の間に Mask が存在しています.
Mask は,Decoder の入力において答えとなる単語の情報を含めないようにする目的で導入されています.(学習時にはまとめて文を入力するため必要となりますが,推論時は逐次処理なので必要ないようです)
行列のサイズは,Encoder と同様に考えることで,
補足
4. Linear & Softmax
いよいよ最後です.
[1] より一部引用
Decoder で出力された値を 線形層
行列の形状は,
補足
おわりに
かなり長くなってしまいましたが,全体的にやりたいことが理解できた気がしています.
実験結果や学習時の工夫など,細かいところまで知りたい方は是非元論文をご参照ください.
百聞は一見に如かずということで,次回は実装して動かすところをやりたいと思います.
参考
- transformer元論文
- ResNet元論文
- The Illustrated Transformer
- 全力解説!Transformer
- Python(PyTorch)で自作して理解するTransformer
- 作って理解する Transformer / Attention
- Transformerを理解したい
- 図で理解するTransformer
- 埋め込み層 (Embedding Layer) [自然言語処理の文脈で]
- 位置符号化 (Positional Encoding, 位置エンコーディング) [Transformer の部品]
- 【自然言語処理】Attentionとは何か
- バッチ正規化(Batch Normalization)とその発展型
- Transformer と seq2seq with attention の違いは?[系列変換]【Q and A 記事】
- 自然言語処理の必須知識 Transformer を徹底解説!
Discussion