LLM 日本語データセット向けに C++ で minhash 重複除去を行うメモ

2023/09/13に公開

背景

LLM 向けデータセット(コーパス)では重複のない品質の高いデータセットが重要となります.
よくあるのは minhash で fuzzy に行うのと, suffix array で exact に行うもののようです.
(RefinedWeb 論文 https://arxiv.org/abs/2306.01116 より)

今回は minhash での fuzzy dedup を取り上げます.

https://github.com/HojiChar/HojiChar/blob/main/hojichar/filters/deduplication.py

参考となる実装は HojiChar にありますが, いかんせん Python はやはり遅い...
C++ でやります!

情報

https://cran.r-project.org/web/packages/textreuse/vignettes/textreuse-minhash.html

https://medium.com/carbon-consulting/explaining-lsh-minhash-simhash-c3cc33040030

minhash

https://www.arxiv-vanity.com/papers/2107.06499/#A1

  • テキストの N-gram 表現を作る. これに対して hash を計算する.
  • hash を BUCKET_SIZE(b) * N_BUCKET(r) の合計 k 個用意する
    • 通常はハッシュ関数(アルゴリズム)は同じで, seed を変えて対応
    • ハッシュ関数は衝突が少ないものがよい. murmurhash3 がよく使われるようであるが, FNV などの軽量ハッシュ関数でもいいかも https://qiita.com/yamasaki-masahide/items/d478dd111690bc84784a
    • N-gram の item M 個に対してハッシュを計算し, 値が小さいのを選んで保存 => k 個の hash 値
  • v: b 個のハッシュを連結して一つの値にする
    • optional: b-bit minhash: BUCKET_SIZE から b bit(HojiChar では 2 byte(b = 16bit))切り出し, BUCKET_SIZE * (b / 8) バイトのハッシュ値を N_BUCKET ぶん作る
  • bucket(r 個)内のいずれかで連結した値 v が一致したら, potential match とする

Jaccard similarity

minhash で potential match したものにしたいして,

1 - (1 - s^b)^r で potential match 率(?)が求まる.

s は Jaccard index(生真面目に計算した Jaccard 係数)っぽい.

b = 20, r = 450 だと, Jaccard index s がおよそ 0.8 になると, potential match が 1.0 となる(つまり, hash の重複が一つでもあると, 確率的に Jaccard 係数はおよそ 0.8 以上)

b = 20, r = 10 だと 0.96 くらい

b = 20, r = 40 だと 0.9 くらい

したがって r(N_BUCKETS) は大きいほうがよいが, 計算量との兼ね合い.

false positive 対策のため, potential match したのに対して実際に Jaccard index を求めるのがよいようであるが, 計算量次第ではこれはスキップしてもよいでしょう.
(RefinedWeb ではスキップしている)

また, よりよい dedup をするためにここから The edit similarity を算出(後述参照)する.

日本語の場合

日本語の場合は UTF-8 文字単位で処理が多い. 今回は UTF-8 文字単位にしましたが,
rinna tokenizer など, 日本語トークナイザを使う手もあるかもです.

また, 日本語は事前に dedup 向けに正規化(normalize)しておくのが推奨です.

5-gram?

HojiChar では 5-gram(5 文字)にしています.
The Pile や RefinedWeb も 5-gram ですが, これは英語で 5 単語づつなので, 日本語の場合はもうちょっと増やしてもいいのかもはしれません(8 とか)

実装

https://github.com/lighttransport/japanese-llama-experiment/blob/main/cpp/dedup.hh

https://github.com/lighttransport/japanese-llama-experiment/blob/a7df8ea0eebc4da40d43b302ef2962a95fb58f19/cpp/main.cc#L266

マルチスレッド化と template で struct など固定サイズにして効率化しましたが, 結構 naiive に実装したからか, すごい速いとはなりませんでした.

minhash 計算時間

75B tokens(jsonl + zstd 圧縮時 80 GB. 展開で 240 GB ほど)で, Ryzen9 3950X 16 コアで 225 分かかりました.
(json 読み書き, ngram 分解などの時間含む)

概ね 5.5 M chars(日本語文字)/sec という感じでしょうか.

dedup

あとは minhash のハッシュ値で重複判定すればいけます.

bucket 単位で判定したり, マッチした hash / key_items で近似 Jaccard 係数を出してしきい値(0.7 など)以上であれば一致と判断するのがよいようですが, LLM データセットの場合は重複が無いのが重要となるため, 今回は bucket 単位は無視し, 一個でも hash 値がすでにあるかどうかで重複判断するようにしました.

これにより, 重複していないが, 重複と判断されてしまうのも出てくるでしょうがが,
LLM データセットでは重複があるよりは無い方がよいのでまあヨシ! とします.
ちなみに 1 document 200 個(20 x 10) hash 計算では概ね 8 ~ 11 % ほどの重複判断となりました(下記参照)

unordered_set の keyer

とりま重複管理に std::unordered_set 使いました.

unordered_set, unordered_map の場合は key の hasher が必要になります.
key の各要素はすでにハッシュ値になっていますので, hash combine をするだけでよいです.

今回は cityhash を使いました.

murmurhash3 なハッシュ値を cityhash で key combine して(効率良いのかどうか)は要検証ですね.

template<uint32_t BUCKET_SIZE = 10, uint32_t B = 2>
struct MinHashValHasher
{
  static_assert(B == 2, "B must be 2 for now.");

  size_t operator()(const MinHashVal<BUCKET_SIZE, B> &k) const {
    size_t seed = k.vals[0];
    for (uint32_t i = 1; i < k.nitems(); i++) {
      hash_combine(seed, k.vals[i]);
    }

    return seed;
  }
};

実装

https://github.com/lighttransport/japanese-llama-experiment/blob/a7df8ea0eebc4da40d43b302ef2962a95fb58f19/cpp/main.cc#L363

dedup 処理時間

TOTAL: duplicated 10066046 documents(total 91701905). ratio = 10.9769 %

real    49m51.016s

およそ 75 B tokens(~200 GB), ドキュメント数(JSON アイテム数) 9,170 万で 11 % ほどの重複を除去となりました.
処理時間は i9-12900K 1 core で 50 分ほどでした.

今回は dedup 用にテキストの正規化(数字は 0 にする, 句点は除くなど)はしていないので, テキスト正規化して minhash 計算したらもうちょっと dedup rate 上がるかもしれません.

さらなる高みを目指して

mimhash 高速化

https://gudok.xyz/minhash1/

あたりが参考になるでしょう.

日本語の場合, ほぼ 3 bytes として 3 * N gram 単位で処理したり, uint32_t で char 表現(トークナイズ)してもいいかもしれません.

dedup のマルチスレッド化, out-of-core 対応

minhash 自体はスレッド化しやすいですが, dedup で set はそのままではマルチスレッド化しづらいです.
bucket 単位もしくは, set の key のハッシュ値の上位 8 byte で 256 entry 作るなどしたらいい感じになるやもはしれません.

LLM 向けデータセット構築ではおおむね minhash dedup での重複率が 5 ~ 10 % と想定し, 全文(全ドキュメント)に対して重複判定したいだけであれば, hash 値を一括で読んでソートさせてで判定もよさそうです.
(hash 値の場合ある程度一様に分布していると考えられるので, ソート処理も worst case にはならずにできると予想)

あとは今後より大きなデータセットを dedup するときに, メモリに入らない場合は NVMe 使って out-of-core 処理とか.

よりよい dedup

Deduplicating Training Data Makes Language Models Better
https://arxiv.org/abs/2107.06499

RefinedWeb では Lee et al. 2001 と同じ設定で, 5-gram 表現を使い, ドキュメントあたり 9000 hashes(20 * 450 buckets)を計算しています!

https://arxiv.org/abs/2306.01116

↑の HojiChar だと The Pile と同様に 20 * 10 buckets(200 hashes)になっています.
(RefienedWeb では 10 buckets では lower dedup rate になってしまうとある)

Lee et al. では minhash での potential match と求まった Jaccard index が 0.8 以上の場合に, edit similarity を計算して類似度判定をしています(document の graph(なんの?)を構築する必要がありちょっと実装はめんどそう)

TODO

  • minhash で Bucket 単位で dedup 場合としない場合でどうなるか調査する
    • 非 bucket 単位と判定の精度が同じくらいになるのであれば, bucket 単位でマルチスレッドで dedup 処理することができそう

Discussion