iTranslated by AI

The content below is an AI-generated translation. This is an experimental feature, and may contain errors. View original article
💡

Implementing Segment Trees in Rust

に公開

Introduction

A Segment Tree is a data structure used for efficient range queries. Range queries refer to operations such as adding a value v to elements from index i to j, multiplying by v, calculating the sum, product, maximum, or minimum within a range. A Segment Tree can perform these operations with a worst-case time complexity of O(\log n) at the cost of (n-1) additional memory.

Other useful algorithms and data structures for range queries include Prefix Sums, the Difference Array method, and Fenwick Trees (also known as Binary Indexed Trees or BITs). Among these, the Segment Tree achieves higher expressive power[1] by paying additional memory costs. Furthermore, the Lazy Segment Tree, an evolution of the Segment Tree, achieves even richer expressive power at the expense of more memory and computational cost. Choosing the right tool for the job is key.

Algorithm

The Prototypical Segment Tree

A Segment Tree is a type of complete binary tree. In Figure 1, the green area at the bottom represents the original data, while the blue and red areas form the buffer for the Segment Tree. Red signifies redundant buffer space.

Values are recorded in the buffer based on the type of range operation. For example, consider range addition. Adding 1 to an entire range naively would cost O(n), but by adding 1 to the root of the complete binary tree, it can be achieved in O(\log n). Traversing to the parent takes O(\log n). For arbitrary ranges, one can update the buffer while traversing parents, achieving O(\log n) in the worst case. To restore the i-th element, one simply adds the values up through the parents, also taking O(\log n).


Figure 1

Memory-Optimized Segment Tree

In the complete binary tree structure shown in the previous section, redundant buffers tend to grow as the data size n increases. Let's remove them. By doing so, the memory usage is restricted to (2n-1)[2].

In Figure 2, the red area represents memory regions holding invalid data[3]. While it may look broken because it is no longer even a binary tree, the "parent-traversal algorithm" works effectively to keep it functional.


Figure 2

The reality of the Segment Tree is shown in Figure 3. The numbers represent the indices of the internal Vec used in the implementation. We will now explain the "parent-traversal algorithm" based on Figure 3.

Consider a case where the range is sufficiently long. If the start of the range is a left child, the same operation will be performed on the right sibling, so we traverse to the parent. If it is a right child, we update the current value and move to the right of the parent. The same logic applies to the end of the range. By repeating this, the range width decreases until the start and end pointers meet or cross, which is the termination condition. In the former case, remember to update the value one last time.

Finally, here is how to distinguish between a left and right child based on the index. In Figure 2, a_1 is always a left child and has an odd index, because there are always an odd number of memory regions above the bottom layer. Ultimately, you can determine this based on the parity of the index.


Figure 3

When the data size is 6

The algorithm holds true in this case as well. Please check it.

Implementation

While various range operations can be implemented in a Segment Tree, only one can be active at a time. To prevent misuse, we can use std::marker::PhantomData for method access control. This allows the SegmentTree name to be reused for different types of range operations.

Common Components

src/lib.rs
pub mod common;
pub mod interval_add;
pub mod interval_sum;
src/common.rs
use std::marker::PhantomData;

pub trait IntervalOps {}

pub struct Add;
impl IntervalOps for Add {}
pub struct Sum;
impl IntervalOps for Sum {}
pub struct Mul;
impl IntervalOps for Mul {}
pub struct Gross;
impl IntervalOps for Gross {}

#[derive(Debug)]
pub struct SegmentTree<T, U>
where
    U: IntervalOps,
{
    pub(crate) inner: Vec<T>,
    pub(crate) interval_ops: PhantomData<U>,
}

impl<T, U: IntervalOps> SegmentTree<T, U> {
    pub(crate) fn par_index(index: usize) -> usize {
        (index + 1 >> 1) - 1
    }

    pub(crate) fn inner_index(&self, index: usize) -> usize {
        (self.inner.len() >> 1) + index
    }
}

Range Addition

In a range-addition Segment Tree, we record the differences in the buffer. Since the initial buffer value is zero[4], building the Segment Tree is simple. If you want to record arbitrary numerical types, you must ensure that addition and the existence of zero are guaranteed via trait bounds.

Code
src/interval_add.rs
use std::{marker::PhantomData, ops::Range};

use anyhow::ensure;
use num::Zero;

use crate::common::{Add, SegmentTree};

impl<T: Copy + std::ops::Add<Output = T> + Zero> SegmentTree<T, Add> {
    pub fn build(data: Vec<T>) -> anyhow::Result<Self> {
        let len = data.len();
        ensure!(len > 0, "the length of the given data must be more than 0");

        let mut inner = Vec::with_capacity(2 * len - 1);
        // The identity element is required for initialization
        inner.extend((1..len).map(|_| T::zero()));
        inner.extend(data.into_iter());

        Ok(Self {
            inner,
            interval_ops: PhantomData,
        })
    }

    pub fn add(&mut self, indexes: Range<usize>, value: T) {
        let mut il = self.inner_index(indexes.start);
        let mut ir = self.inner_index(indexes.end - 1);

        loop {
            match il.cmp(&ir) {
                std::cmp::Ordering::Less => {
                    // If the start of the range is a right child, record the difference and move to the inner element.
                    // These have different parents.
                    if il & 1 == 0 {
                        self.inner[il] = self.inner[il] + value;
                        il += 1;
                    }
                    // The same applies if the end of the range is a left child.
                    if ir & 1 == 1 {
                        self.inner[ir] = self.inner[ir] + value;
                        ir -= 1;
                    }
                    il = Self::par_index(il);
                    ir = Self::par_index(ir);
                }
                // If the start and end indices are the same, record the difference and exit the loop.
                std::cmp::Ordering::Equal => {
                    self.inner[il] = self.inner[il] + value;
                    break;
                }
                std::cmp::Ordering::Greater => break,
            }
        }
    }

    pub fn calc(&self, index: usize) -> Option<T> {
        let mut i = self.inner_index(index);
        if i >= self.inner.len() {
            return None;
        }

        let mut output = self.inner[i];
        while i > 0 {
            i = Self::par_index(i);
            output = output + self.inner[i]
        }

        Some(output)
    }

    // Extract the updated range. Time complexity is O(n).
    pub fn reduce(mut self) -> Vec<T> {
        for i in 0..self.inner.len() >> 1 {
            self.inner[i + 1 << 1] = self.inner[i + 1 << 1] + self.inner[i];
            self.inner[(i + 1 << 1) - 1] = self.inner[(i + 1 << 1) - 1] + self.inner[i];
        }
        self.inner[self.inner.len() >> 1..].to_owned()
    }
}
Test Code
src/interval_add.rs
#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn build_depth3() -> anyhow::Result<()> {
        assert_eq!(
            SegmentTree::<i32, Add>::build(vec![1, 2, 3, 4])?.inner,
            vec![0, 0, 0, 1, 2, 3, 4]
        );
        assert_eq!(
            SegmentTree::<i32, Add>::build(vec![1, 2, 3])?.inner,
            vec![0, 0, 1, 2, 3]
        );

        Ok(())
    }

    #[test]
    fn interval_add() -> anyhow::Result<()> {
        let mut seg_tree = SegmentTree::<i32, Add>::build(vec![1, 2, 3, 4, 5, 6])?;
        assert_eq!(seg_tree.inner, vec![0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6]);

        seg_tree.add(0..4, 2);
        assert_eq!(seg_tree.inner, vec![0, 0, 2, 2, 0, 1, 2, 3, 4, 5, 6]);

        seg_tree.add(4..5, -2);
        assert_eq!(seg_tree.inner, vec![0, 0, 2, 2, 0, 1, 2, 3, 4, 3, 6]);

        seg_tree.add(0..6, 10);
        assert_eq!(seg_tree.inner, vec![0, 10, 12, 2, 0, 1, 2, 3, 4, 3, 6]);

        Ok(())
    }

    #[test]
    fn calc() -> anyhow::Result<()> {
        let mut seg_tree = SegmentTree::<i32, Add>::build(vec![0, 1, 2, 3, 4])?;
        seg_tree.add(0..1, 2);
        seg_tree.add(3..5, 3);
        seg_tree.add(2..5, -9);

        assert_eq!(seg_tree.calc(0), Some(2));
        assert_eq!(seg_tree.calc(1), Some(1));
        assert_eq!(seg_tree.calc(2), Some(-7));
        assert_eq!(seg_tree.calc(3), Some(-3));
        assert_eq!(seg_tree.calc(4), Some(-2));

        Ok(())
    }

    #[test]
    fn reduce() -> anyhow::Result<()> {
        let mut seg_tree = SegmentTree::<i32, Add>::build(vec![0, 1, 2, 3, 4])?;
        seg_tree.add(0..1, 2);
        seg_tree.add(3..5, 3);
        seg_tree.add(2..5, -9);

        assert_eq!(seg_tree.reduce(), vec![2, 1, -7, -3, -2]);

        Ok(())
    }
}

Range Sum

In a Segment Tree corresponding to range sums, we record the sum of the children's values in the buffer.

Range sums can also be implemented with a Fenwick Tree. Since it uses about half the memory of a Segment Tree, the Fenwick Tree is more advantageous.

Code
src/interval_sum.rs
use std::{marker::PhantomData, ops::Range};

use anyhow::ensure;
use num::Zero;

use crate::common::{SegmentTree, Sum};

impl<T: Copy + std::ops::Add<Output = T> + Zero> SegmentTree<T, Sum> {
    pub fn build(data: Vec<T>) -> anyhow::Result<Self> {
        let len = data.len();
        ensure!(len > 0, "the length of the given data must be more than 0");

        let mut inner = Vec::with_capacity(2 * len - 1);
        inner.extend((1..len).map(|_| data[0])); // Dummy value
        inner.extend(data.into_iter());
        // Update dummy data
        for i in (1..len).rev() {
            inner[i - 1] = inner[2 * i - 1] + inner[2 * i]
        }

        Ok(Self {
            inner,
            interval_ops: PhantomData,
        })
    }

    pub fn add(&mut self, index: usize, value: T) {
        let mut i = self.inner_index(index);
        self.inner[i] = self.inner[i] + value;

        while i > 0 {
            i = Self::par_index(i);
            self.inner[i] = self.inner[i] + value;
        }
    }

    pub fn sum(&self, indexes: Range<usize>) -> T {
        let mut il = self.inner_index(indexes.start);
        let mut ir = self.inner_index(indexes.end).min(self.inner.len()) - 1;

        // An identity element is required.
        let mut output = T::zero();

        loop {
            match il.cmp(&ir) {
                std::cmp::Ordering::Less => {
                    if il & 1 == 0 {
                        output = output + self.inner[il];
                        il += 1;
                    }
                    if ir & 1 == 1 {
                        output = output + self.inner[ir];
                        ir -= 1;
                    }
                    il = Self::par_index(il);
                    ir = Self::par_index(ir);
                }
                std::cmp::Ordering::Equal => {
                    output = output + self.inner[il];
                    break;
                }
                std::cmp::Ordering::Greater => break,
            }
        }

        output
    }

    pub fn reduce(self) -> Vec<T> {
        self.inner[self.inner.len() >> 1..].to_owned()
    }
}

Conclusion

Thank you for reading this far. Please let me know in the comments if there are any shortcomings. I am also looking forward to your ideas and information.

脚注
  1. meaning a wider variety of realizable range operations ↩︎

  2. This corresponds to the special case where the memory size of a complete binary tree is exactly 2n-1. ↩︎

  3. values that remain as initialized; for range addition, the initial value is the additive identity, 0 ↩︎

  4. the additive identity ↩︎

Discussion