🌿

ビームサーチ用の固定長ヒープ

2024/01/17に公開

TL;DR

  • コンパイル時に最大サイズが決められる、二分ヒープFixedHeapを実装しました
  • 決めた最大サイズより多く挿入すると、値の小さいものを決めたサイズだけ保持します

モチベーション

ヒューリスティックな競技プログラミングコンテストで、ビームサーチを打ちたい場面は山ほどありますが、いつも評価値上位の状態の管理をstd::collections::BinarhHeapで無理やりなんとかしています。

ヒープの大きさがビーム幅を超えるたびに、1つ削除する処理を書くのは面倒。ビーム幅は予め決まっているのでどうにか楽にできない?と思ってどうにかしました。

機能・つかいかた

  • Ordトレイトを実装している型と、ヒープの最大サイズをジェネリックに取ります
  • push()で挿入ができますが、pop()はありません
  • iter()IntoIterトレイトが実装されています
  • peek()で最大値の参照が取れます

TODO

  • ソートされた配列を返すinto_sorted_vec()

実装

シンプルな二分ヒープがベースですが、ヒープが埋まってから挿入するときは、現状の最大値より小さければ、順序木の根を書き換えdown-heapを行うことで挿入を行います。

あんまりちゃんとデバックしてないので、バグを見つけたら教えてください。

#[allow(dead_code)]
mod procon_lib {
    pub struct FixedHeap<T: Ord, const L: usize>(usize, [T; L]);

    impl<T: Ord, const L: usize> FixedHeap<T, L> {
        pub fn new() -> Self {
            Self(0, unsafe { std::mem::zeroed() })
        }

        pub fn len(&self) -> usize {
            self.0
        }

        pub fn iter(&self) -> std::slice::Iter<'_, T> {
            self.1.iter()
        }

        pub fn peek(&self) -> Option<&T> {
            if 0 < self.0 {
                Some(&self.1[0])
            } else {
                None
            }
        }

        pub fn push(&mut self, item: T) {
            if self.0 < L {
                self.up_heap(item);
            } else if item < self.1[0] {
                self.down_heap(item);
            }
        }

        fn up_heap(&mut self, item: T) {
            self.1[self.0] = item;

            let mut index = self.0;
            while index > 0 && self.1[(index - 1) >> 1] < self.1[index] {
                self.1.swap((index - 1) >> 1, index);
                index = (index - 1) >> 1;
            }

            self.0 += 1;
        }

        fn down_heap(&mut self, item: T) {
            self.1[0] = item;

            let mut index = 0;
            while (index << 1) + 1 < L {
                let left = (index << 1) + 1;
                let right = (index + 1) << 1;

                if right < self.len() && self.1[left] < self.1[right] {
                    if self.1[index] < self.1[right] {
                        self.1.swap(right, index);
                        index = right;
                    } else {
                        break;
                    }
                } else {
                    if self.1[index] < self.1[left] {
                        self.1.swap(left, index);
                        index = left;
                    } else {
                        break;
                    }
                }
            }
        }
    }

    impl<T: Ord, const L: usize> Default for FixedHeap<T, L> {
        fn default() -> Self {
            Self::new()
        }
    }

    impl<T: Ord, const L: usize> IntoIterator for FixedHeap<T, L> {
        type Item = T;
        type IntoIter = std::array::IntoIter<T, L>;
        fn into_iter(self) -> Self::IntoIter {
            self.1.into_iter()
        }
    }
}
GitHubで編集を提案

Discussion