Open27

RustでBERTのfine-tuningしたいけどよくわかんない

denjirydenjiry

token-classificationの前処理でdoccano形式からの変換書いたらバグったので、いっそ全部をRustで書き直したいが、rust_bertは学習をサポートしてないしcl-tohokuの日本語Bertは.otを提供してないのでそもそもrust_bertで読み込めない

denjirydenjiry

https://github.com/LaurentMazare/tch-rs#importing-pre-trained-weights-from-pytorch-using-safetensors
safetensorなる新しい学習済みモデルの保存方式があり、Pythonのpickleに依存しない。tch-rsもこれを使っているらしい

denjirydenjiry

pip install safetensorsで入るsafetensorsを使えば任意のhuggingfaceで公開されているモデルをtch-rsで読み込むことができそう
一旦AutoModelForTokenClassificationなどでベースモデルを読み込んで、その後stt.save_file(model.state_dict(), 'modelname.safetensors')で保存し、tch-rsから改めて読み込めばよさそう

denjirydenjiry

https://huggingface.co/learn/nlp-course/ja/chapter7/2
手で書かれた(PyTorchと一部Transformers)トレーニングループの例があるので、これをtch-rsに移植したい

denjirydenjiry

「手で書かれた」とはいえ、半分くらいの機能は引き続きTransformersを使ってるので、がんばって読んで剥がす必要あり

denjirydenjiry

https://github.com/LaurentMazare/tch-rs/issues/549#issuecomment-1296840898
tch-rsのFAQの”What are the best practices for Python to Rust model translations?”に答として参照されていた

重みファイルだけみてリバースエンジニアリングしようとするのはしんどいしミスりやすいのでやめたほうがよいらしい

私はいつも、Rust バージョンを実装する方法のガイドとして Python 実装を使用していました。重要な点は、モデルの移植を容易にするために、tch が PyTorch のデフォルトの動作をすべて模倣しようとすることです。これは、変数の初期化や関数呼び出しのオプションの引数などの場合に当てはまります。ハグフェイスのディフューザーライブラリに基づいた安定した拡散の例を
プッシュしました。ディフューザーのコードベースはかなり大きいため、数日間の作業が必要でしたが、Rust の構造は Python コードの構造にうまく準拠しており、途中で大きな問題は発生しませんでした。

要はtchはPyTorchのデフォルト動作を模倣しようとしている(まさにPython→Rust移植がしやすいように)ので、Python実装を参考に真似してtch-rsで書くのがベストらしい

https://github.com/LaurentMazare/tch-rs/blob/main/examples/tensor-tools.rs
デバッグのときこれで重みファイルの中身みれて便利らしいが上述のとおりにやって出番が来ないほうがうれしいですね。
(他にはこれのlsを使うとsafetensors使わなくてもnpzからotへの変換ができそうだが、safetensorsはtch-rsのREADMEで言及されていることもありなんとなくsafetensorsの優先度が高い)

denjirydenjiry

上記を考えればrust_bertのコード見ればtch-rsでのBERTのモデル定義の方法がわかるし(重みファイルを読み込む以上、この作業は欠かせないはず)、日本語トークナイザーを使うなどの部分はTransformersの元のコードを読んでがんばれば移植できそう

denjirydenjiry

rust-bertのBERTの構造体ちょっとよんだ感じ結構詳しく書いてあり、モデルの定義はこのままrust_bertのやつ使って、トレーニングループだけ自分で書けばよさそうな気がしてきた

denjirydenjiry

rustbertでの重みの読み込みが、tchの例から想像できた挙動と一致してたところまで理解できた
VarStoreからsafetensorsなどをloadできる
BertTokenClassificationOutput.logitsを取り出して、ループの中で

let loss = logits
    .view([BATCH_SIZE * BLOCK_SIZE, labels])
    .cross_entropy_for_logits(&ys.view([BATCH_SIZE * BLOCK_SIZE]));
opt.backward_step_clip(&loss, 0.5);

すれば学習まわりそう

rust_bertのTokenClassificationModel.predict()が割と参考になる

空の重み行列群(各レイヤー)をnew()で作るが、重みファイル自体は、あらかじめ作っといたVarStoreで直接読み込む

denjirydenjiry

モデルの重みの読み込みはなんとかなりそうだが、tokenizerの情報どうすんのか忘れてた

denjirydenjiry

tests/bert.rs:bert_for_token_classification()に汎用モデルのコンフィグファイルからだいたいスクラッチで作る方法のってた

denjirydenjiry

Tokenizerとそれ以外を分けて考えてもよい確証は得られた気がするが、日本語用Tokenizerを使うので例をコピペというわけにはいかなそう

denjirydenjiry

Tokenizerのvocab.txtは0-originぽい
vocab.txtのファイルとしての先頭が

[PAD]
[UNK]
[CLS]
[SEP]
[MASK]
<unused0>
....

のとき、BertJapaneseTokenizerのspecial tokenを眺めると

ipdb> tokenizer.all_special_ids
[1, 3, 0, 2, 4]
ipdb> tokenizer.all_special_tokens
['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]']

であるため

denjirydenjiry

wordpieceがrust_tokenizersクレートにあったので自分で実装する必要がなかった

takuma satotakuma sato

candle-exampleのllamaでは

for rfilename in [
                "model-00001-of-00002.safetensors",
                "model-00002-of-00002.safetensors",
            ] {
                match &args.local_weights {
                    Some(path) => {
                        filenames.push((path.to_owned() + rfilename).into());
                    }
                    _ => {
                        let filename = api.get(rfilename)?;
                        filenames.push(filename);
                    }
                };
            }

となっているのでやはりsafetensorへの変換が必要かと思ったが、同じくcandle-exampleのbertのサンプルでは

let (config_filename, tokenizer_filename, weights_filename) = {
            let api = Api::new()?;
            let api = api.repo(repo);
            let config = api.get("config.json")?;
            let tokenizer = api.get("tokenizer.json")?;
            let weights = if self.use_pth {
                api.get("pytorch_model.bin")?
            } else {
                api.get("model.safetensors")?
            };
            (config, tokenizer, weights)
        };

となっているので、.binのままdownloadしてきて読み込むことも可能?

takuma satotakuma sato

pytorch_model.binをdownloadしようとしているのはcandle_examplesのなかではbertとdistillbert
どちらもencoder-onlyモデルなので、llama等decoder-onlyモデルでも行けるかは不明(やってみる)