💬

Atcoder ABC246 でnumpyを使いたい

2022/04/07に公開約2,700字

numpyに慣れたくて強引にnumpyで書く方法をいろいろ試行錯誤してドツボにハマった。

C - Coupon

これは解法さえ思いつけばnumpyのブロードキャストを使った四則演算で素直に書けた。
for文、list内包表記無しで書ける。
np.sortのkind="stable"オプションは無くとも大丈夫と思うが、np.sort使うときは念のため使うようにしている。

ABC246C.py
import numpy as np
n,k,x=map(int,input().split())
a=np.array(list(map(int,input().split())))
div=a//x
divsum=np.sum(div)
if divsum>=k:
    print(np.sum(a)-k*x)
    exit()
k=k-divsum
if k>=n:
    print(0)
    exit()
mod=a%x
mod=np.sort(mod,kind="stable")
print(np.sum(mod[:n-k]))

D - 2-variable Function

めちゃくちゃハマった。最初numpy解法のアイディアを思いついたときはシンプルに書けると思ったが、途中から無駄な縛りを入れて解いている気分になった。正直、公式解法見た方がシンプルだし間違いがない。
aを1増やしたら条件を満たすまでbを減らし続け、またaを1増やす...というfor文を書けば良いが、次のbの初期値候補が1個前のfor文ループで求めたbの値になっている。このような1個前の状態に依存して次の値が決まる量をnumpyで計算するのは難しい。こういう計算ができるのはnp.cumsum、np.cumprodくらい?そこで、bを増減させて探索するのではなく、O(1)で代数的に求めてみた。aの候補をnp.arange関数で生成した後、a^{3}+a^{2}b+ab^{2}+b^{3}-n=0をbについての3次式と見て解く。「3次方程式」「解の公式」で検索すると複雑な式が出てくるが、そのまま打ち込む。
...エラー。根号の中が負になっている模様。欲しいのは実数解なので虚数を使った解は不要のはず。タイプミスを疑い他のページも見比べてみる。wikipedia物理のかぎしっぽでは3乗根の前の符号が違っているが、どちらも表記ミスがあるとは考えにくいし、試しに書き換えてもエラーは消えない。いろいろ検索して、どこのページが失念したが、3乗根の中は負になり得る、という記述を見て気づいた。ルートだと根号の中が負の場合虚数になるが、\sqrt[3]{-1}=-1なので、3乗根の中の符号は負でも実数になる。ただし、np.power関数は負値を入力するとエラーになるので、絶対値を入力して根号の中の符号は外に出さないといけない。これに気づくのにだいぶ時間がかかった。
次のハマりはオーバーフロー。入力例3,10^{18}に近い数を入力すると変な値が出力される。64bit整数の範囲内で計算していたつもりだったが、いろいろ調べた結果、3次方程式の解の公式を使う過程でオーバーフローしていた模様。numpyのdtypeをintからfloatに変えると正しそうな値にはなったが、floatの有効数字15桁とすると、18桁の入力に対しては精度不足。解の公式から求めたbの値を整数化し、a^{3}+a^{2}b+ab^{2}+b^{3}がn未満ならb+1とすれば良い、と考えていたが、おそらく精度不足が理由で入力例3の正しい答えが出てこない。b+2まで考えるとやっと答えが合った。

ABC246D.py
import numpy as np
import math
import sys
n=int(input())
if n==0:
    print(0)
    sys.exit()
aupper=pow(n,1/3)
aupper=math.floor(aupper)
if pow(aupper,3)<n:
    aupper+=1
alower=pow(n/4,1/3)
alower=math.floor(alower)
if 4*pow(alower,3)<n:
    alower+=1

def func(a,b):
    return (a+b)*(a*a+b*b)

def cardano(a,b,c): #カルダノの公式の実数解
    s1=(27*c+2*a*a*a-9*a*b)/54
    s4=(3*b-a*a)/9
    t1=-s1+np.sqrt(s1*s1+s4*s4*s4)
    tsign=np.sign(t1) #np.signで符号を取得
    t1=tsign*np.power(np.abs(t1),1/3) #np.powerの外に符号を出す
    t2=-s1-np.sqrt(s1*s1+s4*s4*s4)
    tsign=np.sign(t2)
    t2=tsign*np.power(np.abs(t2),1/3)
    x=t1+t2-a/3
    return x

a=np.arange(alower,aupper+1,dtype=np.int64) #aの取りうる範囲
af=a.astype(np.float) #カルダノの公式計算用にはfloatを使う
b=cardano(af,af*af,af*af*af-n)
b=np.where(b<0,0,b) #b<0の解は不要
b=np.floor(b).astype(np.int64)
x1=func(a,b)
x2=func(a,b+1)
x3=func(a,b+2)
x=np.where(x1>=n,x1,x2)
x=np.where(x>=n,x,x3) #b+2まで考慮しないと正解にならない
print(np.amin(x))

for文を使わずになんとかnumpyで処理できたが、スマートではない。pypyの提出と比べてもそこまで早くない(200ms前後)。無理にnumpy使う必要は無いと痛切に感じた一問。符号を求めるnp.sign関数の存在を知ることができたのが収穫。

Discussion

ログインするとコメントできます