🐷

Kickstart 2020 RoundC: Candies

2021/08/11に公開約4,800字

https://codingcompetitions.withgoogle.com/kickstart/round/000000000019ff43/0000000000337b4d

問題

N個の要素を持つ配列が与えられる。i番目の要素A_iはcandyの甘さを表す値とする。また、以下の二つの操作を定義する。

  1. Update
    • A_iの値を更新する
  2. Query
    • subarrayの合計甘さスコアを問い合わせ
    • スコアは、indexがlからrまでの場合、A_l \times 1 - A_{l+1} \times 2 + A_{l+2} \times 3 - A_{l+3} \times 4 + A_{l+4} \times 5 ...と計算
    • つまり、(-1)^{i-l} A_{i} \times (i - l + 1)でiがlからrまでのsumを求める

合計Q回の操作を行い、そのうち全てのQueryのsumを求める問題。Queryの回数が0の時は0となる。

例えば次のケースの場合、最初のQueryは、3 \times 1 - 9 \times 2 + 8 \times 3 = 9、次のQueryは、2 \times 1 = 2、最後のQueryは[1,10,9,8,2]にUpdateした後なので、1 \times 1 - 10 \times 2 = -19で、合計は、9 + 2 - 19 = -8となる。

1 3 9 8 2
Q 2 4
Q 5 5
U 2 10
Q 1 2

アルゴリズム

Queryのsumを都度求める場合、N個の要素に対するQ回の計算になり、1 \leq N \leq 2 \times 10^51 \leq Q \leq 10^5という条件においては、もっと効率的な解法が必要になる。都度計算するのではなく、事前にsumを計算しておいて、指定された範囲の合計を計算するということを考える。事前に計算しておいたsumから、任意の範囲の合計を計算するのは、公式の解説にある図を見ると分かりやすい。この例では、A = [5, 2, 7, 4, 6, 3, 9, 1, 8]に対して、Q 1 9の計算を図示している。さらに、例えばQ 5 8のsumは図の青い領域に相当し、これは黒い枠線の領域に相当するsumから、オレンジの領域と灰色の領域を引いたものになることが分かる。ただし、範囲の左境界がevenの場合は、それぞれの領域の計算結果を反転させる必要がある。

このような計算をするために事前に求めておくsumを二種類用意する。一つが階段状のsumを求めるMultiple Prefix Sum(MS)で、\rm{MS(0)=0 \ \text{and} \ MS(i)=(-1)^{i-1} A_i \times i + MS(i-1) \ \text{for} \ i \geq 1}と定義する。もう一つが、通常のPrefix Sum(S)で、\rm{S(0)=0 \ \text{and} \ S(i)=(-1)^{i-1} A_i + S(i-1) \ \text{for} \ i \geq 1}と定義する。この二種類のsumを使って、Q l rを求めるには、\rm{(-1)^{l-1} (MS(r) - MS(l-1) - (l-1) \times (S(r) - S(l-1))}と計算する。Q 5 8で考えると、\rm{MS(r)}が黒い枠線の領域、\rm{MS(l-1)}がオレンジの領域、\rm{(l-1) \times (S(r) - S(l-1))}が灰色の領域に相当する。このアルゴリズムでは、二種類のPrefix Sumを求めるのにそれぞれO(N)、各クエリの処理はconstantに実行できるので、Test set1のUpdateの回数が高々5回ということであれば、全体としてO(N+Q)の計算量となる。しかし、Test set2ではUpdateの回数制限がないので、Updateで何度もPrefix Sumの再計算が行われてしまい、O(NQ)の計算量となってしまう。

https://github.com/satojkovic/algorithms/blob/3d2b27d39c66773ed9d9f2483b6c9f690c3fe748/kickstart/2020/RoundC/candies.py

Test set2をPassするために、sumの事前計算のデータ構造としてSegment Treeを利用する。MSに対応するSegment TreeをMT、Sに対応するSegment TreeをSTとすると、\rm{(-1)^{l-1} (MT(l,r)-(l-1)\times ST(l,r))}として計算できる。N個の配列からSegment Treeを構築するのにO(N)、UpdateとQueryの処理はそれぞれO(logN)で、全体としてO(N + QlogN)となり、Test set2にもPassすることができます。(Segment Treeは0-indexedで実装しているので、MTとSTを使うときにインデックスを渡すときだけ-1している)

import math


class SegmentTreeSum:
    def __init__(self, arr=None):
        if arr:
            self.build(arr)

    def build(self, arr):
        def _build(curr, arr, left, right):
            if left == right:
                self.data[curr] = arr[left]
            else:
                mid = (left + right) // 2
                _build(2 * curr + 1, arr, left, mid)
                _build(2 * curr + 2, arr, mid + 1, right)
                self.data[curr] = self.data[2 * curr + 1] + \
                    self.data[2 * curr + 2]
        self.n = len(arr)
        x = 2 ** math.ceil(math.log2(self.n))
        self.data = [0] * (2 * x - 1)
        _build(0, arr, 0, self.n - 1)

    def update(self, i, x):
        def _update(i, x, curr, left, right):
            if left == right:
                self.data[curr] = x
            else:
                mid = (left + right) // 2
                if mid >= i:
                    _update(i, x, 2 * curr + 1, left, mid)
                else:
                    _update(i, x, 2 * curr + 2, mid + 1, right)
                self.data[curr] = self.data[2 * curr + 1] + \
                    self.data[2 * curr + 2]

        _update(i, x, 0, 0, self.n - 1)

    def query(self, range_l, range_r):
        def _query(curr, left, right, range_l, range_r):
            if right < range_l or range_r < left:
                return 0
            if range_l <= left and right <= range_r:
                return self.data[curr]
            mid = (left + right) // 2
            left_sum = _query(2 * curr + 1, left, mid, range_l, range_r)
            right_sum = _query(2 * curr + 2, mid + 1,
                               right, range_l, range_r)
            return left_sum + right_sum

        return _query(0, 0, self.n - 1, range_l, range_r)


T = int(input())
for t in range(1, T + 1):
    N, Q = list(map(int, input().split()))
    arr = list(map(int, input().split()))
    s = [(-1) ** i * a for i, a in enumerate(arr)]
    ms = [(-1) ** i * a * (i+1) for i, a in enumerate(arr)]
    st_s = SegmentTreeSum(s)
    st_ms = SegmentTreeSum(ms)
    ret = 0
    for _ in range(Q):
        line = input().split()
        op, l, r = line[0], int(line[1]), int(line[2])
        sign = (-1) ** (l-1)
        if op == 'Q':
            q = sign * (st_ms.query(l-1, r-1) -
                        (l-1) * st_s.query(l-1, r-1))
            ret += int(q)
        elif op == 'U':
            st_s.update(l-1, sign * r)
            st_ms.update(l-1, sign * r * l)
    print('Case #{}: {}'.format(t, ret))

Discussion

ログインするとコメントできます