競プロerの皆さん、DFSはオーダーいくつですか?
O(n^2) ?
DFSの計算量はこんにちは、私はAtCoder勉強中の茶コーダーです。
AtCoderの過去問を解いている中で、DFSの時間計算量について理解できていなかった部分があり、躓いてしまったので、備忘録も兼ねてここで共有させていただきます。
これは私が、
基本的な計算量の計算はできるけど、典型的なアルゴリズムは速いって知ってるから当たり前に使ってるという人はぜひ読んでいってください。
取り扱う問題とアプローチ
今回はAtCoder Beginner Contest 302 - Dを例題として扱っていきます。
問題概要
- 座標平面上にN個のノードがある
- 以下の形式でM個の情報が与えられる
- ノード
からみて、ノードA_i は、B_i 軸正方向にx 、X_i 軸正方向にy 離れた地点にいるY_i
- ノード
- それぞれのノードの座標を求めよ(座標が一意に決まらないノードは
undecidable
と出力せよ) - ノード0は原点にある
主な制約
1≤N≤2×10^5 1≤M≤2×10^5
アプローチ
- 各ノードに対して、情報で与えられる頂点対
をエッジとするグラフを考える(A_i, B_i) - グラフは有向グラフ
- 各エッジは
に対する相対位置(A_i, B_i) を持つ(X_i, Y_i) - ノード0からDFSして座標を決定
- ノード0と非連結なノードは
undecidable
要するに
- 典型的なDFS問題
- 重み付き有向グラフ(重み=相対位置
)(X_i, Y_i)
失敗した実装
まずは、私が初めに書いたコードはこちらです。
def dfs(u):
if visited[u]:
return
visited[u] = True
for v in range(n):
if graph[u][v] == (INF, INF):
continue
if visited[v]:
continue
pos[v] = [pos[u][0] + graph[u][v][0], pos[u][1] + graph[u][v][1]]
dfs(v)
INF = 10**10
n, m = map(int, input().split())
# 隣接行列で情報を受け取る
graph = [[(INF, INF)] * n for _ in range(n)]
for _ in range(m):
a, b, x, y = map(int, input().split())
a -= 1
b -= 1
# 行列の各マスに相対位置を追加
graph[a][b] = (x, y)
graph[b][a] = (-x, -y)
visited = [False] * n
pos = [(0, 0)] * n
dfs(0)
for i in range(n):
if visited[i]:
print(*pos[i])
else:
print('undecidable')
結果はTLEとなってしまいました。
DFSは明らかに正解っぽいので頭を抱えてしまいました。
この実装のダメなポイント
ソースコードをじっくり見て、時間計算量を確認してみましょう。
DFSの中身に注目してください。
def dfs(u):
if visited[u]:
return
visited[u] = True
for v in range(n):
if graph[u][v] == (INF, INF):
continue
if visited[v]:
continue
pos[v] = [pos[u][0] + graph[u][v][0], pos[u][1] + graph[u][v][1]]
dfs(v)
この関数の中でもforループの部分
for v in range(n):
これ、入力のノードu
に対して、u
とつながっているかに関わらず全てのノードに対してforブロックの中の処理を行っています。もちろん、
if graph[u][v] == (INF, INF):
continue
の部分で、u
とv
がつながっていなければスキップという処理があるのですが、このスキップ処理自体が行われてしまうので関係ありません。
この実装方法では時間計算量は
ここで私は思いました。
「速いアルゴリズム」のDFSが
なぜ私のDFSは
グラフ構造に関する前提知識
データ構造とアルゴリズムにおいて、グラフ構造の表現方法には主に2種類の表現方法があります。隣接リストと隣接行列です。それぞれの表現方法を簡単に説明します。基本的なことしか説明しないので、知ってる方は読み飛ばしてください。また、詳しく知りたい方はこちらの記事などが分かりやすいと思います。
隣接リスト
- つながっているノードだけを保持する表現方法
- 各ノードはつながっているノードのリストを持っている
隣接行列
- グラフを表で表す表現方法
-
個のノードに対してn の表を考えるn \times n - マス
の値は(i, j) - ノード
とi がつながっているならマスj は(i, j) 1 - つながっていないなら
などとする0
- ノード
お互いの長所短所がある
これらの表現方法はどっちでもいいよ、というわけではなくて、時と場合によって使い分ける必要があります。その選択によって、メモリ効率や時間計算量は大きく変わってきます。
隣接行列を選択してしまった
今回私は、グラフの表現方法として、隣接行列表現を選択してしまいました。
なぜそうしたかというと、隣接行列表現なら各マスに相対位置
しかし、これが失敗の原因でした。
DFSの時間計算量
これまでDFSの時間計算量はあまり気にしてきませんでした。
というのも、競プロではいくつか有力なアルゴリズムというものがあり、それを適切に選択すればだいたいうまくいくからです。(もちろんC, D問題くらいまでの話です)
なので、DFSっぽい問題は、ちゃんとDFSに気づいて実装できれば、DFSの計算量がオーダーいくつか知らなくても、ACできてしまっていました。
そんなダメな勉強をしてしまうと、今回のような問題に陥ってしまうのです。ですので、この機会にしっかりとDFSの時間計算量を知っておきましょう。
🔥DFSの時間計算量はグラフの表現方法で違う🔥
ノードの数を
DFSの時間計算量は
-
隣接行列を使用した場合:
O(N^2) - ある頂点から隣接する頂点を見つけるために、その頂点の行を全て調べる必要があるため
-
隣接リストを使用した場合:
O(N+M) - すべての頂点を1回ずつ訪れ、その頂点の隣接する頂点を調べる必要があるため
となります。ですので疎なグラフ(辺の数
今回の問題に戻ってみると
今回の問題ではグラフを隣接リストと隣接行列のどちらで表現する方が適切でしょうか?
制約の部分を見ると
修正後のコードはこちらになります。
import sys
LIMIT = 10**6
sys.setrecursionlimit(LIMIT)
def dfs(u):
if visited[u]:
return
visited[u] = True
for v in graph[u]:
pos[v] = (pos[u][0] + diff[(u, v)][0], pos[u][1] + diff[(u, v)][1])
dfs(v)
n, m = map(int, input().split())
# グラフは隣接リストで表現
graph = [[] for _ in range(n)]
# 相対位置はO(1)で参照できる別のデータ構造(タプルキーの辞書型)で表現
diff = {}
for _ in range(m):
a, b, x, y = map(int, input().split())
a -= 1
b -= 1
graph[a].append(b)
graph[b].append(a)
diff[(a, b)] = (x, y)
diff[(b, a)] = (-x, -y)
visited = [False] * n
pos = [(0, 0)] * n
dfs(0)
for i in range(n):
if visited[i]:
print(*pos[i])
else:
print('undecidable')
これだとACにすることができました。(再帰を使っている関係上PyPyではなくCythonで実行しないとTLEします)
さいごに
今回は、グラフの表現方法とDFSの時間計算量に関するちょっとしたお話でした。
DFSの時間計算量以外にも隣接リスト表現と隣接行列表現ではたくさんの違いがあるので、もっと勉強が必要ですね。
競技プログラミングをやっていると、問題の解法や実装にばかり目がいって、データ構造とアルゴリズムの大事な部分をおろそかにしてしまいがちだと思うのでこの記事を書かせていただきました。少しでもいいと思っていただけたらぜひいいね❤️お願いします。
Discussion