BERTでseq2seqしたいときのモデルの理解
BERT generation周り学習時の備忘用.
huggingfaceの翻訳とほしい部分の要約.
以下の順:
- Encoderの定式化
- Decoderの定式化
- EDモデルの定式化
- pre-trained BERTでEDモデルするにはどうするか
まとめ
- BERTはencoder onlyのため, せっかくがっつり事前学習してweightを保存していてもseq2seqには向かない (事前にsequence lengthが欲しい)
- decoder onlyのGPT2はseq2seqには向くものの,
の情報をX_{i+1} に使えないのがもったいない.X_i - Rothe et al. (2020)ではBERTやGPT2といった事前学習モデルをwarm-startingなencoder-decoder (ED)モデルとして活用する際のTIPSをまとめた.
seq2seqの定式化
このとき各単語ベクトルを用いて
BERTの定式化
Encoder側の代表.
BERTは
その後, 適宜pooling layerとclassification layerを追加して,
重要な点は, BERTのようなencoder-onlyモデルは入力seqを長さが決まっている出力seqへと射影するモデルであり, 出力seqの長さは入力seqに依存しない こと.
そのためBERTは素直に使うとseq2seqには向かない.
GPT2の定式化
Decoder側の代表.
GPT2はunidirectional self-attentionを用いて
logitベクトルをsoftmaxに通せば単語列
この確率分布は
,where
GPT2は文章の生成に向くが, unidirectionalな性質のため,
EDモデルの定式化
EDモデル一般の話.
まずEncoderにより, 長さnの入力seq
その後, Decoderにより,
, where
BERTによるwarm-startingなseq2seqのアプローチ
いよいよpre-trained BERTによるseq2seq.
pre-trained BERTのweightをseq2seqに利活用するにあたり, encoderとdecoderに分けて考える.
encoder側
encoder側はBERTと同じ構造なので, そのままpre-trainedのweightをセットすればいい.
decoder側
decoderは3箇所BERTと異なる.
唯一
- cross-attentionの追加
- ここが 最も異なる
- decoderはコンテキストを反映したseq
によって条件付けされる必要がある\overline{X_{1:n}} - 各BERTブロック内self-attentionとfeed-forwardの間に乱数で初期化した,
とのcross-attentionを入れる\overline{X_{1:n}}
- bi-directional self-attentionをuni-directionalに変更
- BERTブロックのself-attentionはbi-directionalだが, decoderにするためにはここがuni-directionalである必要がある
- ただquery, key, valueの関係性は同一なので, weight自体はBERTのものをセットする
- (雑感) 少し不思議な気もするが, 感覚的にweightのmeanやvarがそれっぽい値になるだけでもfine-tuningする上では捗る印象なので, 奏効するのか
- LM-headの追加
- decoder出力
はロジットのベクトル列なので, 欲しい条件付き確率L_{1:m} に変換するLM headがdecoderブロックの最後に来るp_{\theta_{dec}}(Y_{1:n} | \overline{X}) - LM headはlanguage model headの略, GPT2由来
- LM headのweightはword embedding
に対応するので, BERTモデルのW_{emb} をdecoderのLM headにセットするW_{emb}
- decoder出力
(オマケ) weight-sharing
Raffel et al.(2020)は, seq2seqではencoder/decoderのweightをshareしても(同じにしても)それほど性能が下がらないことを指摘している.
この辺りとwarm-startingは相性が良さそう.
テスト
テストは本家に丁寧に載っているので割愛, ページのちょうど真ん中ら辺からpracticeを載せてくれている.
ここにコードだけをまとめたバージョンもある.
transformers
のEncoderDecoderModel
にて, encoderとdecoderにそれぞれpre-trainedのBERTを指定すれば後は普段通りという感じ.
スクラッチの際には渡せる形のpre-trainedを作ることが少し手間になるが, 構造を変えない限りは入れ替えるだけなので簡単.
Discussion