RustでもBERTでFine-tuningしたい
本記事はLabBase テックカレンダー Advent Calendar 2023 5 日目です。
概要
Pythonのtransformersライブラリと同程度に簡単に、RustでもBERTでFine-tuningしたかったので調べたところ、惜しくもできる方法は無かった。
しかしちょっと頑張ると可能そうだし、ちょっと待てばhuggingface/candleでできるようになりそう。
RustでもBERTでFine-tuningしたい
業務で文章からキーワード抽出するタスクを解きたかったので、
huggingfaceの「トークン分類」のドキュメント
を参考に、あらかじめ用意したアノテーション済みデータを入力として、
huggingfaceのベースモデルを元にFine-tuningするまでをPythonで実装しました。
ですが、せっかくなのでRustでも実装したくなったので、Rustでも同様の
ことができないか調査しました。
このとき、Pythonでは
- トークン分類特有の前処理
- huggingfaceからベースモデルの重み行列のロード
- 学習ループ
がそれぞれ実装されていました。今回はRustでFine-tuningする部分にフォーカスしたいので、トークン分類特有の前処理の話は割愛します。
よってそれぞれのフェーズで、Rustによる再実装をおこなうにはどうすればよいかを検討していきます。
huggingfaceからベースモデルの重み行列のロード
前提として、https://huggingface.co/cl-tohoku/bert-base-japanese-v3 のようにhuggingfaceで公開されているモデルの重み行列はtransformersライブラリからの利用を想定しており、そのままではRustから読み込むことができません。
そこで、safetensorsを使います。
safetensors
はPythonとRust両方に対応した行列の保存形式で、これをつかってhuggingfaceで公開されている任意の重み行列をRustへ橋渡しすることができます。
具体的には、以下のように一旦transformersでhuggingfaceからモデルを初期化したのち、safetensors
でその重み行列をsafetensors形式で保存します。
# c.f. https://github.com/LaurentMazare/tch-rs#importing-pre-trained-weights-from-pytorch-using-safetensors
from safetensors import torch as stt
from transformers import AutoModelForTokenClassification, BertForTokenClassification
def main(model_name: str):
model: BertForTokenClassification = AutoModelForTokenClassification.from_pretrained(model_name)
stt.save_file(model.state_dict(), "exported.safetensors")
if __name__ == "__main__":
main("cl-tohoku/bert-base-japanese-v3")
これによって、毎回一度safetensors形式への変換を挟むものの、huggingfaceの任意のモデルの重み行列がRustでも読み込めるようになりました。
学習ループ
実はRustにtransformersライブラリ相当のものが存在しないかというとそんなことはなく、rust_bert
というクレートが存在しており、またpytorch相当であるtch-rs
クレートも存在しています。
これら2つのクレートとライブラリの関係は、ちょうど図のような対応にあるのですが、
図にも書いてある通り、rust_bert
にはtransformersと同程度に抽象度の高い学習用の機能は実装されていません。
したがって残念ながら、モデルの定義はrust_bert
のものを借りつつ、学習ループについてはtch-rs
を使い手動で書く必要があり、現状ではtransformersと同程度の簡単さでFine-tuningをすることはできないようでした。
candle
ここで今年huggingfaceが新しく開発している https://github.com/huggingface/candle を紹介します。
candle
は先程の図のtch-rs
とおなじ位置に存在するクレートなのですが、tch-rs
と違ってpytorchのC++コア(libtorch)への依存がなく、GPUでの計算などから自力で実装しているようです。(CUDAカーネルを生成するようなコードが見て取れます)
実際READMEにも「軽量かつPythonへの依存がゼロにできる」ということが他のフレームワークとの差分として挙げられています。
また、tch-rs
のような低レイヤな機能だけでなくtransformersのようなレイヤの高い機能の実装(candle-transformers)やデータセット、ONNX、WASM、ドキュメント、exampleなど、すでに幅広い範囲をカバーするサブのクレートが存在しています。
candle-transformersにはBERTやLlamaなど様々なアーキテクチャのモデルの定義やユーティリティ的機能が存在するので、huggingfaceからモデルをロードして使用することも可能なようです。
しかしながらBERTのモデルの実装を見てみたところ、10月にexampleでhuggingfaceのモデルを使う例が実装されていたのですが、Dropoutの学習時の処理が未だ実装されていないようなので、Fine-tuningをそのままでは行えないようでした。
(開発が活発なので、すこし待ってればそのうち実装されるかもしれません)
まとめ
RustでもBERTでFine-tuningがしたかったので調べてみましたが、残念ながらすぐには実現できなそうでした。
実はtch-rsでコードを書いていたのですが、途中までしか終わっていないのと、candleを待つなりTODOに自分でパッチ当てるなりしたほうが楽そうだなと思いほったらかしになっています。
現在トークン分類自体はすぐに必要なものではなくなったので一旦そのままにするつもりですが、今後もっと大きなモデルをRustでセルフホストするなどのケースに備えて引き続き調査する予定です。
参考
Discussion