🖋️
ABC 015 D - 高橋くんの苦悩 PythonでのMLE,TLE対策
PythonでのMLEとTLEの対策
DPの練習にもなったけれど、MLE、TLEの解決方法について勉強になったいい問題
今の条件(メモリ1G)であれば、MLEは出ない気もするけれど知っていて損はないハズ
Pythonで3次元配列は通らない
見た数と幅と個数制限で3次元取って、dpは[i][cnt][sum_w]で見ればよいのだ!
って通らんやん!!
# template
# 解説みた
from math import gcd
from collections import defaultdict, Counter
import sys
sys.setrecursionlimit(10 ** 8)
input = lambda: sys.stdin.readline().rstrip()
ii = lambda: int(input())
mi = lambda: map(int, input().split())
li = lambda: list(mi())
inf = 2 ** 63 - 1
# mod = 10 ** 9 + 7
# mod = 998244353
w=ii()
n,k=mi()
a,b=[],[]
for i in range(n):
A,B=mi()
a.append(A)
b.append(B)
dp = [[[0] * (w+1) for i in range(k+1)] for i in range(n+1)]
dp[0][0][0] = 0
# dp[i][cnt][sum_w]
# i番目まで見た中で、cnt以下で選択し、sum_w以下の幅で選択したうちの最大値
# この問題だと重量と個数制限がついているのでそれぞれ見る必要があるのだ
for i in range(n):
for cnt in range(k+1):
for sum_w in range(w+1):
# sum_w以下、cnt以下であれば最大値を比較する
if sum_w - a[i]>=0 and cnt - 1 >=0:
dp[i+1][cnt][sum_w] = max(dp[i][cnt][sum_w],dp[i][cnt-1][sum_w-a[i]]+b[i])
else:
dp[i+1][cnt][sum_w] =dp[i][cnt][sum_w]
print(dp[n][k][w])
サンプルだと解けてるし挙動は問題ないので、C言語ならこの方針でもで通りそう。
ただし、PyPyだとMLE、PythonだとTLEとなるので対応が必要!!
MLEの原因
自分で調べた感じだと原因は以下の通り
- 3次元DPだと配列作成でメモリを多く消費する
- PyPyはメモリの消費量が多い
- 初期の問題だからなのかメモリ制限が256Mbと現行の25%と少ない
MLE対策
3次元で駄目なら、2次元に落とそう!
今回のDPは一つ前の情報があれば更新できる。
従って、dp[i][cnt][sum_w]のうちの[i]は直前の値(dp_prev)と次の値(dp_next)の2つで十分。
つまり、dp_prev[cnt][sum_w]、dp_next[cnt][sum_w]の2つがあればDPはできる。
dp[i-1]→dp[i]の更新方法でハマる
方針は分かっても、dp_nextをdp_prevへ更新する箇所で詰まった
- dp_prev=dp_next とすると参照渡しとなり、以降同じアドレスを参照するようになる
従って、以降はdp_prevとdp_nextは同じ値となるのでWAとなる - dp_prev=copy.deepcopy(dp_next) とするとTLEとなる
情報量が多いと複製するのに時間がかかるためだと思われる - dp_prev, dp_next = dp_next, dp_prev とすると、うまい具合に参照渡しが実施される
- 新たなメモリに値を格納する手間が省けるので処理が速い
- 更新後のdp_nextは不要だがゴミ箱として使用できる(常に上書き保存されるイメージ)
終わり
ということでMLE対策、次元の落とし方、と学びが多い問題でありました!
ACコード
# template
from math import gcd
from collections import defaultdict, Counter
import sys
sys.setrecursionlimit(10 ** 8)
input = lambda: sys.stdin.readline().rstrip()
ii = lambda: int(input())
mi = lambda: map(int, input().split())
li = lambda: list(mi())
inf = 2 ** 63 - 1
# mod = 10 ** 9 + 7
# mod = 998244353
w = ii()
n, k = mi()
a, b = [], []
for i in range(n):
A, B = mi()
a.append(A)
b.append(B)
dp_prev = [[0] * (w + 1) for i in range(k + 1)]
dp_next = [[0] * (w + 1) for i in range(k + 1)]
# dp[i][cnt][sum_w]
# この問題だと重量と個数制限がついているのでそれぞれ見る必要があるのだ
# i番目まで見た中で、cnt以下で選択し、sum_w以下の幅で選択したうちの最大値
# とすると、3次元配列だとメモリ容量食うのでMLEで敗北する
# dp_prev[cnt][sum_w] dp_next[cnt][sum_w]として
# iが更新されるたびにdp配列を更新することでメモリを節約する
for i in range(n):
for cnt in range(k + 1):
for sum_w in range(w + 1):
# sum_w以下、cnt以下であれば最大値を比較する
if sum_w - a[i] >= 0 and cnt - 1 >= 0:
dp_next[cnt][sum_w] = \
max(dp_prev[cnt][sum_w], dp_prev[cnt - 1][sum_w - a[i]] + b[i])
else:
dp_next[cnt][sum_w] = dp_prev[cnt][sum_w]
# iを更新するのでdp_prevにdp_nextを入れる
# dp_prev=dp_next とするとアドレス渡しとなり、以降同じ値を取るためNG
# dp_prev = copy.deepcopy(d) とするとTLEとなりNG
dp_prev, dp_next = dp_next, dp_prev
print(dp_prev[k][w])
Discussion