AtCoder ABC254 でscipyを使いたい

2022/06/08に公開

scipyはnumpyよりもさらに使う人が少ない気がする。
自分もnumpy以上に使い方が分かっておらず、どんな便利な関数があるかレベルで人のコードを参考にしたいのだけど、皆もっと使いませんか。
pypyで満足ですか、そうですか。

E - Small d and k

https://atcoder.jp/contests/abc254/tasks/abc254_e
グラフ上である頂点から最大3個先までの頂点を列挙する問題。

Dijkstra法

最初に思いついたのはDijkstra法。
scipyのdijkstra関数は自分が使うとAtCoderの問題では高確率で制限時間超過する。
人類の叡智が詰まったすごい最適化コードで動いていると思うのだが、自分が使い方を間違えているのか、とにかく通らない。
例えばABC252E - Road Reduction
https://atcoder.jp/contests/abc252/tasks/abc252_e
この問題とか、scipy.sparse.csgraph.dijkstra使って最短経路を出すプロセスだけで時間超過するので、誰かscipyでの通し方知っていたら教えてください。

今回の問題も時間超過するだろうな、と思いつつ、dijkstra関数オプションのlimitを3にしてみる。

dist_matrix=dijkstra(G,directed=False,indices=x,unweighted=True,limit=3,min_only=True)

コストがlimitを超えると探索を打ち切る、と書いてあるので、もしかしたら早く処理できるかも、と考えたが、これでクエリ回数分for文回すと案の定時間超過。

疎行列べき乗

解説見るとDFS/BFS使うとのこと。
scipyのドキュメント見るとscipy.sparse.csgraph.depth_first_order等の関数あるが、dijkstra関数のlimitのような探索範囲を制限できるオプションがないので使え無さそう。
今回も解けそうにないと思いつつグラフの入力に使う行列(隣接行列というらしい)を見てると、行列要素の位置が今行ける頂点を表していて、行列を二乗(n乗)したら2回(n回)の移動で行ける頂点に遷移できそう。
wikipediaにもそれぽいことが書いてある。
https://ja.wikipedia.org/wiki/隣接行列
今回の問題は最大3回しか移動しないし、行列を3乗したら3回後に行ける頂点を計算できそう。
問題は頂点数Nが最大10^5台で、行列サイズがNxNになりそもそもメモリに乗り切らない、行列演算の計算時間はN^3で計算が一見不可能に思える。が、グラフ問題で使っているscipyの疎行列なら3乗くらいなら力業で計算できるのではないか。
ということで、AtCoderでの隣接グラフの作り方をコピペして隣接グラフを作った。
https://ikatakos.com/pot/programming_algorithm/route_search/scipy

import sys
import numpy as np
from scipy.sparse import csr_matrix

inp=np.array(list(map(int,sys.stdin.buffer.read().split())))
n,m=inp[:2]
a=inp[2:2*m+1:2]
b=inp[3:2*m+2:2]
q=inp[2*(m+1)]
x=inp[2*(m+1)+1::2]
k=inp[2*(m+1)+2::2]

c=np.ones((m,),dtype=np.int) #コスト1の辺のつもり
graph=csr_matrix(c,(a-1,b-1)),(n,n))

scipyのdijkstra等グラフ系の関数では、無向グラフならa→bに行く辺だけ記述すればb→aを陽に書かなくとも勝手に良い感じに計算してくれるが、今回はまじめに逆方向も要素指定しないといけない。
要は対称行列にすればよいので、転置処理して足せばok。
転置処理はtranspose()という関数がちゃんと用意されていた。
自分自身への遷移(ループ)も考えるために、対角成分にも要素を入れる。

e=eye(n) #単位行列。scipy.sparse.eyeからimport必要
graph=graph+graph.transpose()+e

あとはこの行列を3乗して、要素のある列(行)番号を抽出すれば3回の遷移で行ける頂点を列挙できる。
疎行列のある行から要素が存在する列番号を高速に抽出する方法があればよい。
が、何かスマートな方法あるのかな。
べき乗した後の0でない要素が全て1なら、[1,2,3,...,n]のベクトルとの内積とれば問題の答え(列挙した頂点番号の合計)になる。
疎行列の要素を1に置換するためにnumpy.where関数に相当するようなものを探したが見つけられないため、思いつきで疎行列の型をTrue/Falseのbool型に指定してみた。

c=np.ones((m,),dtype=np.int)
e=eye(n,dtype=bool)
graph=csr_matrix((c,(a-1,b-1)),(n,n),dtype=bool)
graph=graph+graph.transpose()+e

d0=np.arange(1,n+1,dtype=np.int).reshape((n,1)) #[1,2,3,...,n]ベクトル
d0s=csr_matrix(d0) #ベクトルの疎行列表現
d1=graph*d0s #1乗
g=graph**2
d2=g*d0s #2乗
g=graph*g
d3=g*d0s #3乗
dist=np.zeros((4,n),dtype=np.int) #各頂点からの距離3以下の頂点番号総和の配列
dist[0,:]=d0.reshape((n,))
dist[1,:]=d1.toarray().reshape((n,))
dist[2,:]=d2.toarray().reshape((n,))
dist[3,:]=d3.toarray().reshape((n,))
xarray=dist[k,x-1] #クエリ順に配列を並び替え

print(*xarray,sep="\n")

bool型での行列積演算や、bool型とint型の掛け算(上のコードのgraph*d0sの部分等)がどう計算されるか全く分からなかったが、試しに問題のテストケースで計算させると正しい答えが出てきたので、これで良いぽい。
numpy,sicpy、違う型でも計算してくれるの本当にすごいですね。
最終的に、問題にある通り各クエリに一つずつ答えるループをfor文で回す、という方法ではなく、全頂点から最大3回までの移動で行ける場所を全て計算しておく、というライブラリに頼った強引な解き方になった。
計算しなくても良い頂点も計算するので計算時間どうかなと思ったが、結果はpythonの中ではなかなか高速な部類。
この問題でfor文を一切書かず解ける、bool型と整数型混じっていても計算できる、というので、scipyすごいですね。

行列の各行の中で0でない要素の列番号がその行の番号から行ける頂点で、非ゼロの列要素を全部探す、という考えなので、本質的にはBFSと同じことをやっていると思われる。
0の要素をそもそも計算しないので早い、探索先を添え字から辿るのではなく行列計算の単純なループ計算に任せているので、BFSの自前実装よりも最適化が効いて早い?
計算量はよくわからない。1行の最大非ゼロ要素数~40 x 行数N?

ちなみに疎行列どうしの演算(特に行列積)はゼロ要素が多いほど早いと思われる。
自己ループの対角成分を入れてしまうとそれだけで問題によっては要素数が10^5個増えるので、隣接行列に最初から対角成分加えてn乗計算するよりは、後で足したほうがおそらく早い。
注意点は、対角成分がないと自分自身に行けない。
n乗したときはn-1→nで移動できる頂点に非ゼロ要素が新たにできるが、n-1時点でいた頂点からはいなくなっている可能性があるので、n回以下で行ける頂点を探すのであれば隣接行列の0乗(単位行列)、1乗、・・・(n-1)乗も足さないと行けない。

e=eye(n,dtype=bool)
graph=csr_matrix((c,(a-1,b-1)),(n,n),dtype=bool)
graph=graph+graph.transpose() #単位行列はまだ足さない

dist=np.zeros((n,4),dtype=np.int)
d0=np.arange(1,n+1,dtype=np.int).reshape((n,1))
d0s=csr_matrix(d0)
d1s=graph*d0s+d0s #(graph+e)*d0s #距離1以下は隣接行列graph+単位行列
g=graph*graph
sumg=g+graph+e #graph**2+graph+eのこと。graph**2と、次数1以下の項との和は3乗計算時にそれぞれ再利用するので、g,sumgとしてメモリに入れておく
#graph**2で自分自身に戻れるので、2乗のときはeは足さなくともよいはず
d2s=sumg*d0s #2乗の計算
d3s=(graph*g+sumg)*d0s #3乗の計算 graph**3+graph**2+...+e
dist=np.zeros((4,n),dtype=np.int) #各頂点からの距離3以下の頂点番号総和の配列
dist[0,:]=d0.reshape((n,))
dist[1,:]=d1.toarray().reshape((n,))
dist[2,:]=d2.toarray().reshape((n,))
dist[3,:]=d3.toarray().reshape((n,))
xarray=dist[k,x-1]

print(*xarray,sep="\n")

また、疎行列に疎行列をかけるのではなくベクトルをかけるようすればもっと早くなると思うが、こちらは簡単に計算する方法が分からず。
ただこれでも多くのpypy解答より早いので、numpy,scipy,networkXライブラリに頼ったごり押し計算・スマートな計算はもっと流行ってほしい。

Discussion