iTranslated by AI
Implementing Segment Trees in Rust
Introduction
A Segment Tree is a data structure used for efficient range queries. Range queries refer to operations such as adding a value
Other useful algorithms and data structures for range queries include Prefix Sums, the Difference Array method, and Fenwick Trees (also known as Binary Indexed Trees or BITs). Among these, the Segment Tree achieves higher expressive power[1] by paying additional memory costs. Furthermore, the Lazy Segment Tree, an evolution of the Segment Tree, achieves even richer expressive power at the expense of more memory and computational cost. Choosing the right tool for the job is key.
Algorithm
The Prototypical Segment Tree
A Segment Tree is a type of complete binary tree. In Figure 1, the green area at the bottom represents the original data, while the blue and red areas form the buffer for the Segment Tree. Red signifies redundant buffer space.
Values are recorded in the buffer based on the type of range operation. For example, consider range addition. Adding 1 to an entire range naively would cost

Figure 1
Memory-Optimized Segment Tree
In the complete binary tree structure shown in the previous section, redundant buffers tend to grow as the data size
In Figure 2, the red area represents memory regions holding invalid data[3]. While it may look broken because it is no longer even a binary tree, the "parent-traversal algorithm" works effectively to keep it functional.

Figure 2
The reality of the Segment Tree is shown in Figure 3. The numbers represent the indices of the internal Vec used in the implementation. We will now explain the "parent-traversal algorithm" based on Figure 3.
Consider a case where the range is sufficiently long. If the start of the range is a left child, the same operation will be performed on the right sibling, so we traverse to the parent. If it is a right child, we update the current value and move to the right of the parent. The same logic applies to the end of the range. By repeating this, the range width decreases until the start and end pointers meet or cross, which is the termination condition. In the former case, remember to update the value one last time.
Finally, here is how to distinguish between a left and right child based on the index. In Figure 2,

Figure 3
When the data size is 6
The algorithm holds true in this case as well. Please check it.

Implementation
While various range operations can be implemented in a Segment Tree, only one can be active at a time. To prevent misuse, we can use std::marker::PhantomData for method access control. This allows the SegmentTree name to be reused for different types of range operations.
Common Components
pub mod common;
pub mod interval_add;
pub mod interval_sum;
use std::marker::PhantomData;
pub trait IntervalOps {}
pub struct Add;
impl IntervalOps for Add {}
pub struct Sum;
impl IntervalOps for Sum {}
pub struct Mul;
impl IntervalOps for Mul {}
pub struct Gross;
impl IntervalOps for Gross {}
#[derive(Debug)]
pub struct SegmentTree<T, U>
where
U: IntervalOps,
{
pub(crate) inner: Vec<T>,
pub(crate) interval_ops: PhantomData<U>,
}
impl<T, U: IntervalOps> SegmentTree<T, U> {
pub(crate) fn par_index(index: usize) -> usize {
(index + 1 >> 1) - 1
}
pub(crate) fn inner_index(&self, index: usize) -> usize {
(self.inner.len() >> 1) + index
}
}
Range Addition
In a range-addition Segment Tree, we record the differences in the buffer. Since the initial buffer value is zero[4], building the Segment Tree is simple. If you want to record arbitrary numerical types, you must ensure that addition and the existence of zero are guaranteed via trait bounds.
Code
use std::{marker::PhantomData, ops::Range};
use anyhow::ensure;
use num::Zero;
use crate::common::{Add, SegmentTree};
impl<T: Copy + std::ops::Add<Output = T> + Zero> SegmentTree<T, Add> {
pub fn build(data: Vec<T>) -> anyhow::Result<Self> {
let len = data.len();
ensure!(len > 0, "the length of the given data must be more than 0");
let mut inner = Vec::with_capacity(2 * len - 1);
// The identity element is required for initialization
inner.extend((1..len).map(|_| T::zero()));
inner.extend(data.into_iter());
Ok(Self {
inner,
interval_ops: PhantomData,
})
}
pub fn add(&mut self, indexes: Range<usize>, value: T) {
let mut il = self.inner_index(indexes.start);
let mut ir = self.inner_index(indexes.end - 1);
loop {
match il.cmp(&ir) {
std::cmp::Ordering::Less => {
// If the start of the range is a right child, record the difference and move to the inner element.
// These have different parents.
if il & 1 == 0 {
self.inner[il] = self.inner[il] + value;
il += 1;
}
// The same applies if the end of the range is a left child.
if ir & 1 == 1 {
self.inner[ir] = self.inner[ir] + value;
ir -= 1;
}
il = Self::par_index(il);
ir = Self::par_index(ir);
}
// If the start and end indices are the same, record the difference and exit the loop.
std::cmp::Ordering::Equal => {
self.inner[il] = self.inner[il] + value;
break;
}
std::cmp::Ordering::Greater => break,
}
}
}
pub fn calc(&self, index: usize) -> Option<T> {
let mut i = self.inner_index(index);
if i >= self.inner.len() {
return None;
}
let mut output = self.inner[i];
while i > 0 {
i = Self::par_index(i);
output = output + self.inner[i]
}
Some(output)
}
// Extract the updated range. Time complexity is O(n).
pub fn reduce(mut self) -> Vec<T> {
for i in 0..self.inner.len() >> 1 {
self.inner[i + 1 << 1] = self.inner[i + 1 << 1] + self.inner[i];
self.inner[(i + 1 << 1) - 1] = self.inner[(i + 1 << 1) - 1] + self.inner[i];
}
self.inner[self.inner.len() >> 1..].to_owned()
}
}
Test Code
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn build_depth3() -> anyhow::Result<()> {
assert_eq!(
SegmentTree::<i32, Add>::build(vec![1, 2, 3, 4])?.inner,
vec![0, 0, 0, 1, 2, 3, 4]
);
assert_eq!(
SegmentTree::<i32, Add>::build(vec![1, 2, 3])?.inner,
vec![0, 0, 1, 2, 3]
);
Ok(())
}
#[test]
fn interval_add() -> anyhow::Result<()> {
let mut seg_tree = SegmentTree::<i32, Add>::build(vec![1, 2, 3, 4, 5, 6])?;
assert_eq!(seg_tree.inner, vec![0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6]);
seg_tree.add(0..4, 2);
assert_eq!(seg_tree.inner, vec![0, 0, 2, 2, 0, 1, 2, 3, 4, 5, 6]);
seg_tree.add(4..5, -2);
assert_eq!(seg_tree.inner, vec![0, 0, 2, 2, 0, 1, 2, 3, 4, 3, 6]);
seg_tree.add(0..6, 10);
assert_eq!(seg_tree.inner, vec![0, 10, 12, 2, 0, 1, 2, 3, 4, 3, 6]);
Ok(())
}
#[test]
fn calc() -> anyhow::Result<()> {
let mut seg_tree = SegmentTree::<i32, Add>::build(vec![0, 1, 2, 3, 4])?;
seg_tree.add(0..1, 2);
seg_tree.add(3..5, 3);
seg_tree.add(2..5, -9);
assert_eq!(seg_tree.calc(0), Some(2));
assert_eq!(seg_tree.calc(1), Some(1));
assert_eq!(seg_tree.calc(2), Some(-7));
assert_eq!(seg_tree.calc(3), Some(-3));
assert_eq!(seg_tree.calc(4), Some(-2));
Ok(())
}
#[test]
fn reduce() -> anyhow::Result<()> {
let mut seg_tree = SegmentTree::<i32, Add>::build(vec![0, 1, 2, 3, 4])?;
seg_tree.add(0..1, 2);
seg_tree.add(3..5, 3);
seg_tree.add(2..5, -9);
assert_eq!(seg_tree.reduce(), vec![2, 1, -7, -3, -2]);
Ok(())
}
}
Range Sum
In a Segment Tree corresponding to range sums, we record the sum of the children's values in the buffer.
Range sums can also be implemented with a Fenwick Tree. Since it uses about half the memory of a Segment Tree, the Fenwick Tree is more advantageous.
Code
use std::{marker::PhantomData, ops::Range};
use anyhow::ensure;
use num::Zero;
use crate::common::{SegmentTree, Sum};
impl<T: Copy + std::ops::Add<Output = T> + Zero> SegmentTree<T, Sum> {
pub fn build(data: Vec<T>) -> anyhow::Result<Self> {
let len = data.len();
ensure!(len > 0, "the length of the given data must be more than 0");
let mut inner = Vec::with_capacity(2 * len - 1);
inner.extend((1..len).map(|_| data[0])); // Dummy value
inner.extend(data.into_iter());
// Update dummy data
for i in (1..len).rev() {
inner[i - 1] = inner[2 * i - 1] + inner[2 * i]
}
Ok(Self {
inner,
interval_ops: PhantomData,
})
}
pub fn add(&mut self, index: usize, value: T) {
let mut i = self.inner_index(index);
self.inner[i] = self.inner[i] + value;
while i > 0 {
i = Self::par_index(i);
self.inner[i] = self.inner[i] + value;
}
}
pub fn sum(&self, indexes: Range<usize>) -> T {
let mut il = self.inner_index(indexes.start);
let mut ir = self.inner_index(indexes.end).min(self.inner.len()) - 1;
// An identity element is required.
let mut output = T::zero();
loop {
match il.cmp(&ir) {
std::cmp::Ordering::Less => {
if il & 1 == 0 {
output = output + self.inner[il];
il += 1;
}
if ir & 1 == 1 {
output = output + self.inner[ir];
ir -= 1;
}
il = Self::par_index(il);
ir = Self::par_index(ir);
}
std::cmp::Ordering::Equal => {
output = output + self.inner[il];
break;
}
std::cmp::Ordering::Greater => break,
}
}
output
}
pub fn reduce(self) -> Vec<T> {
self.inner[self.inner.len() >> 1..].to_owned()
}
}
Conclusion
Thank you for reading this far. Please let me know in the comments if there are any shortcomings. I am also looking forward to your ideas and information.
Discussion