Open11

scannの実装を追う

katzkatz

scann

直積量子化を行うscannの論文を読んだあと実装を読んだがまったくわからなかったため実装を追ったメモを残す

https://arxiv.org/abs/1908.10396

https://github.com/google-research/google-research/tree/master/scann

ちなみに実装によくでてくるsoarは最近実装された論文
ぱっと見た感じMIPSに合わせていい感じ(AVQ的な)に2番目のクラスタを割り当てようって感じに見える
https://arxiv.org/abs/2404.00774

katzkatz

searcherの作成例

searcher = scann.scann_ops_pybind.builder(normalized_dataset, 10, "dot_product")
    .tree(num_leaves=2000, num_leaves_to_search=100, training_sample_size=250000)
    .score_ah(2, anisotropic_quantization_threshold=0.2)
    .reorder(100)
    .build()

目安

ScaNN performs vector search in three phases. They are described below:

  1. Partitioning (optional): ScaNN partitions the dataset during training time, and at query time selects the top partitions to pass onto the scoring stage.
  2. Scoring: ScaNN computes the distances from the query to all datapoints in the dataset (if partitioning isn't enabled) or all datapoints in a partition to search (if partitioning is enabled). These distances aren't necessarily exact.
  3. Rescoring (optional): ScaNN takes the best k' distances from scoring and re-computes these distances more accurately. From these k' re-computed distances the top k are selected.

All three phases can be configured through the ScaNNBuilder class. Before going into details, below are some general guidelines for ScaNN configuration.

このあたりで論文にないtree, score_ah, rescoreingなどが出てくる
ahはasymmetry hashingであるわけだが、Anisotropicはtreeのところにでてこないのか?

https://medium.com/@DataPlayer/scalable-approximate-nearest-neighbour-search-using-googles-scann-and-facebook-s-faiss-3e84df25ba

LSHにkmeans-tree?

katzkatz

MIPS Problemを量子化で対応しようとしている

q^Tx = \underset{k}{\Sigma}q^{(k)T}x^{(k)} \sim \underset{k}{\Sigma}q^{(k)T}U^{(k)}\alpha_x^{(k)} = \underset{k}{\Sigma}q^{(k)T}u_x^{(k)}

\alphaはone-hotベクトル、Uはd/K \times C_kの行列のためcodebookUの各列がクラスタの重心と対応しているはず

Note that this approximation is ’asymmetric’ in the sense that only database vectors x are quantized, not the query vector q. One can quantize q as well but it will lead to increased approximation error. In fact, the above asymmetric computation for all the database vectors can still be carried out very efficiently via look up tables similar to [9], except that each entry in the kth table is a dot product between q(k) and columns of U(k).

asymmetry hashingも[Jégou, 2010]でのAsymmetric distance computationと認識して問題なさそう

これの6章がTree-Quantization Hybrids for Large Scale Searchとなっていてそれっぽい。実際scannでもtree()をするとTreeXHybrid~クラスが登場する。

The basic idea of tree-quantization hybrids is to combine tree-based recursive data partitioning with
QUIPS applied to each partition. At the training time, one first learns a locality-preserving tree such as hierarchical k-means tree, followed by applying QUIPS to each partition. In practice only a shallow tree is learned such that each leaf contains a few thousand points. Of course, a special case of tree-based partitioners is a flat partitioner such as k-means. At the query time, a query is assigned to more than one partition to deal with the errors caused by hard partitioning of the data. This soft assignment of query to multiple partitions is crucial for achieving good accuracy for high-dimensional data.

クエリはpartition centerの近くに割り当てられ、実験では全てで2000 partition, クエリに割り当てられるpartitionは100である。これらはbrute forceアプローチより高速と論文に書いてあるが、ここでいうbruto forceとは?? この論文内でbrute forceという単語が出てくるのはここだけなので、すべての検索対象についてasymmetricな距離の計算をしているという意味のはず。

が、となるとscannで出てくるscore_ahscore_brute_forceの違いは?

→このときのbrute_forceは単純に性格な値を計算しているっぽい?

katzkatz

そもそも論文の位置づけとしては
[Jégou, 2010] → 直積量子化、目的関数はreconstruction error
[Guo, 2016] → 直積量子化、目的関数はMIPS
[Guo, 2020] → score awareな損失関数、scann

みたいな形か、、?

[Guo,2016]の場合はq \sim Qの分布は任意のようである。scannは単位超球面一様分布。

[Wu,2017]はいまいち接点がわからず
↓の通りtree+score_ahでは明示的に指定しない場合residual_quantizationが使われそう

katzkatz

tree, socre_ahなどは_factory_decoratorによって関数の引数をparamsフィールドに入れられ、各paramsの値を基にconfigが作られる。これをちゃんとみてなかったため、tree+score_ahが呼ばれた場合にresidual_quantizationがtrueに上書きされることを見落としていた。

def _factory_decorator(key):
  """Wraps a function that produces a portion of the ScaNN config proto."""

  def func_taker(f):
    """Captures arguments to function and saves them to params for later."""

    def inner(self, *args, **kwargs):
      if key in self.params:
        raise Exception(f"{key} has already been configured")
      kwargs.update(zip(f.__code__.co_varnames[1:], args))
      self.params[key] = kwargs
      return self

    inner.proto_maker = f
    return inner

  return func_taker
ah = self.params.get("score_ah")
bf = self.params.get("score_bf")
if ah is not None and bf is None:
  if "residual_quantization" not in ah:
    ah["residual_quantization"] = (
        tree_params is not None and self.distance_measure == "dot_product")
katzkatz

そこからscannのコードを追っていく
https://github.com/google-research/google-research/blob/25d0f4ab2cc249573e763b0a913d8504ab6137aa/scann/scann/scann_ops/py/scann_ops_pybind.py#L222-L223

https://github.com/google-research/google-research/blob/25d0f4ab2cc249573e763b0a913d8504ab6137aa/scann/scann/scann_ops/cc/scann_npy.cc#L71-L73

https://github.com/google-research/google-research/blob/25d0f4ab2cc249573e763b0a913d8504ab6137aa/scann/scann/base/single_machine_factory_scann.cc#L207-L229

treeを設定していればTreeXHybridFactoryが呼ばれ、

https://github.com/google-research/google-research/blob/25d0f4ab2cc249573e763b0a913d8504ab6137aa/scann/scann/base/internal/tree_x_hybrid_factory.cc#L605-L620

更に追っていくと
partitionでは
https://github.com/google-research/google-research/blob/25d0f4ab2cc249573e763b0a913d8504ab6137aa/scann/scann/partitioning/kmeans_tree_partitioner.cc#L98-L112

hashingでは
https://github.com/google-research/google-research/blob/25d0f4ab2cc249573e763b0a913d8504ab6137aa/scann/scann/base/internal/tree_x_hybrid_factory.cc#L277-L279
とそれぞれ学習コードが出てくるようである。


ここでpartitionのほうではデフォルトではAnisotropic VQはオフみたいだが、
https://github.com/google-research/google-research/blob/25d0f4ab2cc249573e763b0a913d8504ab6137aa/scann/scann/scann_ops/py/scann_builder.py#L81-L98

ここでavqを指定することが可能。その後がよくわからない
kmeans-treeの各葉に対して中心を変更している??

https://github.com/google-research/google-research/blob/1ff9bb3a9685223e9d2ded8888c0a4a9652cf7ba/scann/scann/partitioning/kmeans_tree_partitioner.cc#L512-L515


score_ahのほうは
https://github.com/google-research/google-research/blob/1ff9bb3a9685223e9d2ded8888c0a4a9652cf7ba/scann/scann/hashes/asymmetric_hashing2/training.h#L71-L74
から

https://github.com/google-research/google-research/blob/1ff9bb3a9685223e9d2ded8888c0a4a9652cf7ba/scann/scann/hashes/internal/asymmetric_hashing_impl.cc#L39-L196
と学習コードが存在。
この中の

gmm.ComputeKmeansClustering(
        chunked_dataset[i], opts.config().num_clusters_per_block(), &centers,
        {.final_partitions = &subpartitions, .weights = weights}));`

は普通のkmeansでクラスタの中心を計算しているっぽいが、
それからしばらくした
https://github.com/google-research/google-research/blob/1ff9bb3a9685223e9d2ded8888c0a4a9652cf7ba/scann/scann/tree_x_hybrid/tree_ah_hybrid_residual.cc#L341-L343
ではkmeansで計算したresidualに対して再度割り当てるクラスを計算し直している、、??
https://github.com/google-research/google-research/blob/1ff9bb3a9685223e9d2ded8888c0a4a9652cf7ba/scann/scann/hashes/internal/asymmetric_hashing_impl.cc#L433

https://github.com/google-research/google-research/blob/1ff9bb3a9685223e9d2ded8888c0a4a9652cf7ba/scann/scann/hashes/internal/asymmetric_hashing_impl.cc#L380

katzkatz

debugしたいがmacでビルドができない、、

katzkatz

一旦諦めて、コードのbuildは見終わったのでsearchに行きたい