【TypeScript】配列のある区間の合計を高速に計算する『SegmentTree』

4 min read読了の目安(約4200字

やりたいこと

こういう問題があったとします。

  1. 配列が与えられる
  2. この二つのコマンドがたくさん流れてくる
    1. 配列の要素を更新
    2. 配列のある区間の合計を求める

愚直に書く分には簡単そうですね。
でも、配列が大きかったりコマンドが多すぎると動作が遅くなってしまいます。
どうにか2-1と2-2を高速に実行できないでしょうか。

SegmentTreeの概要

SegmentTree(セグメント木)というものを紹介します。

まずこういう表があったとします。

例えばこれの3~10を足した合計を求めたいとします。

ここで、要素の下に段を追加します。
そして隣り合った数値同士を足した値を書いていきます。
ひとまず、元々の配列をレイヤー0、今追加した段をレイヤー1とします。
例えばレイヤー1の一番左のマスは3+5で8です。

これをマスの数が1になるまでピラミッド型に重ねます。

さて、3~10を足したいなら、ここを足していけばいいです。
なるべく大きなマスを見ていけばいいということです。

メリット

今回の例では4マスを足すだけで答えが求まりました。
もともとの長さは8マスなのでちょっと恩恵がわかりづらいですね。
もしこれが100万件のデータの場合20~40マス程度を足すだけであらゆる区間の合計を求めることができます。

また、同じ理屈で足し算だけでなく掛け算や区間の最大値・最小値を求めることができます。

実装

ポイント

まず、レイヤーが何段になるのかですが、
これは1,2,4,8,16……としていった時に何回目で配列の長さ以上になるかと同じです。
なので、log2(配列の長さ)で求めることができます。

処理の流れを説明します。
各マスが主人公です。
各マスは自分に問い合わせが来た場合

  • 自分のマスの担当範囲と完全に一致するなら自分のマスの数値を返す
  • 一致しないなら、自分の左上のマスと右上のマスにクエリを転送し、それを足した値を返す

という対応をします。
これを一番下の段から再帰的に繰り返します。

コード

ところで、この界隈のしきたりに習って半開区間という範囲の指定方法になっています。

  • 左端は含む
  • 右端は含まない

です。
例えば、query(2,5)は[2,3,4]を表します。
最初は違和感があると思いますが、慣れればわかりやすいです。


class SegmentTree {

    private numbers: number[][];
    private log: number; // 段の高さ

    constructor(size: number) {
        this.log = Math.ceil(Math.log2(size));
        this.numbers = new Array(this.log);
        for (let l = 0; l <= this.log; l++) {
            const m = Math.pow(2, l);
            const p = Math.ceil(size / m);
            this.numbers[l] = new Array(p);
            for (let j = 0; j < p; j++) this.numbers[l][j] = 0;
        }
    }

    public set(index: number, value: number) {
        this.numbers[0][index] = value;
        for (let layer = 1; layer <= this.log; layer++) {
            let li = Math.floor(index / Math.pow(2, layer));
            let lj1 = li * 2; // 左上のマスのindex
            let lj2 = li * 2 + 1; // 右上のマスのindex

            // 配列の長さを超えることがあるので、その場合は0とみなす
            let lv1 = lj1 >= this.numbers[layer - 1].length ? 0 : this.numbers[layer - 1][lj1];
            let lv2 = lj2 >= this.numbers[layer - 1].length ? 0 : this.numbers[layer - 1][lj2];

            this.numbers[layer][li] = lv1 + lv2;

        }
    }

    // 半開区間
    public query(start: number, end: number) {
        return this.recursive(start, end, this.log, 0);
    }


    public recursive(start: number, end: number, layer: number, index: number) {
        // この段は1マスあたり何要素を受け持つか
        const layerCellCount = Math.pow(2, layer);

        // このマスが受け持つ要素たちの右端と左端
        const layerCellStart = index * layerCellCount;
        const layerCellEnd = layerCellStart + layerCellCount;

        // マスの長さと完全に一致する場合
        if (start == layerCellStart && end == layerCellEnd) {
            return this.numbers[layer][index];
        }

        // この段の上の段は1マスあたり何要素を受け持つか
        const childLayerCellCount = Math.pow(2, layer - 1);
        // 左上のインデックス
        const leftChildIndex = index * 2;
        // 右上のインデックス
        const rightChildIndex = leftChildIndex + 1;

        // 左上・右上のマスが受け持つ要素たちの右端と左端
        const leftChildStart = leftChildIndex * childLayerCellCount;
        const leftChildEnd = leftChildStart + childLayerCellCount;
        const rightChildStart = leftChildEnd;
        const rightChildEnd = rightChildStart + childLayerCellCount;

        // 左上のマスだけに用事がある
        if (end <= leftChildEnd) {
            return this.recursive(start, end, layer - 1, leftChildIndex);
        }

        // 右上のマスだけに用事がある
        if (rightChildStart <= start) {
            return this.recursive(start, end, layer - 1, rightChildIndex);
        }

        // 左上と右上のマスを足す
        return this.recursive(start, leftChildEnd, layer - 1, leftChildIndex)
            + this.recursive(rightChildStart, end, layer - 1, rightChildIndex);
    }
}
const array: number[] = [
    3, 5, 2, 4, 1, 7, 5, 9, 5, 2, 0, 5, 9, 4,
];
const seg = new SegmentTree(array.length);
array.forEach((v, i) => seg.set(i, v));

console.log(seg.query(3, 11));