🥷
TensorFlow形式で保存されたBERTをPyTorch形式にする
コマンド
このコードに対して、以下のコマンドを実行する。
python convert_bert_original_tf_checkpoint_to_pytorch.py --tf_checkpoint_path \path\to\the\tf\model --bert_config_file \path\to\the\config\file --pytorch_dump_path \path\to\the\output\pytorch\model
解説
著者の金子はBERTなどの言語モデルを使う時は、Hugging Face🤗のtransformersを用いることが多い。一方で、世の中にはTensorFlow (TF) で保存されたモデルもたくさんあるため、必ずしもそのままtransformersでロードできなかったりする。その時は、transformersのmodelsディレクトリで提供されている、TF形式のモデルをPyTorch形式にするコードでモデルを変換する。
BERTに対する例
例としてBERTを変換する。Googleが提供しているBERTを以下のコマンドでダウンロードする。
wget https://storage.googleapis.com/bert_models/2020_02_20/uncased_L-2_H-128_A-2.zip # 練習のため小規模なBERTモデルをダウンロード
unzip uncased_L-2_H-128_A-2.zip # 解凍
BERTモデルを変換するために、models/bert
にあるこのコードを、必要なライブラリがあれば適宜インストールし使う。変化するためのコードはモデルごとに提供されているため、対象モデルごとにコードを変える必要がある。
wget https://raw.githubusercontent.com/huggingface/transformers/master/src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py
python convert_bert_original_tf_checkpoint_to_pytorch.py --tf_checkpoint_path bert_model.ckpt.index --bert_config_file bert_config.json --pytorch_dump_path pytorch_model.bin
-
--tf_checkpoint_path
はTF形式のモデルのパス -
--bert_config_file
はconfigファイル -
--pytorch_dump_path
はPyTorch形式のモデルを出力するパス
Discussion