🍊

# ABC189 C - Mandarin Orange

2021/01/25に公開

ABC189 C - Mandarin Orange

問題へのリンク

https://atcoder.jp/contests/abc189/tasks/abc189_c

問題概要

N 枚の皿が一列に並べられており、左から i 番目の皿には A_i 個のみかんが置かれている。
次の3つの条件をすべて満たすような整数の組 (l, r, x) を1つ選ぶ。

  • 1 \leq l \leq r \leq N
  • 1 \leq x
  • l 以上 r 以下の全ての整数 i について、 x \leq A_i

その後、 l 番目から r 番目まで (両端を含む) のすべての皿からみかんを x 個ずつとって食べる。
最大で何個のみかんを食べることができるか。

制約

1 \leq N \leq 10^4
1 \leq A_i \leq 10^5

ABC中の解答

制約が 1 \leq N \leq 10^4 なので O(N^2) はギリギリ間に合わないかな?と最初は思った。
(一方で普段 10^5 のことが多いのに 10^4 なの?とも思った)
そこで計算コストで下げるために尺取法を実装してsampleがすべてACになって提出したけどこれは間違いだった。
例えば私の尺取法では 2 3 3 1 1 1 2 1 のような場合は 2 2 2 あるいは 3 3 が最大と判断され
6 が解となってしまうが実際には 1 1 1 1 1 1 1 18 が最大であるからだ。
そこで計算コストを下げるために二次元配列で i 番目以降の皿の中の累積Minを記録し、それを用いて区間 (l, r) の最小値を求めて解を求めた。C問題にしては難産だったがAC。

import sys

input = sys.stdin.buffer.readline
N = int(input())
A = list(map(int, input().split()))
cummin = [[0] * N for _ in range(N)]
for i in range(N):
    cummin[i][i] = A[i]
    for j in range(i + 1, N):
        cummin[i][j] = min(cummin[i][j - 1], A[j])

ans = 0
for left in range(N):
    for right in range(left, N):
        x = cummin[left][right]
        ans = max(ans, (right - left + 1) * x)
print(ans)

公式解法1

あれこれ無駄に考えてしまったが制約が 1 \leq N \leq 10^4 なので O(N^2) でもACとなるようだ。
ただしPythonだと TLE になるので PyPy3 で提出する必要がある。

import sys

input = sys.stdin.buffer.readline
N = int(input())
A = list(map(int, input().split()))

ans = 0
for left in range(N):
    x = A[left]
    for right in range(left, N):
        x = min(x, A[right])
        ans = max(ans, (right - left + 1) * x)
print(ans)

https://atcoder.jp/contests/abc189/submissions/19679250

Pythonで提出したら TLE

https://atcoder.jp/contests/abc189/submissions/19679232

他の解法1

公式の解説ページに O(N) で解くこともできると書いている。maspyさん極大長方形で解けるとtweetしており、その後、競プロ仲間に同義の最大長方形を紹介しているページを教えてもらった。

https://twitter.com/maspy_stars/status/1352974845185724416

http://algorithms.blog55.fc2.com/blog-entry-132.html

上の記事はすごいわかりやすく、また最大長方形のDPのアルゴリズムがめちゃくちゃかしこくてすごい感心してしまった。勉強のために記事を読んだ後に自分で実装してみた。最後に stacks に残っている中身を処理するために A0 を追加しているところがオシャレポイント。

from collections import deque
import sys

input = sys.stdin.buffer.readline
N = int(input())
A = list(map(int, input().split())) + [0]

ans = 0
stacks = deque([(A[0], 0)])
for i in range(1, len(A)):
    if stacks[-1][0] < A[i]:
        stacks.append((A[i], i))
        continue
    if stacks[-1][0] > A[i]:
        while stacks and stacks[-1][0] >= A[i]:
            h, j = stacks.pop()
            ans = max(ans, h * (i - j))
        stacks.append((A[i], j))
print(ans)

https://atcoder.jp/contests/abc189/submissions/19679644

Discussion