⏱️

競プロerの皆さん、DFSはオーダーいくつですか?

2023/11/02に公開

DFSの計算量はO(n^2)?

こんにちは、私はAtCoder勉強中の茶コーダーです。
AtCoderの過去問を解いている中で、DFSの時間計算量について理解できていなかった部分があり、躓いてしまったので、備忘録も兼ねてここで共有させていただきます。

これは私が、O(n^2)のDFSを実装してしまった話です。
基本的な計算量の計算はできるけど、典型的なアルゴリズムは速いって知ってるから当たり前に使ってるという人はぜひ読んでいってください。

取り扱う問題とアプローチ

今回はAtCoder Beginner Contest 302 - Dを例題として扱っていきます。

問題概要

  • 座標平面上にN個のノードがある
  • 以下の形式でM個の情報が与えられる
    • ノードA_iからみて、ノードB_iは、x軸正方向にX_iy軸正方向にY_i離れた地点にいる
  • それぞれのノードの座標を求めよ(座標が一意に決まらないノードはundecidableと出力せよ)
  • ノード0は原点にある

主な制約

  • 1≤N≤2×10^5
  • 1≤M≤2×10^5

アプローチ

  • 各ノードに対して、情報で与えられる頂点対(A_i, B_i)をエッジとするグラフを考える
  • グラフは有向グラフ
  • 各エッジは(A_i, B_i)に対する相対位置(X_i, Y_i)を持つ
  • ノード0からDFSして座標を決定
  • ノード0と非連結なノードはundecidable

要するに

  • 典型的なDFS問題
  • 重み付き有向グラフ(重み=相対位置(X_i, Y_i))

失敗した実装

まずは、私が初めに書いたコードはこちらです。

def dfs(u):
    if visited[u]:
        return

    visited[u] = True

    for v in range(n):
        if graph[u][v] == (INF, INF):
            continue
        if visited[v]:
            continue
        pos[v] = [pos[u][0] + graph[u][v][0], pos[u][1] + graph[u][v][1]]
        dfs(v)


INF = 10**10
n, m = map(int, input().split())

# 隣接行列で情報を受け取る
graph = [[(INF, INF)] * n for _ in range(n)]
for _ in range(m):
    a, b, x, y = map(int, input().split())
    a -= 1
    b -= 1
    # 行列の各マスに相対位置を追加
    graph[a][b] = (x, y)
    graph[b][a] = (-x, -y)

visited = [False] * n
pos = [(0, 0)] * n

dfs(0)
for i in range(n):
    if visited[i]:
        print(*pos[i])
    else:
        print('undecidable')

結果はTLEとなってしまいました。

DFSは明らかに正解っぽいので頭を抱えてしまいました。

この実装のダメなポイント

ソースコードをじっくり見て、時間計算量を確認してみましょう。
DFSの中身に注目してください。

def dfs(u):
    if visited[u]:
        return

    visited[u] = True

    for v in range(n):
        if graph[u][v] == (INF, INF):
            continue
        if visited[v]:
            continue
        pos[v] = [pos[u][0] + graph[u][v][0], pos[u][1] + graph[u][v][1]]
        dfs(v)

この関数の中でもforループの部分

for v in range(n):

これ、入力のノードuに対して、uとつながっているかに関わらず全てのノードに対してforブロックの中の処理を行っています。もちろん、

if graph[u][v] == (INF, INF):
    continue

の部分で、uvがつながっていなければスキップという処理があるのですが、このスキップ処理自体が行われてしまうので関係ありません。
この実装方法では時間計算量はO(n^2)となってしまい、TLEとなるのは当然です。
ここで私は思いました。

「速いアルゴリズム」のDFSがO(n^2)なんてあり得るわけない!!

なぜ私のDFSはO(n^2)になってしまったのか、本来のDFSの時間計算量はいくらなのかを見ていきたいと思います。

グラフ構造に関する前提知識

データ構造とアルゴリズムにおいて、グラフ構造の表現方法には主に2種類の表現方法があります。隣接リスト隣接行列です。それぞれの表現方法を簡単に説明します。基本的なことしか説明しないので、知ってる方は読み飛ばしてください。また、詳しく知りたい方はこちらの記事などが分かりやすいと思います。

隣接リスト

  • つながっているノードだけを保持する表現方法
  • 各ノードはつながっているノードのリストを持っている

隣接行列

  • グラフを表で表す表現方法
  • n個のノードに対してn \times nの表を考える
  • マス(i, j)の値は
    • ノードijがつながっているならマス(i, j)1
    • つながっていないなら0などとする

お互いの長所短所がある

これらの表現方法はどっちでもいいよ、というわけではなくて、時と場合によって使い分ける必要があります。その選択によって、メモリ効率や時間計算量は大きく変わってきます。

隣接行列を選択してしまった

今回私は、グラフの表現方法として、隣接行列表現を選択してしまいました。
なぜそうしたかというと、隣接行列表現なら各マスに相対位置(x, y)を持たせれば簡単に実装ができると考えたからです。

しかし、これが失敗の原因でした。

DFSの時間計算量

これまでDFSの時間計算量はあまり気にしてきませんでした。
というのも、競プロではいくつか有力なアルゴリズムというものがあり、それを適切に選択すればだいたいうまくいくからです。(もちろんC, D問題くらいまでの話です)
なので、DFSっぽい問題は、ちゃんとDFSに気づいて実装できれば、DFSの計算量がオーダーいくつか知らなくても、ACできてしまっていました。
そんなダメな勉強をしてしまうと、今回のような問題に陥ってしまうのです。ですので、この機会にしっかりとDFSの時間計算量を知っておきましょう。

🔥DFSの時間計算量はグラフの表現方法で違う🔥

ノードの数をN, エッジの数をMとした時に
DFSの時間計算量は

  • 隣接行列を使用した場合: O(N^2)
    • ある頂点から隣接する頂点を見つけるために、その頂点の行を全て調べる必要があるため
  • 隣接リストを使用した場合: O(N+M)
    • すべての頂点を1回ずつ訪れ、その頂点の隣接する頂点を調べる必要があるため

となります。ですので疎なグラフ(辺の数MV^2よりもずっと少ない)の場合、隣接リストを使用したDFSの方が効率的です。一方、密なグラフ(辺の数EV^2に近い)の場合、隣接行列と隣接リストの差はそれほど大きくありませんが、隣接リストの方が一般的には効率的です。

今回の問題に戻ってみると

今回の問題ではグラフを隣接リストと隣接行列のどちらで表現する方が適切でしょうか?
制約の部分を見るとO(N^2)なら当然間に合わず、O(N+M)であれば計算できることがわかると思います。従って今回は隣接リストで実装する方が正しかったわけですね。

修正後のコードはこちらになります。

import sys

LIMIT = 10**6
sys.setrecursionlimit(LIMIT)


def dfs(u):
    if visited[u]:
        return

    visited[u] = True

    for v in graph[u]:
        pos[v] = (pos[u][0] + diff[(u, v)][0], pos[u][1] + diff[(u, v)][1])
        dfs(v)


n, m = map(int, input().split())

# グラフは隣接リストで表現
graph = [[] for _ in range(n)]
# 相対位置はO(1)で参照できる別のデータ構造(タプルキーの辞書型)で表現
diff = {}
for _ in range(m):
    a, b, x, y = map(int, input().split())
    a -= 1
    b -= 1
    graph[a].append(b)
    graph[b].append(a)
    diff[(a, b)] = (x, y)
    diff[(b, a)] = (-x, -y)

visited = [False] * n
pos = [(0, 0)] * n

dfs(0)
for i in range(n):
    if visited[i]:
        print(*pos[i])
    else:
        print('undecidable')

これだとACにすることができました。(再帰を使っている関係上PyPyではなくCythonで実行しないとTLEします)

さいごに

今回は、グラフの表現方法とDFSの時間計算量に関するちょっとしたお話でした。
DFSの時間計算量以外にも隣接リスト表現と隣接行列表現ではたくさんの違いがあるので、もっと勉強が必要ですね。

競技プログラミングをやっていると、問題の解法や実装にばかり目がいって、データ構造とアルゴリズムの大事な部分をおろそかにしてしまいがちだと思うのでこの記事を書かせていただきました。少しでもいいと思っていただけたらぜひいいね❤️お願いします。

GitHubで編集を提案

Discussion