🤖

Rustで多段Step関数

2022/10/16に公開

概要

xがx0を超えてx1以下の時にy=y0となるような計算はよく目にします。例えば宅配では商品のサイズによって料金が何段階かになっており、モバイル通信料にも従量課金プランがあって1GBまでは600円、3GBまでは1000円、10GBまでは1500円、それ以降は2000円というような計算があります。今回はこういった計算を扱うデータ構造をRustで作ってみました。

データ構造

pub struct ThresholdDict<K, V> {
    keys: Vec<K>,
    values: Vec<V>,
    default_value: V,
}

keys は上の説明でいうとxであり、valuesがyです。keys[i] < x <= keys[i+1]の時にy = values[i] になります。そしてxkeysの最後の要素よりも大きい場合はdefault_valueがyの結果になります。keysは小さい順に、valuesもそれに対応してソートされている必要がありますので以下のようにコンストラクタを実装します。

impl<K: PartialOrd, V> ThresholdDict<K, V> {
    /// default constructor
    pub fn new(mut kv: Vec<(K, V)>, default_value: V) -> Self {
        kv.sort_by(|lhs, rhs| lhs.0.partial_cmp(&rhs.0).unwrap());
        let mut keys = vec![];
        let mut values = vec![];
        for (k, v) in kv {
            keys.push(k);
            values.push(v);
        }
        Self {
            keys,
            values,
            default_value,
        }
    }
}

問い合わせアルゴリズム

単純に線形探索をしてもいいのですが、データが多い時には二分探索がO(logN)で高速になります。RustのVecにはpartition_pointというメソッドがあり、ある値以下の最大の点を探すことができますのでこれを使用します。

impl<K: PartialOrd, V> ThresholdDict<K, V> {
    pub fn query(&self, key: &K) -> &V {
        if self.keys.is_empty() {
            return &self.default_value;
        }
        
        let i = self.keys.partition_point(|x| x < key);
        if i == self.keys.len() {
            return &self.default_value;
        }
        self.values.get(i).unwrap()
    }
}

試しに使用してみます。

let dict = ThresholdDict::new(vec![(10, 100), (20, 150), (50, 300)], 500);

assert_eq!(dict.query(&0), &100);
assert_eq!(dict.query(&10), &100);
assert_eq!(dict.query(&15), &150);
assert_eq!(dict.query(&50), &300);
assert_eq!(dict.query(&60), &500);

まとめ

Vecの機能を利用することでとても簡単に実装することができました。成果物はGitHubとcrates.ioに公開してあります。公開しているものは実際にはkeysの長さに閾値を設けて線形探索と二分探索を切り替えられるようにしてあります。

https://github.com/lucidfrontier45/threshold_dict
https://crates.io/crates/threshold-dict

Discussion