🥷

TensorFlow形式で保存されたBERTをPyTorch形式にする

2022/01/23に公開

コマンド

このコードに対して、以下のコマンドを実行する。

python convert_bert_original_tf_checkpoint_to_pytorch.py --tf_checkpoint_path \path\to\the\tf\model --bert_config_file \path\to\the\config\file --pytorch_dump_path \path\to\the\output\pytorch\model

解説

著者の金子はBERTなどの言語モデルを使う時は、Hugging Face🤗のtransformersを用いることが多い。一方で、世の中にはTensorFlow (TF) で保存されたモデルもたくさんあるため、必ずしもそのままtransformersでロードできなかったりする。その時は、transformersのmodelsディレクトリで提供されている、TF形式のモデルをPyTorch形式にするコードでモデルを変換する。

BERTに対する例

例としてBERTを変換する。Googleが提供しているBERTを以下のコマンドでダウンロードする。

wget https://storage.googleapis.com/bert_models/2020_02_20/uncased_L-2_H-128_A-2.zip # 練習のため小規模なBERTモデルをダウンロード
unzip uncased_L-2_H-128_A-2.zip # 解凍

BERTモデルを変換するために、models/bertにあるこのコードを、必要なライブラリがあれば適宜インストールし使う。変化するためのコードはモデルごとに提供されているため、対象モデルごとにコードを変える必要がある。

wget https://raw.githubusercontent.com/huggingface/transformers/master/src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py
python convert_bert_original_tf_checkpoint_to_pytorch.py --tf_checkpoint_path bert_model.ckpt.index --bert_config_file bert_config.json --pytorch_dump_path pytorch_model.bin
  • --tf_checkpoint_pathはTF形式のモデルのパス
  • --bert_config_fileはconfigファイル
  • --pytorch_dump_pathはPyTorch形式のモデルを出力するパス

参考文献

Discussion