PythonによるWavelet Matrixの実装
この記事はJij Inc. Advent Calendar 2023の3日目の記事です。
はじめまして、株式会社Jij の篠原です。
概要
WaveletMatrix[1]は元々, 文字列におけるrank
, select
クエリを効率的に行うデータ構造Wavelet Tree[2] として提案されたものを簡潔に実装し直したものになります.
競技プログラミングの文脈だと, 整数値からなる配列
WaveletMatrixでは, 数列Tの各要素を各桁のビットにわけて基数ソートして0-1簡潔ビットベクトル (完備辞書) として保存します.
最初にWaveMatrixができることを示し, その次に軽く0-1簡潔ビットベクトルについて述べ, 最後にWaveletMatrixのPython実装を示します.
Notation
T
: 元の配列.
T[i]
: 元の配列の i 番目
x
: Tの任意の要素.
bit_size
:
できること
簡潔ビットベクトルのrank, selectを
前処理
操作 | 説明 | 計算量 |
---|---|---|
access(i) |
T[i] の要素を計算. (Wavelet Matrixでは, 通常元の配列Tは持たない.) | O(bit_size) |
rank(x, right) |
T[0..right) における x の出現回数を計算 | O(bit_size) |
rank_range(x, left, right) |
T[left..right) における x の出現回数を計算 | O(bit_size) |
select(x, k) |
T の k 個目の x の出現位置を計算 | O(bit_size) |
quantile(left, right, k) |
T[left..right) の中の k 番目に小さい値を計算 | O(bit_size) |
range_freq(left, right, lower, upper) |
T[left..right) の lower ≤ x < upperとなる x の個数を計算 | O(bit_size) |
prev_value(left, right, upper) |
T[left..right) の x < upper となる最大のxを計算 | O(bit_size) |
next_value(left, right, lower) |
T[left..right) の lower ≤ x となる最小のxを計算 | O(bit_size) |
簡潔ビットベクトル
概要
rank
とselect
を
空間計算量を
できること
操作 | 説明 | 計算量 |
---|---|---|
access(i) |
元の配列のi番目の要素を返す. | O(1) |
rank_0(i) |
元の配列B[0..i)の0の個数 | O(1) |
rank_1(i) |
元の配列B[0..i)の1の個数 | O(1) |
select_0(k) |
0がk番目に現れる元の配列のindexを返す | O(1) |
select_1(k) |
1がk番目に現れる元の配列のindexを返す | O(1) |
横着
私の場合, 更新系のクエリが無いことから累積和 + 二分探索を使った
実装
from typing import Optional
from itertools import accumulate
class BitVector:
def __init__(self, B: list[int]):
"""累積和を使用したビットベクトル
Args:
B (list[int]): 要素は0 or 1
"""
self.B = B
self.acc = list(accumulate(B))
def __len__(self) -> int:
"""元の配列の長さ
Returns:
int: 元の配列の長さ
"""
return len(self.B)
def __getitem__(self, i: int) -> int:
"""B[i]を取得
Args:
i (int): i番目の要素
Returns:
int: B[i]
"""
return self.B[i]
def rank0(self, i: int) -> int:
"""元の配列B[0..i)の0の個数
Args:
i (int): 上限
Returns:
int: 0の個数
"""
if i <= 0:
return 0
i = min(i, len(self))
return i - self.rank1(i)
def rank1(self, i: int) -> int:
"""元の配列B[0..i)の1の個数
Args:
i (int): 上限
Returns:
int: 1の個数
"""
if i <= 0:
return 0
i = min(i, len(self))
return self.acc[i - 1]
def rank0_all(self) -> int:
"""元の配列Bの0の個数
Returns:
int: 0の個数
"""
return self.rank0(len(self.B))
def rank1_all(self) -> int:
"""元の配列Bの1の個数
Returns:
int: 1の個数
"""
return self.rank1(len(self.B))
def select0(self, k: int) -> Optional[int]:
"""0がk番目に現れるindexを返す
Args:
k (int): k番目. 1-indexed.
Returns:
Optional[int]: 該当のindex. 存在しない場合はNone
"""
if (k <= 0) or (k > self.rank0_all()):
return None
left, right = 0, len(self)
while (right - left) > 1:
mid = (left + right) // 2
if self.rank0(mid) < k:
left = mid
else:
right = mid
return left
def select1(self, k: int) -> Optional[int]:
"""1がk番目に現れるindexを返す
Args:
k (int): k番目. 1-indexed.
Returns:
Optional[int]: 該当のindex. 存在しない場合はNone
"""
if (k <= 0) or (k > self.rank1_all()):
return None
left, right = 0, len(self)
while (right - left) > 1:
mid = (left + right) // 2
if self.rank1(mid) < k:
left = mid
else:
right = mid
return left
実装
from typing import Optional
class WaveletMatrix:
def __init__(self, T: list[int]):
"""WaveletMatrix
Args:
T (list[int]): 整数列
"""
self.bit_size: int = max(T).bit_length() if T else 0
self.wavelet_matrix: list[BitVector] = self._build(T)
def _build(self, T: list[int]) -> list[BitVector]:
"""WaveletMatrixを構築する
Args:
T (list[int]): 元の配列
Returns:
list[BitVector]: WaveletMatrix
"""
wavelet_matrix: list[BitVector] = []
for digit in range(self.bit_size)[::-1]:
zeros = [t for t in T if not self._get_i_bit(t, digit + 1)]
ones = [t for t in T if self._get_i_bit(t, digit + 1)]
T = zeros + ones
wavelet_matrix.append(BitVector([self._get_i_bit(t, digit) for t in T]))
return wavelet_matrix
def __len__(self) -> int:
"""len(T).
Returns:
int: len(T)
"""
if len(self.wavelet_matrix) == 0:
return 0
return len(self.wavelet_matrix[0])
def __getitem__(self, i: int) -> int:
"""元の配列T[i]を返す
Args:
i (int): index
Returns:
int: T[i]
"""
return self.access(i)
def _get_i_bit(self, x: int, digit: int) -> int:
"""xの(下から数えて)digit桁目のbitを計算
Args:
x (int): 整数
digit (int): 桁数
Returns:
int: 0 or 1
"""
return 1 & (x >> digit)
def _next_range(self, B: BitVector, bit: int, left: int, right: int) -> (int, int):
"""B[j][left..right)におけるbit(0 or 1)が, B[j+1]においてどの範囲になるかを計算
Args:
B (BitVector): Bit Vector B[j].
bit (int): 範囲に含まれるbit. 0 or 1.
left (int): Bの範囲の下限.
right (int): Bの範囲の上限.
Returns:
(int, int): B[j+1]の[left..right)
"""
if bit == 0:
left = B.rank0(left)
right = B.rank0(right)
else:
left = B.rank0_all() + B.rank1(left)
right = B.rank0_all() + B.rank1(right)
return (left, right)
def access(self, i: int) -> int:
"""元の配列T[i]を返す
Args:
i (int): index
Returns:
int: T[i]
"""
x = 0
for digit, B in enumerate(self.wavelet_matrix):
bit = B[i]
x |= bit << (self.bit_size - digit - 1)
if bit == 0:
i = B.rank0(i + 1) - 1
else:
i = B.rank0_all() + B.rank1(i + 1) - 1
return x
def rank(self, x: int, right: int) -> int:
"""T[0..right)における, xの出現回数を返す
Args:
x (int): 対象の要素
right (int): Tの範囲の上限
Returns:
int: 出現回数
"""
return self.rank_range(x, 0, right)
def rank_range(self, x: int, left: int, right: int) -> int:
"""T[left..right)におけるxの出現回数を返す
Args:
x (int): 対象の要素
left (int): Tの範囲の下限
right (int): Tの範囲の上限
Returns:
int: 出現回数
"""
if (left >= right) or (x.bit_length() > self.bit_size) or (x < 0):
return 0
left = max(left, 0)
right = min(right, len(self))
for digit, B in enumerate(self.wavelet_matrix):
x_bit = self._get_i_bit(x, self.bit_size - digit - 1)
left, right = self._next_range(B, x_bit, left, right)
return right - left
def select(self, x: int, k: int) -> Optional[int]:
"""元の配列Tのk個目のxの出現位置 (index) を返す
Args:
x (int): 対象の要素.
k (int): 何個目の出現か. 1-index.
Returns:
Optional[int]: index. 該当要素が存在しない場合はNone.
"""
if not (0 <= x < 1 << self.bit_size):
return None
if k <= 0:
return None
# xが最終的な配列のどの範囲に入るか計算
left, right = 0, len(self)
for digit, B in enumerate(self.wavelet_matrix):
x_bit = self._get_i_bit(x, self.bit_size - digit - 1)
left, right = self._next_range(B, x_bit, left, right)
if left >= right:
return None
# 上記範囲のk番目が, 元の配列で何番目か計算
index = left + k - 1
for digit, B in enumerate(self.wavelet_matrix[::-1]):
if index is None:
return None
x_bit = self._get_i_bit(x, digit)
if x_bit == 0:
index = B.select0(index + 1)
else:
index = B.select1(index - (B.rank0_all() - 1))
return index
def quantile(self, left: int, right: int, k: int) -> Optional[int]:
"""元の配列T[left..right)の中のk番目に小さい値を返す
Args:
left (int): Tの範囲の下限
right (int): Tの範囲の上限
k (int): 何番目の要素か, 1-index
Returns:
Optional[int]: k番目に小さい値. 該当要素が存在しない場合はNone.
"""
if (left >= right) or (k <= 0) or (right - left < k):
return None
left = max(left, 0)
right = min(right, len(self))
x = 0
for digit, B in enumerate(self.wavelet_matrix):
num_zeros = B.rank0(right) - B.rank0(left)
# 該当要素が0の中か or 1の中か
bit = 0 if k <= num_zeros else 1
left, right = self._next_range(B, bit, left, right)
if bit == 1:
k -= num_zeros
x |= bit << (self.bit_size - digit - 1)
return x
def range_freq_to(self, left: int, right: int, upper: int) -> int:
"""T[left..right)の0 <= x < upper となるxの個数を計算する
Args:
left (int): Tの範囲の下限
right (int): Tの範囲の上限
upper (int): 要素の上限
Returns:
int: 0 <= x < upperとなる要素の数
"""
if (left >= right) or (upper <= 0):
return 0
# 全ての要素が < upper
if upper.bit_length() > self.bit_size:
return right - left
cnt = 0
left = max(left, 0)
right = min(right, len(self))
for digit, B in enumerate(self.wavelet_matrix):
upper_bit = self._get_i_bit(upper, self.bit_size - digit - 1)
if upper_bit == 1:
cnt += B.rank0(right) - B.rank0(left)
left, right = self._next_range(B, upper_bit, left, right)
return cnt
def range_freq_from(self, left: int, right: int, lower: int) -> int:
"""T[left..right)のlower <= x となるxの個数を計算する
Args:
left (int): Tの範囲の下限
right (int): Tの範囲の上限
lower (int): 要素の下限
Returns:
int: lower <= x となる要素の数
"""
if (left >= right) or (lower > (1 << self.bit_size)):
return 0
left = max(left, 0)
right = min(right, len(self))
return (right - left) - self.range_freq_to(left, right, lower)
def range_freq(self, left: int, right: int, lower: int, upper: int) -> int:
"""T[left..right)のlower <= x < upper となるxの個数を計算する
Args:
left (int): Tの範囲の下限
right (int): Tの範囲の上限
lower (int): 要素の下限
upper (int): 要素の上限
Returns:
int: lower <= x < upperとなる要素の数
"""
if (left >= right) or (lower >= upper) or (lower > (1 << self.bit_size)):
return 0
return self.range_freq_to(left, right, upper) - self.range_freq_to(
left, right, lower
)
def prev_value(self, left: int, right: int, upper: int) -> Optional[int]:
"""T[left..right)の中で, 0 <= x < upperを満たす最大のxを返す
Args:
left (int): Tの範囲の下限
right (int): Tの範囲の上限
upper (int): 要素の上限
Returns:
Optional[int]: 0 <= x < upperを満たす最大のx. 存在しない場合None.
"""
# T[left..right)内の, 0 <= x < upperのxの個数
cnt = self.range_freq_to(left, right, upper)
if cnt == 0:
return None
return self.quantile(left, right, cnt)
def next_value(self, left: int, right: int, lower: int) -> int:
"""T[left..right)の中で, lower <= x を満たす最小のxを返す
Args:
left (int): Tの範囲の下限
right (int): Tの範囲の上限
lower (int): 要素の下限
Returns:
int: lower <= x を満たす最小のx
"""
# T[left..right)内の, lower <= xとなるxの個数
cnt = self.range_freq_from(left, right, lower)
if cnt == 0:
return None
return self.quantile(left, right, right - left - cnt + 1)
参考
[1]: F.Claude nad G. Navarro. The wavelet marxi. In Proceedings of the 19th International Symposium on String Processing nad Information Retrieval (SPIRE), pp.167-179, 2012.
[2]: R. Grossi, A. Gupta, and J. S. Vitter, High-order entropy-compressed text indexes, Proceedings of the 14th Annual SIAM/ACM Symposium on Discrete Algorithms (SODA), January 2003, 841-850.
[3]: 定兼 邦彦, 簡潔データ構造, アルゴリズム・サイエンスシリーズ―数理技法編, 共立出版, 2018
[4]: 岡野原大輔, 高速文字列解析の世界――データ圧縮・全文検索・テキストマイニング, 岩波書店, 2012
[5]: ウェーブレット行列(wavelet matrix)
Verify
Library Checker - Range Kth Smalleset
最後に
\Rustエンジニア・数理最適化エンジニア募集中!/
株式会社Jijでは、数学や物理学のバックグラウンドを活かし、量子計算と数理最適化のフロンティアで活躍するRustエンジニア、数理最適化エンジニアを募集しています!
詳細は下記のリンクからご覧ください。皆さんのご応募をお待ちしております!
Rustエンジニア: https://open.talentio.com/r/1/c/j-ij.com/pages/51062
数理最適化エンジニア: https://open.talentio.com/r/1/c/j-ij.com/pages/75132
-
F.Claude nad G. Navarro. The wavelet marxi. In Proceedings of the 19th International Symposium on String Processing nad Information Retrieval (SPIRE), pp.167-179, 2012. ↩︎
-
R. Grossi, A. Gupta, and J. S. Vitter, High-order entropy-compressed text indexes, Proceedings of the 14th Annual SIAM/ACM Symposium on Discrete Algorithms (SODA), January 2003, 841-850. ↩︎
Discussion
select
の実装を参考にできて助かりました!