👌

ビームサーチの上位 N件を高速に取る手法について考えてみる

2023/01/26に公開

ビームサーチでは候補となるノードを評価してその上位 N件を取得する処理が必要になるのですが、この部分をどう実装するのが高速なのかが度々話題になったりならなかったりします。

候補としては以下の 4つになります。(C++)

  • sort
  • nth_element
  • partial_sort
  • priority_queue

気になったので簡易的なコードを書いて実験してみました。(10万回ランダムに値を生成して、その中の値の小さい方から 100件を取得する。それを 1000回繰り返すコード)

sort ver

#include <algorithm>
#include <iostream>
#include <chrono>
#include <vector>

using namespace std;
using namespace std::chrono;
typedef long long ll;

int main() {
  int seed = 1;
  srand(seed);

  int beam_width = 100;

  high_resolution_clock::time_point begin = high_resolution_clock::now();

  for (int depth = 0; depth < 1000; ++depth) {
    vector<int> numbers;

    for (int i = 0; i < 100000; ++i) {
      int v = rand() % 100000;

      numbers.push_back(v);
    }

    sort(numbers.begin(), numbers.end());
    for (int i = 0; i < beam_width; ++i) {
      int v = numbers[i];
    }
  }

  high_resolution_clock::time_point end = high_resolution_clock::now();
  ll count = duration_cast<microseconds>(end - begin).count();

  printf("%.4f\n", count / 1000.0);

  return 0;
}

nth_element ver

  for (int depth = 0; depth < 1000; ++depth) {
    vector<int> numbers;

    for (int i = 0; i < 100000; ++i) {
      int v = rand() % 100000;

      numbers.push_back(v);
    }

    nth_element(numbers.begin(), numbers.begin() + beam_width, numbers.end());
    for (int i = 0; i < beam_width; ++i) {
      int v = numbers[i];
    }
  }

partial_sort ver

  for (int depth = 0; depth < 1000; ++depth) {
    vector<int> numbers;

    for (int i = 0; i < 100000; ++i) {
      int v = rand() % 100000;

      numbers.push_back(v);
    }

    partial_sort(numbers.begin(), numbers.begin() + beam_width, numbers.end());
    for (int i = 0; i < beam_width; ++i) {
      int v = numbers[i];
    }
  }

priority_queue

  for (int depth = 0; depth < 1000; ++depth) {
    priority_queue <int, vector<int>, greater<int>> pque;

    for (int i = 0; i < 100000; ++i) {
      int v = rand() % 100000;
      pque.push(v);
    }

    for (int i = 0; i < beam_width && not pque.empty(); ++i) {
      int v = pque.top();
      pque.pop();
    }
  }

以下は最適化を無効化にして(-O0) 実行した結果になります。

type time(ms)
sort 9995.0
nth_element 7679.0
partial_sort 6245.1
priority_queue 13295.8

今回の実験では partial_sort が一番速く priority_queue が一番遅いという結果になりました。(最適化とか他のコードとかではまた違った結果になる気もしますが)

ただ、何回かコードを書いているうちに上位 N件を取得するということは N件目の値より悪いものは候補から事前に外しておけるのではないかと思い下記のコードを書いてみました。

  for (int depth = 0; depth < 1000; ++depth) {
    priority_queue <int> pque;

    for (int i = 0; i < 100000; ++i) {
      int v = rand() % 100000;

      if (pque.size() >= beam_width && pque.top() <= v) continue;
      pque.push(v);

      if (pque.size() > beam_width) pque.pop();
    }

    for (int i = 0; i < beam_width && not pque.empty(); ++i) {
      int v = pque.top();
      pque.pop();
    }
  }

まず priority_queue を昇順ではなく降順で管理するようにして、キューのサイズがビーム幅を超えた場合は現在管理している中で一番大きい値と比較してそれより値が大きい場合は候補から外します。またキューのサイズがビーム幅を超えた場合にはビーム幅になるように値を取り出します。

if (pque.size() >= beam_width && pque.top() <= v) continue;
pque.push(v);

if (pque.size() > beam_width) pque.pop();

この処理によって N番目の値と比較してそれより悪いものについては push の処理がスキップ出来るのである程度の高速化が見込めます。実際に実行してみると 1261.3ms となり大幅に高速化出来ました。

type time(ms)
sort 9995.0
nth_element 7679.0
partial_sort 6245.1
priority_queue 13295.8
priority_queue(足切り ver) 1261.3

ただこのコードだとビームサーチにおける重要な重複処理が出来ないのでその辺りが課題です。ただ、状態が重複しにくい問題とかもあると思うのでそのケースの場合は今回みたいな高速化手法も候補に入ってくるのかなと考えています。

追記 (23/01/26)

koyumeishi さんからの指摘で一定個数キューに溜まるごとに sort するほうが速いのではということだったので試してみました。

  for (int depth = 0; depth < 1000; ++depth) {
    vector<int> numbers;
    bool sorted = false;

    for (int i = 0; i < 100000; ++i) {
      int v = rand() % 100000;

      if (sorted && numbers[beam_width - 1] <= v) continue;
      numbers.push_back(v);

      if (numbers.size() >= 2 * beam_width) {
        sort(numbers.begin(), numbers.end());
        numbers.resize(beam_width);
        sorted = true;
      }
    }

    sort(numbers.begin(), numbers.end());
    for (int i = 0; i < beam_width; ++i) {
      int v = numbers[i];
    }
  }

ソート済みかどうかのフラグをつけて足切り機能を追加すると 791.8ms で動いてこっちのほうが priority_queue よりも高速に動作しました。

type time(ms)
sort 9995.0
nth_element 7679.0
partial_sort 6245.1
priority_queue 13295.8
priority_queue(足切り ver) 1261.3
sort(足切り vre) 791.8

情報ありがとうございます 🙏

Discussion