ファインチューニングでmBARTの日→英翻訳モデルを作成してhuggingfaceで公開してみた
おはこんばんちは。最近は自分自信が頑張るしかねえなって思うことが多いken11です。
身の回りでいろいろ起きるけど、自分の身は自分しか守れない…
さて、今日はついにhuggingfaceに自分で学習させたモデルを公開するところまでいったのでその話です。
成果物
まあいろいろ長々と書きまくる前に成果物を置いておきます。
急ぎの人はここだけチェックすれば大丈夫です。
できあがったモデル
なにより最初にモデルを出せと言われる気がするので、ほいどうぞ
sacrebleu(大文字小文字無視)で18くらいのスコアにはなったので、日→英翻訳としてはぼちぼちじゃない?って個人的には思ったり
適当な記事の冒頭を投げてみるとこんな感じで返ってきます、いい線いってるしニュアンスは伝わるが果たしてって感じ?英語わかんない←
まあ本気で使いたかったらもっと学習してねって感じでしょうか
学習のコード
ほいさ
一応、SageMakerStudioで動いたのでたぶん大丈夫だと思う
恐らく少しやってみようとした方はわかると思うんですが、mBARTのpre-trainedモデルで一番有名なfacebookのものはサイズが大きくて一苦労します。
その一苦労への対処(後述)も込みでノートブックにしてあるので、便利に使ってもらえたらと :bow:
苦労したポイント
モデルでかすぎ問題
今回、mBARTの事前学習済みモデルとしてfacebookのものを利用しました。
これが25言語収録というのもあってなかなかオバケサイズで、実は当初バッチサイズを1にしてもメモリに乗らないという悲劇に見舞われていました。
この辺で議論されているOOM問題です。
これにどう対処するかっていうと、同じissueで書かれているモデルのカットオフを実行します。
モデルを削る
なかなか面白い経験でした。
ノートブックにも書いたんですが、どういう話かというと
- mbart-large-cc25は25言語の約250000単語をボキャブラリーとして抱えている
- ファインチューニングする際、関係ない言語の単語やその情報は不要になる
- 学習データを元にトークナイザーとボキャブラリーファイルを作成しなおし、そのボキャブラリーに基づいて必要な情報だけ残したモデルを作成する
- そのモデルをベースにファインチューニングする
ということです。
いやーこの方法思いついたfansiawangさんすごいですね。マジでありがとうございます。
削るコード
実際に削るコードですが、基本はissueに書かれているコードを実行するだけです。
ただ、いくつか気をつける点があります。
まず学習データを元にボキャブラリーファイルをつくる必要があるんですが、それは普通にSentencepieceの学習をする形で問題ないです。
import sentencepiece as spm
spm.SentencePieceTrainer.Train("--input=tmp.txt --model_prefix=new_spm_model --vocab_size=64000 --vocabulary_output_piece_score=false --model_type=bpe")
ただこのままだとfairseqのDictionaryでloadできないので、少し加工します。
edited = []
for line in open("new_spm_model.vocab", 'r', encoding='utf-8'):
if line in ["<unk>\n", "<s>\n", "</s>\n"]:
continue
new_line = line.rstrip('\n') + " 1\n"
edited.append(new_line)
with open('new_dict.txt', 'w') as f:
for e in edited:
f.write(e)
※ちなみになぜ <unk>
などを消しこんでいるかというと、そのままだと後ほどduplicateで怒られるからです。
次に、実際にカットオフするコードですが、自分の場合issueのコメントにあるコードだけでは足りませんでした。
コメントのコードだと "encoder.embed_tokens.weight", "decoder.embed_tokens.weight"
ですが、自分がやったときは "model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "model.shared.weight", "lm_head.weight"
というように他の部分も対応が必要でした。
最終的に新しいものをsaveすることもふまえて、こんな感じでモデルを削るようにしています。
(詳しくはノートブックを参照ください)
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-cc25")
org_sd = model.state_dict()
resized_sd = model.state_dict()
mapping: List[int] = []
for i in range(len(ft_dict)):
word = ft_dict[i]
mapping.append(pre_dict.index(word))
for name in ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "model.shared.weight", "lm_head.weight"]:
pre_tensor: torch.Tensor = org_sd[name]
ft_tensor = torch.zeros(
[len(ft_dict), 1024], dtype=pre_tensor.dtype, layout=pre_tensor.layout, device=pre_tensor.device,
)
for ft_i, pre_i in enumerate(mapping):
ft_tensor[ft_i] = pre_tensor[pre_i]
resized_sd[name] = ft_tensor
resized_sd["final_logits_bias"] = resized_sd["final_logits_bias"][:, :len(ft_dict)]
config = MBartConfig.from_pretrained("facebook/mbart-large-cc25")
config.vocab_size = len(ft_dict)
print(config)
new_model = MBartForConditionalGeneration.from_pretrained(None, config=config, state_dict=resized_sd)
new_model.save_pretrained("./reduced_model")
これによりモデルサイズは1.6GB程度まで減り、無事に貧弱なメモリでも学習できるようになります。
今回公開したモデルは、これを元にJESCのデータセットで5エポックくらいぶん回したものです。
だいぶお金がかかった
でも(自分としては)そこそこ納得のいくモデルができたので満足ですね。
huggingfaceでの公開
これについては特に難しいことはなく、チュートリアル通りにやればできると思います。
git(とlfs)を使って管理されている状態なので、GitHubに載せるのと感覚的にはあまり変わらないかなと。
ただ、READMEはmodel cardといって方言が強いので、こちらのドキュメントを読んで書いたらいいと思います。
書いてない人多いし、完璧に書く必要なんてないのだろうけど、載せられる情報は載せておいたほうが使いたい人の役に立つかなと思います。
余談ですが、このmodel cardに付けられる言語タグを元に日本語の翻訳モデルを探すとほぼ見つからないのが現実です。
みんなもっと自然言語処理やろうぜ。
ちなみに、mBARTのmodel card widgetは壊れている気がするので僕のモデルもそうですがwidgetは生きてないです。
なんかたぶんAutoModelでmBART読もうとするとエラーしてる気がする。
issueあるのか確認してないけど、直せそうなら直したいところではある。
感想
mBARTの学習、マジで最初はOOMし続けて「これ絶対無理ゲーじゃん」とか思ってたんですが、なんとかうまく学習させることができてほんとよかったです。
OOM見飽きて発狂しそうだった
あとボキャブラリーに従っていらない単語の情報を消すってあまり考えたことなかったので勉強になりました。たしかに、その手があったか、と。
そしてそして、今まではモデルをつくっても自分の手元で完結していたので、公に出すっていうのはちょっと楽しいですね。(まだ誰も使ってないけど笑)
日本語関連のモデルはhuggingfaceでも少ない状態なのは勿体ないなと思います。
BERTだって日本語はみんな東北大使うけど、京大やNICTもあるから。
京大やNICTのモデルもhuggingfaceに乗ればみんな使うんじゃないかな〜って勝手に思ってました。
僕もまたなにか作ったら公開します!
Discussion