RWKV world tokenizer の情報と C++ での実装メモ
背景
LLM 向けコーパス構築で, TB 単位の文章に対して Suffix Tree 構築したいが,
そのままでは 4N(4GB 以下の文章. N = 入力文章のバイト数) or 8N(4GB 以上の文章)の容量が必要となる. つまり 1 TB の文章があったら 4 TB or 8 TB が必要...
メモリ節約したいので入力文章を tokenize したい.
しかし tokenizer の速度や vocab の少なさ(65536 以下なら uint16 で表現できる)も重要...
RWKV world tokenizer がよさそうです!
RWKV world tokenizer
特徴は...
- 多言語対応
- 数値もいい感じ
- ソースコードもいける
- 最長一致法(longest common prefix)でマッチングして tokenize
- ので実装が比較的かんたん(Trie 木をつくればよい)
- rwkv world tokenizer 自体の論文とかはなさそう
nvocab は 65536 とありますが, 実際は 20230424 版は要素数 65045, 最大 id 65529 になっています.
0 は <endoftext>
で, vocab file には含まれていません.
vocab ファイルを見ると, 漢字などは基本 1 文字 1 token になっています.
また,
...
65525 '############################################################################' 76
65526 '****************************************************************************' 76
65527 '################################################################################' 80
65528 '--------------------------------------------------------------------------------' 80
65529 ' ' 128
とめっちゃ長いのがあります. ソースコードの表現にはいいのかもしれません.
最長一致でトークナイズするので, 基本的には id と string のペアから TRIE 木(prefix tree)作ればいいだけです.
学習
どのように vocab を学習したのかは不明です...
日本語特化で, 日本語 2~3 文字を一つのトークンにするとかしてみたいところですね.
とりま同じ最長一致で形態素解析する jagger https://zenn.dev/syoyo/articles/9ac920632ba5c9 の学習を使いまわしでいける.. かも?
実装
とりま hat-trie の longest_prefix を使ってぺろっとできますが...
めちゃ遅い版
std::map<std::string, int> str_to_id_map;
// read vocab file and fill str_to_id_map...
// We can use uint16_t as value type.
tsl::htrie_map<char, int> trie_map;
for (const auto &it : str_to_id_map) {
trie_map[it.first] = it.second;
}
// encode UTF-8 string
std::string input_str = u8"吾輩は猫である。";
while (!input_str.empty()) {
auto longest_prefix = trie_map.longest_prefix(input_str);
if (longest_prefix != trie_map.end()) {
std::cout << "{" << longest_prefix.key() << ", " << *longest_prefix << "}\n";
input_str.erase(0, longest_prefix.key().size());
} else {
// TODO: fallback to emit fallback token + UTF-8 bytes
}
}
{吾, 11080}
{輩, 17065}
{は, 10139}
{猫, 14398}
{である, 58552}
{。, 10080}
Voila~
ただ, 入力文字列が長いと(100 文字くらいでも) longest_prefix
がめっちゃ遅いです...
最適化版
しょうがないので, 1 UTF-8 文字づつ find_ks
で一致を求めていって, 一致がなくなったらそれまでのが最長一致になるので, そのようにして longest match を探すようにしました.
bool encode(const std::string &_input_str, std::vector<int> &output_ids) {
std::vector<int> dst;
const size_t s_len = _input_str.size();
if (s_len == 0) {
// empty input
return false;
}
size_t char_idx = 0;
int prev_id = -1; // Track previously matched result.
size_t key_size = 0;
// Find match for each UTF-8 character,
// Since `longest_prefix` is quite slow for larger input string.
while ((char_idx + key_size) < s_len) {
// Extract UTF-8 char.
uint32_t charlen = utf8_len(_input_str[char_idx]);
if (charlen == 0) {
// Found invalid UTF-8 string.
return false;
}
key_size += charlen;
auto it = _trie_map.find_ks(&_input_str[char_idx], key_size);
if (it == _trie_map.cend()) {
if (prev_id > 0) {
// prev_id = id of longest matched key
dst.push_back(prev_id);
// pop current UTF-8 character.
key_size -= charlen;
} else {
// UTF-8 byte fallback
// Should be single UTF-8 character
if (key_size != charlen) {
// This should not happen. Just in case.
return false;
}
for (size_t i = 0; i < charlen; i++) {
dst.push_back(int(uint8_t(_input_str[char_idx + i])) +
_utf8_id_offset);
}
}
prev_id = -1;
char_idx += key_size;
key_size = 0;
} else {
prev_id = *(it);
// Continue search
}
}
// Remainder
if (prev_id) {
dst.push_back(prev_id);
}
output_ids = dst;
return true;
}
これでそれなりに高速に tokenize できるようになりました!
コードはこちら
さらなる高みへ...
高速化
TRIE 木なので, ダブル配列(めっちゃややこしい和製日本語...)で配列インデックスベースで高速化できるでしょう.
ただ, tokenizer 用途であれば, 今回は高々 65536 要素なので, ↑の速度の問題がなければ, HAT-trie でぺろっと対応のままでいきたいところ. ベンチマークではダブル配列の cedar より速いっぽいし(cedar の website https://www.tkl.iis.u-tokyo.ac.jp/~ynaga/cedar/ 見ると hat-trie より速いとあるが...)
ダブル配列は, ほぼ配列表現の二重連鎖木と同じだと思うので, 配列表現の二重連鎖木の実装を使いまわすのも手でしょう.
N 分木(多分木)を二重連鎖木にしてオフセット配列表現するメモ
未知語を UTF-8 バイト文字にする
未知語(e.g. 絵文字とか)があったら UTF-8 文字に fallback したい.
幸いにも, RWKV World tokenizer では, 1 ~ 256 が ASCII に対応しています(+1 されている).
ASCII 文字単体 127 ~ 255 は通常テキストには単体としてはほぼ現れない(そのため JSON 版では削除されている模様)ので, これを fallback の UTF-8 バイトの id として使えます.
encode 時は,
trie longest_prefix 探索で見つからなかったら, UTF-8 バイト列(通常日本語は 3 バイト)を 1 バイトづつ 1 tokenize(つまり日本語文字 fallback の場合は通常 3 tokens に変換される).
がでいけます.
頑張って UTF-16 にするとか, uint16 x 2 にパックしてバイト数減らす手もありますが, 未知語が出るのはレアケースだとおもうのでそこまで頑張らなくてもいいかもです.
decode 時は, 128 ~ 256 に当たっtoken ids は UTF-8 バイトとして解釈して, 文字を復元, でいけます.
絵文字と顔文字
絵文字は10 個くらいしか語彙に含まれていないため, 絵文字があると, UTF-8 fallback で最低 4 token 使いますので, 効率が悪いです.
特に日本語を想定しているのであれば, いくつか英単語やソースコード用語彙を削って, 絵文字と顔文字を追加するとよいでしょう.
に絵文字と顔文字を追加する builder スクリプトを置きました.
多言語 LLM の tokenizer にも活用する
一から事前学習必要になりますが, tokenizer 部分の実装がシンプルかつ高速になり, いい感じになる可能性があります.
TODO
- 実際にベンチマークを取る
Discussion
続編の記事を読んで帰ってきました。
(自分で動かしていないで、ソース読んだだけでコメントしているので、的外れでしたらすみません。また、既に別実装を色々進めていらっしゃるようなので、このコメント自体不要かもしれませんが。。。)
longest_prefix()
を利用する詳細1
実装を辿っていくと、
tsl::htrie_hash::longest_prefix_impl()
にたどり着きますが、constメンバーと非constメンバーがオーバーロードされていて、非const版ではconst版の結果をtsl::htrie_hash::mutable_iterator()
により非constに変換しているようです。これ以降の実装の読解は甘いですが、内部の hash array を複製している可能性はありそうなので、const版で変換を避ける事にメリットがあるかもしれません。
詳細2
READMEに記載がありますが、
std::hash<std::string>
が内部にコピーをしてしまうため、C++17以上ではstd::string_view
を利用、C+14以前では単純なハッシュを実装して利用しているようです。このC++14以前の単純なハッシュを外部のハッシュに変えると20%ぐらい性能向上が観測されたとのこと。
以前別の記事で、manylinux対応のために、C++14前提で実装取り組まれていた事もあったので、今回もそれに該当しているのかもしれないと思い。
ありがとうございます.
C++17 でコンパイルしても遅かったですね...
const でカイゼンできそうですね!
ただ,
find_ks
利用で満足する速度でたので,longest_prefix
の詳細調べるのはしばらくはなさそうです...