AtCoder ABC411-E upsolve 備忘録
ABC411-E を解説みながら、upsolve していたのですが、結構つまづいたので、備忘録を残しておきます。
※解説記事ではないです。厳密な理解は公式解説を読んでください。
この問題の本質は「何かの最大値が特定の値に一致する確率を求めるより、最大値が特定の値以下になる確率を求める方が往々にして簡単である」です。
解説では、これを適用した後に、色々式変形をしていますが、これを本番で思いつくのは個人的には厳しいと思ったので、自分で思いつくやり方で解かないと意味ないなと思いました。
※式変形自体を理解できるようにはしておきたい。。。
愚直解
異なる 2 つの特定の値以下(仮にp<q
なるp
とq
を考える)になる確率を求めることができれば、それらの差をとることで、(p,q]
の範囲になる確率を求めることができます(累積和っぽいね)。
今回の場合は、最大値の候補をソートして、隣り合う値の確率の差をとると、離散的なのでq
になる確率を求めることができます。
まず、各サイコロの上を向いた面に書かれた数の最大値の候補s
を列挙します。
a
の要素を全て列挙し、重複削除して、ソートするだけで良いです。
n=int(input())
a=[list(map(int,input().split())) for _ in range(n)]
s=[]
for ai in a:
s.extend(ai)
s=list(set(s))
s.sort()
というのは実は嘘で、各サイコロの最小の最大より小さい値が候補となることはありえません。
なので、候補としてあり得る値の最小値s_min
を事前に求めておきます。
s_min=0 # 各サイコロの最小値の最大値(これより小さいと確率が0になる)
for i,ai in enumerate(a):
s_min=max(s_min,min(ai))
次に、値がs_min
以下(実際はs_min
のみ)になる確率を求めます。
サイコロのs_min
以下の値を数えて、その値を 6 で割ったものが単一サイコロのs_min
以下になる確率です。
これを各サイコロで計算して、それらの積が求める確率になります。
a
を逆順にソートして、小さい値から pop して数えやすいように工夫します。
for ai in a:
ai.sort(reverse=True) # bをカウントする際に、小さい値からpopするために逆順にソート
b=[0 for _ in range(n)] # i番目のサイコロの目がsj以下の個数
inv6=pow(6,-1,mod) # 6の逆元
prob=1
for i,ai in enumerate(a):
while len(ai)>0 and ai[-1]<=s_min:
ai.pop()
b[i]+=1
prob*=b[i]*inv6
prob%=mod
あとはs
を順次見ていき、同じように計算し、加算していけば答えを得ることができます。
直前に計算した確率を次に使用するのでメモしておきます。
ans=(prob*s_min)%mod # s_minとなる期待値を計算
before=prob
# s_minより大きい目が出る確率を順次計算
for sj in s:
if sj<=s_min:
continue
# TLEだが愚直に計算
prob=1
for i,ai in enumerate(a):
while len(ai)>0 and ai[-1]<=sj:
ai.pop()
b[i]+=1
prob*=b[i]*inv6
prob%=mod
ans+=sj*(prob-before+mod)
ans%=mod
before=prob
print(ans)
愚直解のコード
mod=998244353
n=int(input())
a=[list(map(int,input().split())) for _ in range(n)]
s=[]
for ai in a:
s.extend(ai)
ai.sort(reverse=True) # 後にbをカウントする際に、小さい値からpopするために逆順にソート
s=list(set(s))
s.sort()
s_min=0 # 各サイコロの最小値の最大値(これより小さいと確率が0になる)
for i,ai in enumerate(a):
s_min=max(s_min,min(ai))
b=[0 for _ in range(n)] # i番目のサイコロの目がsj以下の個数
inv6=pow(6,-1,mod) # 6の逆元
# s_minの目が出る確率を計算
prob=1
for i,ai in enumerate(a):
while len(ai)>0 and ai[-1]<=s_min:
ai.pop()
b[i]+=1
prob*=b[i]*inv6
prob%=mod
ans=(prob*s_min)%mod
before=prob
# s_minより大きい目が出る確率を順次計算
for sj in s:
if sj<=s_min: # s_min以下は計算済みなのでスキップ
continue
# TLEだが愚直に計算
prob=1
for i,ai in enumerate(a):
while len(ai)>0 and ai[-1]<=sj:
ai.pop()
b[i]+=1
prob*=b[i]*inv6
prob%=mod
ans+=sj*(prob-before+mod)
ans%=mod
before=prob
print(ans)
AC 解
愚直解だと、毎回サイコロ全部を照会しており、TLE するので、高速化をする必要があります。
よくよく考えてみると、このカウントの計算は全部で6n
回しか行われないはずで、どのサイコロが更新されるのかを予め計算しておけば、無駄な照会が減り、高速化できそうです。
また、確率に関しては、直前の確率から、更新のあるサイコロの確率を差分更新すれば良さそうです(直前の確率で割って、更新後の確率を掛ければよい)。
まず、高速参照するための辞書つくります(座圧して配列でもってもよい)。
d=defaultdict(list) # サイコロの目がsjのindexリスト
for i,ai in enumerate(a):
for aij in ai:
d[aij].append(i)
愚直に計算していた箇所は以下のように書き換えます。
prob=before
for i in d[sj]:
ai=a[i]
prob*=pow(b[i]*inv6,-1,mod) # 直前の確率で割る
while len(ai)>0 and ai[-1]<=sj:
ai.pop()
b[i]+=1
prob*=b[i]*inv6 # 更新後の確率を掛ける
prob%=mod
AC 解のコード
from collections import defaultdict
mod=998244353
n=int(input())
a=[list(map(int,input().split())) for _ in range(n)]
s=[]
for ai in a:
s.extend(ai)
ai.sort(reverse=True) # 後にbをカウントする際に、小さい値からpopするために逆順にソート
s=list(set(s))
s.sort()
d=defaultdict(list) # サイコロの目がsjのindexリスト
mx=0 # 各サイコロの最小値の最大値(これより小さいと確率が0になる)
for i,ai in enumerate(a):
mx=max(mx,min(ai))
for aij in ai:
d[aij].append(i)
b=defaultdict(int) # i番目のサイコロの目がsj以下の個数
inv6=pow(6,-1,mod) # 6の逆元
prob=1
for i,ai in enumerate(a):
while len(ai)>0 and ai[-1]<=mx:
ai.pop()
b[i]+=1
prob*=b[i]*inv6
prob%=mod
ans=(prob*mx)%mod
before=prob
for sj in s:
if sj<=mx: # mx以下は計算済みなのでスキップ
continue
prob=before
for i in d[sj]:
ai=a[i]
prob*=pow(b[i]*inv6,-1,mod) # 直前の確率で割る
while len(ai)>0 and ai[-1]<=sj:
ai.pop()
b[i]+=1
prob*=b[i]*inv6 # 更新後の確率を掛ける
prob%=mod
ans+=sj*(prob-before+mod)
ans%=mod
before=prob
print(ans)
デバッグについて
確率や期待値の有理数 mod だと途中の処理があっているかどうかを判断するのが難しいです。
他の競プロ er の人たちはどうやっているか良くわかりませんが、単に float で計算する方法、もしくは以下のように復元する方法を使いました。
※このコードは生成 AI に書いてもらいました。
from math import gcd
MOD=998244353
def recover_fraction(x, max_denom=100000):
"""
mod上の整数 x が a/b (mod MOD) に相当するとき、
分母bを1からmax_denomまで試して (a, b) を推定する
"""
results = []
for b in range(1, max_denom + 1):
a = (x * b) % MOD
if gcd(a, b) == 1:
results.append((a, b))
return results # 一般に複数候補がある
a,b=1,6
if (a,b) in recover_fraction(a*pow(b,MOD-2,MOD)%MOD): # 1/6となるか確認
print("OK")
else:
print("NG")
Discussion