💡

Segment Tree

2024/02/24に公開

はじめに

セグメント木(Segment Tree)は区間演算を効率よく行うためのデータ構造の1つです。区間演算とは、「配列のi番目からj番目までの要素についてvを加える・vを掛ける・総和をとる・総積をとる・最大値をとる・最小値をとる」などの操作のことを言います。セグメント木は(n-1)メモリ分のコストを支払うことでこれらの操作を最悪計算量O(\log n)実現することができます。

区間演算に有用なアルゴリズム・データ構造には他にも累積和・いもす法・フェニック木(Fenwick Tree)[1]などがあります。その中でもセグメント木は追加のメモリコストを支払うことで高い表現能力[2]を実現しています。また、セグメント木を発展させた遅延セグメント木(Lazy Segment Tree)では、追加のメモリ・計算コストを支払うことでより豊かな表現能力を実現しています。使い分けが肝心ということです。

アルゴリズム編

セグメント木の原型

セグメント木は完全二分木の1つです。図1の最下段にある緑色の領域がオリジナルのデータを表現しています。青色と赤色に塗られた領域はセグメント木を構成するためのバッファです。無駄なバッファを赤色で表現しています。

バッファには区間演算の種類に応じて値が記録されます。例として、区間加算について考えてみましょう。区間全域に1を加える場合、各要素を愚直に更新していけばO(n)の計算コストがかかりますが、代わりに完全二分木の根に1を加えるとO(\log n)で実現できます。親をたどっていくのにO(\log n)かかります。任意の区間の場合も親をたどりながらバッファを更新していけばよいので最悪O(\log n)で実現できます。i番目の要素を復元するには親の値を加算していけばよいのでO(\log n)で実現できます。


図1

メモリを節約したセグメント木

前節で紹介した完全二分木構造のセグメント木では、データサイズnが大きくなると無駄なバッファも大きくなりがちです。消しちゃいましょう。このとき、使用メモリは(2n-1)に抑えられます[3]

図2の赤色の領域は不正なデータ[4]を保持しているメモリ領域を表しています。二分木ですらなくなり一見破綻しているように見えますが、親をたどるアルゴリズムがいい感じに働いて、うまく機能します。


図2

セグメント木の実態は図3のようになります。数字はセグメント木の内部実装であるVecのインデックスを表しています。図3をもとに「親をたどるアルゴリズム」を解説していきます。

区間が十分長い時を考えます。区間の始点が左の子であれば右の子にも同じ操作がされるので親をたどります。右の子であれば自身の値を更新して親の右に行きます。区間の終点も同様です。この操作を繰り返していけば区間の幅が小さくなり、やがて始点と終点が一致するか入れ替わります。これがこのアルゴリズムの終了条件です。前者の場合、最後に値を更新しておくのを忘れないようにしましょう。

最後にインデックスからそれが右の子か左の子かを判別する方法を紹介します。図2においてa_1は必ず左の子であり、インデックスは奇数です。最下段の上に必ず奇数個のメモリ領域が載っているからです。結局、インデックスの偶奇から判定できます。


図3

データサイズが6のとき

このときも上記のアルゴリズムは成り立ちます。確かめてみてください。

実装編

セグメント木には様々な区間演算を実装できますが、同時に実装できるのは1つまでです。誤った使い方を防ぐためにstd::marker::PhantomDataを活用してメソッドのアクセス制御をしましょう。こうすることで、SegmentTreeという名前を様々な区間演算について使いまわすことができます。

共通部分

src/lib.rs
pub mod common;
pub mod interval_add;
pub mod interval_sum;
src/common.rs
use std::marker::PhantomData;

pub trait IntervalOps {}

pub struct Add;
impl IntervalOps for Add {}
pub struct Sum;
impl IntervalOps for Sum {}
pub struct Mul;
impl IntervalOps for Mul {}
pub struct Gross;
impl IntervalOps for Gross {}

#[derive(Debug)]
pub struct SegmentTree<T, U>
where
    U: IntervalOps,
{
    pub(crate) inner: Vec<T>,
    pub(crate) interval_ops: PhantomData<U>,
}

impl<T, U: IntervalOps> SegmentTree<T, U> {
    pub(crate) fn par_index(index: usize) -> usize {
        (index + 1 >> 1) - 1
    }

    pub(crate) fn inner_index(&self, index: usize) -> usize {
        (self.inner.len() >> 1) + index
    }
}

区間加算

区間加算のセグメント木ではバッファに差分を記録します。バッファの初期値はゼロ[5]なので、セグメント木の構築は簡単です。任意の数値型を記録したい場合、トレイト境界で加法が定義されていることとゼロを持つことを保証する必要があります。

コード
src/interval_add.rs
use std::{marker::PhantomData, ops::Range};

use anyhow::ensure;
use num::Zero;

use crate::common::{Add, SegmentTree};

impl<T: Copy + std::ops::Add<Output = T> + Zero> SegmentTree<T, Add> {
    pub fn build(data: Vec<T>) -> anyhow::Result<Self> {
        let len = data.len();
        ensure!(len > 0, "the length of the given data must be more than 0");

        let mut inner = Vec::with_capacity(2 * len - 1);
        // 初期化に単位元が必要
        inner.extend((1..len).map(|_| T::zero()));
        inner.extend(data.into_iter());

        Ok(Self {
            inner,
            interval_ops: PhantomData,
        })
    }

    pub fn add(&mut self, indexes: Range<usize>, value: T) {
        let mut il = self.inner_index(indexes.start);
        let mut ir = self.inner_index(indexes.end - 1);

        loop {
            match il.cmp(&ir) {
                std::cmp::Ordering::Less => {
                    // 区間の始点が右の子であれば、差分を記録して内側の要素に移る。
                    // これらは異なる親を持つ。
                    if il & 1 == 0 {
                        self.inner[il] = self.inner[il] + value;
                        il += 1;
                    }
                    // 区間の終端が左の子である場合も同様。
                    if ir & 1 == 1 {
                        self.inner[ir] = self.inner[ir] + value;
                        ir -= 1;
                    }
                    il = Self::par_index(il);
                    ir = Self::par_index(ir);
                }
                // 区間の始点と終点のインデックスが同じ場合、差分を記録してからループを抜ける。
                std::cmp::Ordering::Equal => {
                    self.inner[il] = self.inner[il] + value;
                    break;
                }
                std::cmp::Ordering::Greater => break,
            }
        }
    }

    pub fn calc(&self, index: usize) -> Option<T> {
        let mut i = self.inner_index(index);
        if i >= self.inner.len() {
            return None;
        }

        let mut output = self.inner[i];
        while i > 0 {
            i = Self::par_index(i);
            output = output + self.inner[i]
        }

        Some(output)
    }

    // 更新後の区間を取り出す。計算量はO(n)。
    pub fn reduce(mut self) -> Vec<T> {
        for i in 0..self.inner.len() >> 1 {
            self.inner[i + 1 << 1] = self.inner[i + 1 << 1] + self.inner[i];
            self.inner[(i + 1 << 1) - 1] = self.inner[(i + 1 << 1) - 1] + self.inner[i];
        }
        self.inner[self.inner.len() >> 1..].to_owned()
    }
}
テストコード
src/interval_add.rs
#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn build_depth3() -> anyhow::Result<()> {
        assert_eq!(
            SegmentTree::<i32, Add>::build(vec![1, 2, 3, 4])?.inner,
            vec![0, 0, 0, 1, 2, 3, 4]
        );
        assert_eq!(
            SegmentTree::<i32, Add>::build(vec![1, 2, 3])?.inner,
            vec![0, 0, 1, 2, 3]
        );

        Ok(())
    }

    #[test]
    fn interval_add() -> anyhow::Result<()> {
        let mut seg_tree = SegmentTree::<i32, Add>::build(vec![1, 2, 3, 4, 5, 6])?;
        assert_eq!(seg_tree.inner, vec![0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6]);

        seg_tree.add(0..4, 2);
        assert_eq!(seg_tree.inner, vec![0, 0, 2, 2, 0, 1, 2, 3, 4, 5, 6]);

        seg_tree.add(4..5, -2);
        assert_eq!(seg_tree.inner, vec![0, 0, 2, 2, 0, 1, 2, 3, 4, 3, 6]);

        seg_tree.add(0..6, 10);
        assert_eq!(seg_tree.inner, vec![0, 10, 12, 2, 0, 1, 2, 3, 4, 3, 6]);

        Ok(())
    }

    #[test]
    fn calc() -> anyhow::Result<()> {
        let mut seg_tree = SegmentTree::<i32, Add>::build(vec![0, 1, 2, 3, 4])?;
        seg_tree.add(0..1, 2);
        seg_tree.add(3..5, 3);
        seg_tree.add(2..5, -9);

        assert_eq!(seg_tree.calc(0), Some(2));
        assert_eq!(seg_tree.calc(1), Some(1));
        assert_eq!(seg_tree.calc(2), Some(-7));
        assert_eq!(seg_tree.calc(3), Some(-3));
        assert_eq!(seg_tree.calc(4), Some(-2));

        Ok(())
    }

    #[test]
    fn reduce() -> anyhow::Result<()> {
        let mut seg_tree = SegmentTree::<i32, Add>::build(vec![0, 1, 2, 3, 4])?;
        seg_tree.add(0..1, 2);
        seg_tree.add(3..5, 3);
        seg_tree.add(2..5, -9);

        assert_eq!(seg_tree.reduce(), vec![2, 1, -7, -3, -2]);

        Ok(())
    }
}

区間の総和

区間の総和に対応したセグメント木ではバッファに子の値の和を記録します。

区間の総和はフェニック木でも実装できます。セグメント木の約半分のメモリしか使用しないため、フェニック木の方が有利です。

コード
src/interval_sum.rs
use std::{marker::PhantomData, ops::Range};

use anyhow::ensure;
use num::Zero;

use crate::common::{SegmentTree, Sum};

impl<T: Copy + std::ops::Add<Output = T> + Zero> SegmentTree<T, Sum> {
    pub fn build(data: Vec<T>) -> anyhow::Result<Self> {
        let len = data.len();
        ensure!(len > 0, "the length of the given data must be more than 0");

        let mut inner = Vec::with_capacity(2 * len - 1);
        inner.extend((1..len).map(|_| data[0])); // ダミーの値
        inner.extend(data.into_iter());
        // ダミーデータの更新
        for i in (1..len).rev() {
            inner[i - 1] = inner[2 * i - 1] + inner[2 * i]
        }

        Ok(Self {
            inner,
            interval_ops: PhantomData,
        })
    }

    pub fn add(&mut self, index: usize, value: T) {
        let mut i = self.inner_index(index);
        self.inner[i] = self.inner[i] + value;

        while i > 0 {
            i = Self::par_index(i);
            self.inner[i] = self.inner[i] + value;
        }
    }

    pub fn sum(&self, indexes: Range<usize>) -> T {
        let mut il = self.inner_index(indexes.start);
        let mut ir = self.inner_index(indexes.end).min(self.inner.len()) - 1;

        // 単位元が必要。
        let mut output = T::zero();

        loop {
            match il.cmp(&ir) {
                std::cmp::Ordering::Less => {
                    if il & 1 == 0 {
                        output = output + self.inner[il];
                        il += 1;
                    }
                    if ir & 1 == 1 {
                        output = output + self.inner[ir];
                        ir -= 1;
                    }
                    il = Self::par_index(il);
                    ir = Self::par_index(ir);
                }
                std::cmp::Ordering::Equal => {
                    output = output + self.inner[il];
                    break;
                }
                std::cmp::Ordering::Greater => break,
            }
        }

        output
    }

    pub fn reduce(self) -> Vec<T> {
        self.inner[self.inner.len() >> 1..].to_owned()
    }
}

おわりに

ここまで読んでいただきありがとうございます。至らない点があればコメント等でお知らせください。アイデアや情報もお待ちしています。

脚注
  1. BIT(Binary Indexed Tree)と呼ばれることもあります。 ↩︎

  2. 実現可能な区間演算の種類が豊富という意味です。 ↩︎

  3. 完全二分木のメモリサイズは2n-1の特別な場合に相当します。 ↩︎

  4. 初期値のまま変わりません。区間加算の場合、初期値は単位元の1です。 ↩︎

  5. 加算の単位元 ↩︎

Discussion