🔡

高速UTF-8バリデーションを通してビット演算を学ぶ

2024/03/26に公開

はじめに

こんにちは。前回は文字列から浮動小数点数への変換を高速化する話を書きました。その中で、もっと文字列を読む部分も高速化できないかと思い複数バイトを一度に操作する方法について色々実験してみたところ想像以上に面白かったので作ったものなど紹介したいと思います。

目指すもの・目指さないもの

今回はGo言語の標準ライブラリであるutf8.Validを高速化することを目標にして進めていきます。この関数は、与えられたバイト配列が正しくUTF-8の符号化に従えているかを確認するものです。もちろん、この内容を過去に研究した人もおり、有名どころで言うとDaniel Lemireさんの論文などがありますが、今回はこの論文に紹介されていた方法とは違いSIMD使わないことにします。

理由はそれぞれのプラットフォームへの対応が簡単になることが大きく、コードを統一できたり、メンテナンスが楽になります。特にGo言語ではSIMDを直接利用するのが難しく(少なくとも僕には難しかった)、uint64(符号なし64-bit整数)ならアーキテクチャに依存せずに存在していることもひとつです。

ではまずUTF-8が何で、これがなぜ重要なのかから見ていきたいと思います。

UTF-8

文字符号化形式の一つで、実はGo言語の初期チームと同じ方々が作りました。文字符号化形式とは具体的に、文字を表現するためにそれぞれに割り当てた数字を、コンピューター上で取り扱えるようにバイト列で表す方法のことです。僕はこれを最初聞いたとき簡単に思えたのですが、仕組みを知ると間違いに気づきました。例えば、日本でよく使われていたShift JISエンコーディングですが、使用するバイトにかぶりがあるため、単純なバイトの一致を検索すると誤検知が多くなってしまう欠点があります。

ほかにも、UTF-8はA, a, 0などの一般的なASCII文字と完全な互換性があるため、よく考えられていることがわかります。

本題ですが、ここで問題になってくるのはこの形式に則っているかを判定する方法を開発することです。そもそもこれが必要なのかと思った方がいるかもしれません。が、過去にUTF-8の書式を正しく確認していなかったことで起こった脆弱性があることを考えると、これは必要だと言えるでしょう。

では、高速化のへの道筋を立てます。

同時に複数のバイトを処理する

一般的に、コンピューターは64-bitのワード(CPUレベルの変数の一単位のようなもの)を持っているため、8バイトまで(1バイト = 8-bit)は同時に計算を行うことができるということになります。
幸運なことに、Go言語には最適化されたバイト列の中身を組み合わせて一つのより大きな整数にする処理が埋め込まれているため、(関連したコード)これは問題にはなりません。

実際の計算のステップは以下のようになります。

ASCIIを検知する

ASCIIは先ほど言った通り有効なUTF-8の一部です。かつ、アルファベットや数字は多く存在するので、これを先に検知しておくのも高速化という観点において重要であると思います。

検知はさほど難しくはなく、ASCII文字は

  • 全て1バイトで表される。
  • 一番上のビット(MSB)は常に0である。

ことを利用します。一番上のビットだけを取り出す時はビット的AND(&)をします。一番上のビットだけが立っているマスクは0x80のため、存在する8バイト全てにこれをマップして0x8080808080808080を使います。
実際の実装はこのようになります。

func isASCII(data uint64) bool {
	return data&0x8080808080808080 == 0
}

(go.dev/play)

常に無効なバイトを検知する

では、ASCIIではないバイトがあったらどうするのでしょうか...僕は一番最初にどこに存在したとしても無効になるバイトに着目しました。具体的には、

  • 0xf4より大きいバイトは有効なUTF-8には存在しません。
  • 0xc00xc1は有効なUTF-8には存在しません。

この時点でこの部分を単純なバイト検索に問題を縮めることができます。この方法はどこかで紹介したような気がするのですが、以下のように書けます。

func contains(data, mask uint64) bool {
	data ^= mask
	return (data-0x0101010101010101)&^data&0x8080808080808080 != 0
}

(go.dev/play)

引数maskは8つの各バイトに検索したい文字が入っています。例えば上のリンクの場合、0を探していて0は16進数で0x30なので引数は0x3030303030303030のようになっているのがわかると思います。^はビット的XORであり、ビットが一致しない部分に1が立ちます。つまり、探したい文字が入ったバイトのみ0になります。次に全てのバイトから1を引きます。もちろん、0以外はなんともないのですが、一致したバイトだけはオーバーフローを起こし各バイトの最上位ビットが1になります。

ちなみに&^はビット的AND-NOTであり、最上位ビットが1である場合から、元々1が立っていた場合を除いで誤検知を防ぐためのものです。

ただこの方法で複数ケース計算していては結局非効率的になってしまいます。なんとかして0xc00xc1を検知する方法はないのでしょうか?

実は答えは簡単で、最初に各バイトに1をビット的ORしておくだけです。それと同様に"0xf4より大きい"も前回と同様にオーバーフローを利用して表現することができます。

特別な範囲を検知する

The Unicode standardによると、最初のバイトによって次のバイトの許容された値の範囲が変わることがあるとわかります。具体的には0xe0, 0xed, 0xf0, 0xf4の場合特別な範囲がつきます。通常の範囲はのちに判定します。

先ほど使ったテクニックをそのまま利用すると以下のようになります...

func isSpecial(data uint64) bool {
	xed := data ^ 0xedededededededed
	xf0 := data | 0x1010101010101010 ^ 0xf0f0f0f0f0f0f0f0
	xf4 := data ^ 0xf4f4f4f4f4f4f4f4
	xed = (xed - 0x0101010101010101) &^ xed
	xf0 = (xf0 - 0x0101010101010101) &^ xf0
	xf4 = (xf4 - 0x0101010101010101) &^ xf4
	return (xed|xf0|xf4)&0x8080808080808080 != 0
}

(go.dev/play)

命令数多いですね。16回の演算をする必要があります(たしかAND-NOTはANDとNOTの2つの命令の組み合わせなので)。これは若干改善することができます。例えば

func isSpecial(data uint64) bool {
	top := data & 0x8080808080808080
	btm := data & 0x7f7f7f7f7f7f7f7f
	xed := btm ^ 0x6d6d6d6d6d6d6d6d - 0x0101010101010101
	xf0 := btm | 0x1010101010101010 ^ 0x7070707070707070 - 0x0101010101010101
	xf4 := btm ^ 0x7474747474747474 - 0x0101010101010101
	return top&(xed|xf0|xf4) != 0
}

(go.dev/play)

このようにすると12回に抑えられます。トリックは上下に分けて空のビットを作っておくことによって、前のセクションで話した誤検知防止を全体で一回に済ませることができることです。

このように調べたいバイトに少しでもパターンがある場合にはさらに最適化を加えることができます。

バイトの順序等

最後に各バイトの先頭の連続した1が前や後のバイトに対応しているかどうかを調べる必要があります。
そのためにはまず"先頭の連続した1"だけを残す必要があるのですが、これがなかなか厄介です。
問題が後ろ側の連続したビットを残すというものなら実装は簡単で効率的です。具体的には以下のようにして実装できるからです。

data&(^data-0x0101010101010101)

ちなみに効率的なハミング重みの関数があればこれを使ってCTZ(Count Trailing Zeros)を自分で実装することもできます(1ワード単位ならハードウェア実装があることが多いので使いどころは少ないですが)。

ただ今回はビットが先頭にあり、これを当てはめるには各バイト内のビットを反転させる必要があり、もちろんビットの反転は重い作業です。簡単な実装としては以下のようなものがありますが、これはそれぞれの列が前の列の答えに依存しているため値を並列して計算する最適化を行いにくいことも難点です。

func reverse64(data uint64) uint64 {
	data = data>>1&0x5555555555555555 | data&0x5555555555555555<<1
	data = data>>2&0x3333333333333333 | data&0x3333333333333333<<2
	data = data>>4&0x0f0f0f0f0f0f0f0f | data&0x0f0f0f0f0f0f0f0f<<4
	return data
}

(go.dev/play)

反転が必要ないアプローチとして、右シフトを順番にビット的ORしていく方法があります。

	u64 = ^u64
	u64 |= u64 & 0xfefefefefefefefe >> 1
	u64 |= u64 & 0xfcfcfcfcfcfcfcfc >> 2
	u64 |= u64 & 0xf0f0f0f0f0f0f0f0 >> 4
	u64 = ^u64

(go.dev/play)

ただ、これも少しだけ早くなります。なぜなら、先頭のビットが4つよりも大きくなるケースは先ほど紹介した"常に無効なバイトを検知する"で既に弾かれているからです。よって3つ目のORを消すことができます。

これの最初のバイトの部分をbits.Mul64等(64-bitプラットフォームで1つの命令にコンパイルされます)でそれ以外のバイトの部分とマッチさせ、合ったらこの文字列は有効ということになります。もしこの掛け算が上位ワードのどこかのバイトのMSBにオーバーフローしたなら、次に判定を持ち越します。

それでは、これらを組み合わせてベンチマークを走らせてみましょう。

ベンチマーク

下の表は一秒に何バイトの文字の有効性を確認できたかを表していて(高いほうが良い)、左が標準ライブラリ、右が今回僕が作ったものを表しています。

                     │ ./before.txt │              ./after.txt              │
                     │     B/s      │      B/s       vs base                │
Valid/ascii-small-4    1.699Gi ± 2%    1.783Gi ± 1%   +4.99% (p=0.000 n=10)
Valid/ascii-large-4    11.01Gi ± 1%    15.82Gi ± 1%  +43.64% (p=0.000 n=10)
Valid/kanji-small-4    982.4Mi ± 1%   1099.9Mi ± 2%  +11.95% (p=0.000 n=10)
Valid/kanji-large-4    890.4Mi ± 1%   1334.1Mi ± 1%  +49.82% (p=0.000 n=10)
Valid/unicode.json-4   821.6Mi ± 7%   1226.3Mi ± 2%  +49.26% (p=0.000 n=10)
geomean                1.658Gi         2.162Gi       +30.44%

ascii-smallは10文字のASCIIで、ascii-largeは100KBのASCIIです。速度が改善しているのは、配列の範囲チェックをBCEを使ってできるだけなくしたことと(自動判定が結構トリッキーなところがあり苦労しました)、細かいメモリロードを8バイト以上の文字列で発生させないように工夫したところが影響していると考えています。

kanji-smallは10文字の漢字で、kanji-largeは100KBのの漢字、unicode.jsonは複数の言語や絵文字を組み合わせて作ったファイルです。速度が大きく改善していることが解ると思います。これには今回紹介したアルゴリズムが大きく影響していると考えています。

レポジトリ

https://github.com/sugawarayuuta/charcoal

このライブラリには、標準ライブラリのテストに加えて活発にファジングテストも行っています。

おわりに

今回はUTF-8の仕組みとバリデーションの高速化の方法の例を紹介しました。もし気になった点・質問・協力・提案などありましたらどんな方法でもお知らせください。読んでいただきありがとうございました。

Discussion