[EDPC-Z問題 Frog 3] Educational DP Contest を Pythonで解く

2022/06/28に公開

はじめに

競技プログラミング(使用言語はPython)を日々精進中です。
DP(動的計画法)をマスターするために、AtCoderのEducational DP Contestを解いています。

Z問題 Frog 3

https://atcoder.jp/contests/dp/tasks/dp_z

問題概要

N個の足場があり、足場の高さは、h_1 < h_2 < h_3 < … < h_N となっている。足場iからのジャンプ先は i+1, i+2, … , N までのどこでも選べるが、コストとして、(h_i - h_j)^2+Cを支払う必要がある。足場1からスタートして足場Nに辿り着くまでのコストの総和の最小値を求める。

制約

2 ≦ N ≦ 2×\def\bar#1{#1^5} \bar{10}

考え方

DP[i]は、足場iに到達するまでのコスト最小値と定義して、TLEとなるコードであれば、

N,C=map(int,input().split())
h = list(map(int,input().split()))
INF = 10**18
DP = [INF]*N
DP[0] = 0

for i in range(1,N):
    for j in range(i):
        DP[i] = min(DP[i], DP[j]+(h[i]-h[j])**2+C)

ans = DP[N-1]
print(ans)

で正しい出力自体は得られます。

ここで、計算のボトルネックとなっているのは下記の部分なので、展開してj以外のものをminの外に出してしまう。

DP_i = \operatorname*{min}\limits_{1≦j<i} (DP_j+(h_i-h_j)^2+C)
DP_i = \operatorname*{min}\limits_{1≦j<i} (DP_j+h_i^2+h_j^2-2h_i h_j+C)
DP_i = \operatorname*{min}\limits_{1≦j<i} (DP_j+h_j^2-2h_j h_i)+h_i^2+C

minの中をh_iに関する式として見ると、傾き:-2h_j 切片:DP_j+h_j^2 の一次式となっています。さらにh_jは単調増加なので、jが大きくなると、傾き-2h_jは単調減少します。

Convex Hull Trick という傾きが単調減少となる直線群における最小値を求める方法を用います。

y=-2h_1x+DP_1+h_1^2
y=-2h_2x+DP_2+h_2^2
y=-2h_3x+DP_3+h_3^2
y=-2h_{i-1}x+DP_{i-1}+h_{i-1}^2

これらの直線上でのx=h_iにおける最小値がDP_iとなります。つまり、直線の特徴値である傾き、切片を保持しながら、x=h_iでの最小値を求めていけば良いです。全ての直線を記録しておく必要はなく、最小値には無関係な直線の情報は適宜破棄していくことになります。例えば下図では、新しい直線(赤)が加わったときに、直線 y=a_2x+b_2 は不要となる場合があります。直線 y=a_1x+b_1y=a_2x+b_2 の交点は、\left({\dfrac{b_2-b_1}{a_1-a_2} , \dfrac{a_1b_2-a_2b_1}{a_1-a_2}}\right) となるので、

傾きa、切片bの直線を追加したときに、元の直線がまだ必要となる条件は、

\dfrac{a_1b_2-a_2b_1}{a_1-a_2} < \dfrac{b_2-b_1}{a_1-a_2} ×a + b
a_1b_2-a_2b_1 < a(b_2-b_1)+b(a_1-a_2)
a_1b_2-a_2b_2+a_2b_2-a_2b_1 < a(b_2-b_1)-b(a_2-a_1)
-b_2(a_2-a_1)+a_2(b_2-b_1) < a(b_2-b_1)-b(a_2-a_1)
b(a_2-a_1)-b_2(a_2-a_1) < a(b_2-b_1)-a_2(b_2-b_1)
(b-b_2)(a_2-a_1) < (a-a_2)(b_2-b_1)

また、x=h_iは単調増加なので、最小値を求める際に使われなかった直線は廃棄できる。

実装メモ

from collections import deque
N,C = map(int,input().split())
h = list(map(int,input().split()))
dp = [0] * N
fs = deque()
def add(i):
    hi = h[i]
    a,b = -2*hi,hi**2+dp[i]
    while len(fs) >= 2:
        a1,b1 = fs[-2]
        a2,b2 = fs[-1]
        if (a2-a1)*(b-b2) < (b2-b1)*(a-a2):
            break
        fs.pop()
    fs.append((a,b))

一番右端の直線を条件を確認して削除する。

def get_min(x):
    a1,b1 = fs[0]
    y = a1*x+b1
    while len(fs) >= 2:
        a2,b2 = fs[1]
        if y < a2*x+b2:
            break
        y = a2*x+b2
        fs.popleft()
    return y

最小値を求める際に必要なければ、一番左端の直線を削除する。

add(0)
for i in range(1,N):
    hi = h[i]
    dp[i] = get_min(hi) + hi**2 + C
    add(i)

ans = dp[-1]
print(ans)

Discussion