RustでBERTのfine-tuningしたいけどよくわかんない
token-classificationの前処理でdoccano形式からの変換書いたらバグったので、いっそ全部をRustで書き直したいが、rust_bertは学習をサポートしてないしcl-tohokuの日本語Bertは.otを提供してないのでそもそもrust_bertで読み込めない
前処理を普通にRustで書き、さらに推論時は学習したモデルをONNX形式にしたものをONNXのRustランタイムから使えば、Transformersの使用量を前処理後の学習だけにできそう
日本語でHugging Face Tokenizersを動かす
huggingfaceのtokenizersを使わずに日本語用形態素解析器を使っているが、この記事ではhuggingfaceのをつかっている
cl-tohokuもこの早稲田のモデルもこれを読まずにこのタスクに取り掛かったのを後悔した
safetensorなる新しい学習済みモデルの保存方式があり、Pythonのpickleに依存しない。tch-rsもこれを使っているらしい
pip install safetensors
で入るsafetensors
を使えば任意のhuggingfaceで公開されているモデルをtch-rsで読み込むことができそう
一旦AutoModelForTokenClassification
などでベースモデルを読み込んで、その後stt.save_file(model.state_dict(), 'modelname.safetensors')
で保存し、tch-rsから改めて読み込めばよさそう
手で書かれた(PyTorchと一部Transformers)トレーニングループの例があるので、これをtch-rsに移植したい
「手で書かれた」とはいえ、半分くらいの機能は引き続きTransformersを使ってるので、がんばって読んで剥がす必要あり
tch-rsのFAQの”What are the best practices for Python to Rust model translations?”に答として参照されていた
重みファイルだけみてリバースエンジニアリングしようとするのはしんどいしミスりやすいのでやめたほうがよいらしい
私はいつも、Rust バージョンを実装する方法のガイドとして Python 実装を使用していました。重要な点は、モデルの移植を容易にするために、tch が PyTorch のデフォルトの動作をすべて模倣しようとすることです。これは、変数の初期化や関数呼び出しのオプションの引数などの場合に当てはまります。ハグフェイスのディフューザーライブラリに基づいた安定した拡散の例を
プッシュしました。ディフューザーのコードベースはかなり大きいため、数日間の作業が必要でしたが、Rust の構造は Python コードの構造にうまく準拠しており、途中で大きな問題は発生しませんでした。
要はtchはPyTorchのデフォルト動作を模倣しようとしている(まさにPython→Rust移植がしやすいように)ので、Python実装を参考に真似してtch-rsで書くのがベストらしい
(他にはこれのls
を使うとsafetensors使わなくてもnpzからotへの変換ができそうだが、safetensorsはtch-rsのREADMEで言及されていることもありなんとなくsafetensorsの優先度が高い)
上記を考えればrust_bertのコード見ればtch-rsでのBERTのモデル定義の方法がわかるし(重みファイルを読み込む以上、この作業は欠かせないはず)、日本語トークナイザーを使うなどの部分はTransformersの元のコードを読んでがんばれば移植できそう
ちらっとみたかんじ大体PyTorch
tch-rsにdownload-libtorch
なるfeatureが生えてて、これをオンにするとCPU版のlibtorchのバイナリを自分で落としてつかってくれるので便利だった
rust-bertのBERTの構造体ちょっとよんだ感じ結構詳しく書いてあり、モデルの定義はこのままrust_bertのやつ使って、トレーニングループだけ自分で書けばよさそうな気がしてきた
tch-rsの代わりに https://github.com/huggingface/candle 使おうかと思ったが、DropOutなどの学習には使う層が実質からで、自分で実装埋めれるほど理解度高くないので一旦撤退
Dropoutの学習時の処理が未だにTODOだが、10月にexampleでhuggingfaceのモデルを使う例が実装されており、推論で使う分には大丈夫なのでは?
rustbertでの重みの読み込みが、tchの例から想像できた挙動と一致してたところまで理解できた
VarStoreからsafetensorsなどをloadできる
BertTokenClassificationOutput.logitsを取り出して、ループの中で
let loss = logits
.view([BATCH_SIZE * BLOCK_SIZE, labels])
.cross_entropy_for_logits(&ys.view([BATCH_SIZE * BLOCK_SIZE]));
opt.backward_step_clip(&loss, 0.5);
すれば学習まわりそう
rust_bertのTokenClassificationModel.predict()が割と参考になる
空の重み行列群(各レイヤー)をnew()で作るが、重みファイル自体は、あらかじめ作っといたVarStoreで直接読み込む
Tokenizerのvocab.txtは0-originぽい
vocab.txtのファイルとしての先頭が
[PAD]
[UNK]
[CLS]
[SEP]
[MASK]
<unused0>
....
のとき、BertJapaneseTokenizer
のspecial tokenを眺めると
ipdb> tokenizer.all_special_ids
[1, 3, 0, 2, 4]
ipdb> tokenizer.all_special_tokens
['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]']
であるため
candle-exampleのllamaでは
for rfilename in [
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
] {
match &args.local_weights {
Some(path) => {
filenames.push((path.to_owned() + rfilename).into());
}
_ => {
let filename = api.get(rfilename)?;
filenames.push(filename);
}
};
}
となっているのでやはりsafetensorへの変換が必要かと思ったが、同じくcandle-exampleのbertのサンプルでは
let (config_filename, tokenizer_filename, weights_filename) = {
let api = Api::new()?;
let api = api.repo(repo);
let config = api.get("config.json")?;
let tokenizer = api.get("tokenizer.json")?;
let weights = if self.use_pth {
api.get("pytorch_model.bin")?
} else {
api.get("model.safetensors")?
};
(config, tokenizer, weights)
};
となっているので、.bin
のままdownloadしてきて読み込むことも可能?
pytorch_model.bin
をdownloadしようとしているのはcandle_examples
のなかではbertとdistillbert
どちらもencoder-onlyモデルなので、llama等decoder-onlyモデルでも行けるかは不明(やってみる)
ELYZAのllama-13bのリポジトリはこちら