👋

RustのBinaryHeapを少しだけ使いやすくした

に公開

背景

Rustの標準ライブラリにあるちょっとマイナーなコンテナのBinaryHeapは優先度付きキューを実現するもので、C++のpriority_queueやPythonのheapqに相当するものです。
Rustに限らずそれらの言語でも共通することですが、入れるアイテム自体が順序付可能であること(RustでいうOrd)が要求されていて任意の構造体をスコアと一緒に入れることができなくて個人的にはちょっと使いにくく感じています。
そこでスコア関数を含めて少しだけ使いやすくした構造体を作ってみました。

成果物

実装方式

ヘルパー構造体

まず任意の構造体とスコアをまとめるヘルパー構造体を作ります。外部には公開しないのでcrateレベルのスコープにします。

pub(crate) struct ScoredItem<T, S: Ord + Copy> {
    pub(crate) score: S,
    pub(crate) item: T,
}

Copyは本当は不要なのですが、スコアとしては基本的には数値型を想定しているのでつけておいた方がいろいろ便利になります。

そしてこれにscoreのみを使って比較するようにOrdを実装します。

impl<T, S: Ord + Copy> PartialEq for ScoredItem<T, S> {
    fn eq(&self, other: &Self) -> bool {
        self.score == other.score
    }
}

impl<T, S: Ord + Copy> Eq for ScoredItem<T, S> {}

impl<T, S: Ord + Copy> Ord for ScoredItem<T, S> {
    fn cmp(&self, other: &Self) -> Ordering {
        self.score.cmp(&other.score)
    }
}

impl<T, S: Ord + Copy> PartialOrd for ScoredItem<T, S> {
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
        Some(self.cmp(other))
    }
}

Queue本体

本体は単純に標準ライブラリのBinaryHeapと評価関数のトレイとオブジェクトをまとめたものです。評価関数はもともとのデータの参照&Tを受け取っってスコアSに変換するものです。

pub struct PriorityQueue<T, S: Ord + Copy> {
    heap: BinaryHeap<ScoredItem<T, S>>,
    score_fn: Box<dyn Fn(&T) -> S>,
}

そしてデータTをpushする際には先に評価関数でスコアを計算し、ScoredItemに格納してから入れます。popやpeekする際にはScoredItemをただのタプルにばらします。

impl<T, S: Ord + Copy> PriorityQueue<T, S> {
    pub fn new(score_fn: Box<dyn Fn(&T) -> S>) -> Self {
        let heap = BinaryHeap::new();
        Self { heap, score_fn }
    }
    
    pub fn push(&mut self, item: T) {
        let score = (self.score_fn)(&item);
        self.push_with_score(item, score);
    }
    
    pub fn push_with_score(&mut self, item: T, score: S) {
        self.heap.push(ScoredItem { item, score });
    }
    
    pub fn peek(&self) -> Option<(S, &T)> {
        self.heap
            .peek()
            .map(|scored_item| (scored_item.score, &scored_item.item))
    }

    pub fn pop(&mut self) -> Option<(S, T)> {
        self.heap
            .pop()
            .map(|scored_item| (scored_item.score, scored_item.item))
    }
}

以上で実装完了です。以下は使用例です。文字列を長い順に取り出せるQueueです。

// a score function that returns the length of a string
let score_fn = Box::new(|s: &String| s.len());
// create a new priority queue with the score function
let mut queue = PriorityQueue::new(score_fn);

// the queue is empty at the beginning
assert!(queue.peek().is_none());

// push some items into the queue
// the score function is used to calculate the score of each item
queue.push("a".to_string()); // score = 1
queue.push("ccc".to_string()); // score = 3
queue.push("bb".to_string()); // score = 2

// you can also push an item with a explicit score
queue.push_with_score("b".to_string(), 10); // score = 10

// peek the item with the highest priority
assert_eq!(queue.peek(), Some((10, &"b".to_string())));

// pop the item with the highest priority
assert_eq!(queue.pop(), Some((10, "b".to_string())));
assert_eq!(queue.pop(), Some((3, "ccc".to_string())));
assert_eq!(queue.pop(), Some((2, "bb".to_string())));
assert_eq!(queue.pop(), Some((1, "a".to_string())));

短い順にしたい場合はスコア関数の数値をマイナスにするか、あるいはstd::cmd::Reverseを使用してください。

// you can also use a reverse order
let score_fn = Box::new(|s: &String| Reverse(s.len()));
let mut queue = PriorityQueue::new(score_fn);

Discussion