Kickstart 2020 RoundC: Candies
問題
N個の要素を持つ配列が与えられる。i番目の要素
- Update
-
の値を更新するA_i
-
- Query
- subarrayの合計甘さスコアを問い合わせ
- スコアは、indexがlからrまでの場合、
と計算A_l \times 1 - A_{l+1} \times 2 + A_{l+2} \times 3 - A_{l+3} \times 4 + A_{l+4} \times 5 ... - つまり、
でiがlからrまでのsumを求める(-1)^{i-l} A_{i} \times (i - l + 1)
合計Q回の操作を行い、そのうち全てのQueryのsumを求める問題。Queryの回数が0の時は0となる。
例えば次のケースの場合、最初のQueryは、
1 3 9 8 2
Q 2 4
Q 5 5
U 2 10
Q 1 2
アルゴリズム
Queryのsumを都度求める場合、N個の要素に対するQ回の計算になり、Q 1 9
の計算を図示している。さらに、例えばQ 5 8
のsumは図の青い領域に相当し、これは黒い枠線の領域に相当するsumから、オレンジの領域と灰色の領域を引いたものになることが分かる。ただし、範囲の左境界がevenの場合は、それぞれの領域の計算結果を反転させる必要がある。
このような計算をするために事前に求めておくsumを二種類用意する。一つが階段状のsumを求めるMultiple Prefix Sum(MS)で、Q l r
を求めるには、Q 5 8
で考えると、
Test set2をPassするために、sumの事前計算のデータ構造としてSegment Treeを利用する。MSに対応するSegment TreeをMT、Sに対応するSegment TreeをSTとすると、
import math
class SegmentTreeSum:
def __init__(self, arr=None):
if arr:
self.build(arr)
def build(self, arr):
def _build(curr, arr, left, right):
if left == right:
self.data[curr] = arr[left]
else:
mid = (left + right) // 2
_build(2 * curr + 1, arr, left, mid)
_build(2 * curr + 2, arr, mid + 1, right)
self.data[curr] = self.data[2 * curr + 1] + \
self.data[2 * curr + 2]
self.n = len(arr)
x = 2 ** math.ceil(math.log2(self.n))
self.data = [0] * (2 * x - 1)
_build(0, arr, 0, self.n - 1)
def update(self, i, x):
def _update(i, x, curr, left, right):
if left == right:
self.data[curr] = x
else:
mid = (left + right) // 2
if mid >= i:
_update(i, x, 2 * curr + 1, left, mid)
else:
_update(i, x, 2 * curr + 2, mid + 1, right)
self.data[curr] = self.data[2 * curr + 1] + \
self.data[2 * curr + 2]
_update(i, x, 0, 0, self.n - 1)
def query(self, range_l, range_r):
def _query(curr, left, right, range_l, range_r):
if right < range_l or range_r < left:
return 0
if range_l <= left and right <= range_r:
return self.data[curr]
mid = (left + right) // 2
left_sum = _query(2 * curr + 1, left, mid, range_l, range_r)
right_sum = _query(2 * curr + 2, mid + 1,
right, range_l, range_r)
return left_sum + right_sum
return _query(0, 0, self.n - 1, range_l, range_r)
T = int(input())
for t in range(1, T + 1):
N, Q = list(map(int, input().split()))
arr = list(map(int, input().split()))
s = [(-1) ** i * a for i, a in enumerate(arr)]
ms = [(-1) ** i * a * (i+1) for i, a in enumerate(arr)]
st_s = SegmentTreeSum(s)
st_ms = SegmentTreeSum(ms)
ret = 0
for _ in range(Q):
line = input().split()
op, l, r = line[0], int(line[1]), int(line[2])
sign = (-1) ** (l-1)
if op == 'Q':
q = sign * (st_ms.query(l-1, r-1) -
(l-1) * st_s.query(l-1, r-1))
ret += int(q)
elif op == 'U':
st_s.update(l-1, sign * r)
st_ms.update(l-1, sign * r * l)
print('Case #{}: {}'.format(t, ret))
Discussion