BERT 実装入門
BERT の理論的な部分はなんとなく理解していたけれど、世に出回っているライブラリの実装がどうなっているかはよく知らなかった。
それで、色々入門記事っぽいのを漁っていたが、とりあえず以下の記事が一番良さそうに見えたのでざっと読んでみることにした。
必要な箇所を自分用にかいつまんでいるだけなので、翻訳記事ではない。
以下「ライブラリ」と言っているのは huggingface の Transformers のことである。
ライブラリの構成
ライブラリでは大元の BERT 以外にも RoBERTa や XLNet のような亜種も使うことができる。
これらの BERT や RoBERTa のような transformer を部品としたモデル各々に対し、以下の3つのクラスが用意されている。
- model クラス: モデルのアーキテクチャが PyTorch で実装されている(BERT であれば
BertModel
) - configuration クラス: モデルを組み立てるのに必要なパラメータが格納されている(BERT であれば
BertConfig
) - tokenizer クラス: それぞれのモデルの vocabulary や、文字列とトークンの間の変換を行うメソッドが提供されている(BERT であれば
BertTokenizer
)
from_pretrained()
では、ライブラリが用意している pre-trained モデルや、自分で作ったモデルの model / configuration / tokenizer を読み込む。
save_pretrained()
では、model / configuration / tokenizer をローカルに保存する。
model クラス
class transformers.BertModel
なお、BERT には
- L: transformer レイヤーの層数
- H: transformer における隠れ層の大きさ
- A: transformer における multi head attention のヘッド数
というパラメータがあって、 BERT BASE というのが L = 12, H = 768, A = 12 のモデルである。
また、よく使われるのは bert-base-uncased
だが、uncased
とは前処理として「小文字化」しているということ。
固有表現抽出や品詞タグ付けのような、大文字の情報が効くタスクでない限り、この uncased
を用いた方が良いそうである。
tokenizer クラス
class transformers.BertTokenizer(
vocab_file,
do_lower_case=True,
do_basic_tokenize=True,
never_split=None,
unk_token='[UNK]',
sep_token='[SEP]',
pad_token='[PAD]',
cls_token='[CLS]',
mask_token='[MASK]',
tokenize_chinese_chars=True,
**kwargs)
- それぞれの入力の最初のトークンは必ず
[CLS]
である -
[CLS]
の位置の最後の隠れ層は、文全体の埋め込み(sentence representation)として、文書分類などのタスクで使われる - 文どうしの区別は、まず文と文の間に
[SEP]
トークンを入れ、さらに、異なる文に対して異なる Segment Embeddings を加えることで行う - さらに、それぞれのトークンごとに Position Embeddings を加える
tokenize に関しては、以下の記事もわかりやすかった。
configuration クラス
class transformers.PretrainedConfig(**kwargs)
pre-trained モデルをそのまま使う場合は、configuration を明示的に呼び出さない場合も多い。
モデルの作成
Pre-training: すごい時間がかかる(Cloud TPU でも数日程度かかる)。Pre-trained されたモデルが配布されているので、普通の開発者はやる必要がない。
Fine-tuning: 時間はそこまでかからない。例えば、SQuAD という QA タスクの訓練データを学習するのは TPU で 30 分、GPU でも数時間程度である。
Pre-trained モデルには
- MLM(Masked LM) 穴埋め問題 でパラメータ学習
- NSP(Next Sentence Prediction) 次の文の予測 でパラメータ学習
の 2 種類がある。
Pre-trained モデルを動かして穴埋め問題を解いてみよう
例として、
Who was Jim Henson? Jim ★ was a puppeteer.
の★を推測してみよう。
以下、Google Colaboratory で実装したので、そのリンクを貼る。
Discussion