🔣

RWKV world tokenizer の情報と C++ での実装メモ

2024/01/29に公開
2

背景

LLM 向けコーパス構築で, TB 単位の文章に対して Suffix Tree 構築したいが,
そのままでは 4N(4GB 以下の文章. N = 入力文章のバイト数) or 8N(4GB 以上の文章)の容量が必要となる. つまり 1 TB の文章があったら 4 TB or 8 TB が必要...

メモリ節約したいので入力文章を tokenize したい.
しかし tokenizer の速度や vocab の少なさ(65536 以下なら uint16 で表現できる)も重要...

RWKV world tokenizer がよさそうです!

RWKV world tokenizer

https://github.com/BlinkDL/ChatRWKV/blob/main/tokenizer/rwkv_tokenizer.py

特徴は...

  • 多言語対応
  • 数値もいい感じ
  • ソースコードもいける
  • 最長一致法(longest common prefix)でマッチングして tokenize
    • ので実装が比較的かんたん(Trie 木をつくればよい)
  • rwkv world tokenizer 自体の論文とかはなさそう

https://github.com/BlinkDL/ChatRWKV/blob/main/tokenizer/rwkv_vocab_v20230424.txt

nvocab は 65536 とありますが, 実際は 20230424 版は要素数 65045, 最大 id 65529 になっています.
0 は <endoftext> で, vocab file には含まれていません.

vocab ファイルを見ると, 漢字などは基本 1 文字 1 token になっています.
また,

...
65525 '############################################################################' 76
65526 '****************************************************************************' 76
65527 '################################################################################' 80
65528 '--------------------------------------------------------------------------------' 80
65529 '                                                                                                                                ' 128

とめっちゃ長いのがあります. ソースコードの表現にはいいのかもしれません.

最長一致でトークナイズするので, 基本的には id と string のペアから TRIE 木(prefix tree)作ればいいだけです.

学習

どのように vocab を学習したのかは不明です...

日本語特化で, 日本語 2~3 文字を一つのトークンにするとかしてみたいところですね.

とりま同じ最長一致で形態素解析する jagger https://zenn.dev/syoyo/articles/9ac920632ba5c9 の学習を使いまわしでいける.. かも?

実装

https://github.com/Tessil/hat-trie

とりま hat-trie の longest_prefix を使ってぺろっとできますが...

めちゃ遅い版


    std::map<std::string, int> str_to_id_map;
    
    // read vocab file and fill str_to_id_map...
    
    // We can use uint16_t as value type.
    tsl::htrie_map<char, int> trie_map;

    for (const auto &it : str_to_id_map) {
      trie_map[it.first] = it.second;
    }
    
    // encode UTF-8 string
    std::string input_str = u8"吾輩は猫である。";

    while (!input_str.empty()) {

      auto longest_prefix = trie_map.longest_prefix(input_str);
      if (longest_prefix != trie_map.end()) {
        std::cout << "{" << longest_prefix.key() << ", " << *longest_prefix << "}\n";
        input_str.erase(0, longest_prefix.key().size());
      } else {
        // TODO: fallback to emit fallback token + UTF-8 bytes
      }
    }
    
{吾, 11080}
{輩, 17065}
{は, 10139}
{猫, 14398}
{である, 58552}
{。, 10080}

Voila~

ただ, 入力文字列が長いと(100 文字くらいでも) longest_prefix がめっちゃ遅いです...

最適化版

しょうがないので, 1 UTF-8 文字づつ find_ks で一致を求めていって, 一致がなくなったらそれまでのが最長一致になるので, そのようにして longest match を探すようにしました.

  bool encode(const std::string &_input_str, std::vector<int> &output_ids) {
    std::vector<int> dst;

    const size_t s_len = _input_str.size();

    if (s_len == 0) {
      // empty input
      return false;
    }

    size_t char_idx = 0;
    int prev_id = -1;  // Track previously matched result.
    size_t key_size = 0;

    // Find match for each UTF-8 character,
    // Since `longest_prefix` is quite slow for larger input string.

    while ((char_idx + key_size) < s_len) {
      // Extract UTF-8 char.
      uint32_t charlen = utf8_len(_input_str[char_idx]);
      if (charlen == 0) {
        // Found invalid UTF-8 string.
        return false;
      }

      key_size += charlen;

      auto it = _trie_map.find_ks(&_input_str[char_idx], key_size);
      if (it == _trie_map.cend()) {
        if (prev_id > 0) {
          // prev_id = id of longest matched key
          dst.push_back(prev_id);

          // pop current UTF-8 character.
          key_size -= charlen;

        } else {
          // UTF-8 byte fallback
          // Should be single UTF-8 character
          if (key_size != charlen) {
            // This should not happen. Just in case.
            return false;
          }

          for (size_t i = 0; i < charlen; i++) {
            dst.push_back(int(uint8_t(_input_str[char_idx + i])) +
                          _utf8_id_offset);
          }
        }

        prev_id = -1;

        char_idx += key_size;
        key_size = 0;
      } else {
        prev_id = *(it);

        // Continue search
      }
    }

    // Remainder
    if (prev_id) {
      dst.push_back(prev_id);
    }

    output_ids = dst;
    return true;
  }

これでそれなりに高速に tokenize できるようになりました!

コードはこちら

https://github.com/lighttransport/nanotokenizer

さらなる高みへ...

高速化

TRIE 木なので, ダブル配列(めっちゃややこしい和製日本語...)で配列インデックスベースで高速化できるでしょう.

ただ, tokenizer 用途であれば, 今回は高々 65536 要素なので, ↑の速度の問題がなければ, HAT-trie でぺろっと対応のままでいきたいところ. ベンチマークではダブル配列の cedar より速いっぽいし(cedar の website https://www.tkl.iis.u-tokyo.ac.jp/~ynaga/cedar/ 見ると hat-trie より速いとあるが...)

ダブル配列は, ほぼ配列表現の二重連鎖木と同じだと思うので, 配列表現の二重連鎖木の実装を使いまわすのも手でしょう.

N 分木(多分木)を二重連鎖木にしてオフセット配列表現するメモ
https://qiita.com/syoyo/items/6b44949148a606862db5

未知語を UTF-8 バイト文字にする

未知語(e.g. 絵文字とか)があったら UTF-8 文字に fallback したい.
幸いにも, RWKV World tokenizer では, 1 ~ 256 が ASCII に対応しています(+1 されている).

ASCII 文字単体 127 ~ 255 は通常テキストには単体としてはほぼ現れない(そのため JSON 版では削除されている模様)ので, これを fallback の UTF-8 バイトの id として使えます.

https://en.wikipedia.org/wiki/UTF-8

encode 時は,

trie longest_prefix 探索で見つからなかったら, UTF-8 バイト列(通常日本語は 3 バイト)を 1 バイトづつ 1 tokenize(つまり日本語文字 fallback の場合は通常 3 tokens に変換される).

がでいけます.

頑張って UTF-16 にするとか, uint16 x 2 にパックしてバイト数減らす手もありますが, 未知語が出るのはレアケースだとおもうのでそこまで頑張らなくてもいいかもです.

decode 時は, 128 ~ 256 に当たっtoken ids は UTF-8 バイトとして解釈して, 文字を復元, でいけます.

絵文字と顔文字

絵文字は10 個くらいしか語彙に含まれていないため, 絵文字があると, UTF-8 fallback で最低 4 token 使いますので, 効率が悪いです.

特に日本語を想定しているのであれば, いくつか英単語やソースコード用語彙を削って, 絵文字と顔文字を追加するとよいでしょう.

https://github.com/lighttransport/japanese-llama-experiment/tree/main/build_rwkv_world_ja_tokenizer

に絵文字と顔文字を追加する builder スクリプトを置きました.

多言語 LLM の tokenizer にも活用する

一から事前学習必要になりますが, tokenizer 部分の実装がシンプルかつ高速になり, いい感じになる可能性があります.

TODO

  • 実際にベンチマークを取る

Discussion

山田(ymd)山田(ymd)

続編の記事を読んで帰ってきました。
(自分で動かしていないで、ソース読んだだけでコメントしているので、的外れでしたらすみません。また、既に別実装を色々進めていらっしゃるようなので、このコメント自体不要かもしれませんが。。。)

longest_prefix がめっちゃ遅いです

  1. const参照を介して、constメンバーの longest_prefix() を利用する
  2. (C++14以前のみ) ハッシュ関数を差し替える

詳細1

実装を辿っていくと、tsl::htrie_hash::longest_prefix_impl() にたどり着きますが、constメンバーと非constメンバーがオーバーロードされていて、非const版ではconst版の結果を tsl::htrie_hash::mutable_iterator() により非constに変換しているようです。
これ以降の実装の読解は甘いですが、内部の hash array を複製している可能性はありそうなので、const版で変換を避ける事にメリットがあるかもしれません。

const auto& const_trie_map = trie_map;
const_trie_map.longest_prefix(input_str);

https://github.com/Tessil/hat-trie/blob/906e6abd1e7063f1dacd3a6b270aa654b525eb0a/include/tsl/htrie_hash.h#L1556-L1618

詳細2

READMEに記載がありますが、 std::hash<std::string> が内部にコピーをしてしまうため、C++17以上では std::string_view を利用、C+14以前では単純なハッシュを実装して利用しているようです。
このC++14以前の単純なハッシュを外部のハッシュに変えると20%ぐらい性能向上が観測されたとのこと。

以前別の記事で、manylinux対応のために、C++14前提で実装取り組まれていた事もあったので、今回もそれに該当しているのかもしれないと思い。

syoyosyoyo

ありがとうございます.

C++17 でコンパイルしても遅かったですね...
const でカイゼンできそうですね!

ただ, find_ks 利用で満足する速度でたので, longest_prefix の詳細調べるのはしばらくはなさそうです...