🔖

Strongly Connected Components without Recursion (iterative DFS)

2023/03/05に公開

(問題ページ)
https://atcoder.jp/contests/typical90/tasks/typical90_u


1.考え方

SCC(強連結成分分解)では、帰りがけの頂点を番号ラベリングする必要があります。
この部分をどうするかがポイントになると思います。
(参考)強連結成分(SCC) | technical-note

おもに以下の2点を考慮する必要があります。
(1)stackに、「頂点番号(帰り)、頂点番号(行き)」を追加していくようにします。

stackの内容は、以下のようになります。

  • 走査前、「V1(帰り), V1(行き)」を追加。
      stack = [V1(帰り), V1(行き)]
  • 現在地V1のとき、「V2(帰り), V2(行き)」を追加。
      stack = [V1(帰り), V2(帰り), V2(行き)]
  • 現在地V2のとき、「V3(帰り), V3(行き)」を追加。
      stack = [V1(帰り), V2(帰り), V3(帰り), V3(行き)]
  • 現在地V3のとき、追加なし。
      stack = [V1(帰り), V2(帰り), V3(帰り)]

(2)同じ頂点に複数の経路が存在する場合
上記(1)のやり方だけでは、同じ頂点に、複数の頂点からの経路がある場合に不具合が生じます。
頂点V4に対し、V1,V2,V3から経路がある以下のような有向グラフのときを考えます。
そして、DFSで走査経路が、「V1 → V2 → V3 → V4」になる場合を考えます。

  • 走査前、「V1(帰り), V1(行き)」を追加。
      stack = [V1(帰り), V1(行き)]
  • 現在地V1のとき、「V4(帰り), V4(行き), V2(帰り), V2(行き)」を追加。
      stack = [V1(帰り), V4(帰り), V4(行き), V2(帰り), V2(行き)]
  • 現在地V2のとき、「V4(帰り), V4(行き), V3(帰り), V3(行き)」を追加。
      stack = [V1(帰り), V4(帰り), V4(行き), V2(帰り), V4(帰り), V4(行き), V3(帰り), V3(行き)]
  • 現在地V3のとき、「V4(帰り), V4(行き)」を追加。
      stack = [V1(帰り), V4(帰り), V4(行き), V2(帰り), V4(帰り), V4(行き), V3(帰り), V4(帰り), V4(行き)]
  • 現在地V4のとき、追加なし
      stack = [V1(帰り), V4(帰り), V4(行き), V2(帰り), V4(帰り), V4(行き), V3(帰り), V4(帰り)]

上記の内容で分かるように、頂点V4(帰りがけ)が複数回stackに保存されてしまいます。
そのため、「帰りがけ」を走査済みかどうかをチェックする必要があります。

今回の実装では、DFSで頂点が「行きがけ」操作済みかどうかを管理するために、配列visitedを用いており、
  ・値が0の時、未visited
  ・値が1の時、visited (「行きがけ」操作済み)
としています。

この配列に入れる数値に、
  ・値が2の時、finished
として、これを「帰りがけ」を走査済みのステータスにします。

stackから頂点番号を取り出した後、それが帰りがけの頂点だったとき、
  ・配列visitedの値が2で無ければ、その頂点を「帰りがけの頂点」として
   配列v_stackに記録します。
  ・値が2だったときは何もしない
というようにしました。

また、stackに頂点番号をpushするときに、「帰りがけ」を表すためにビット反転(~num)した値をpushすることにしました。
これについては、以下のサイトを参考にしています。
非再帰 Euler Tour を Python でやる

DFSは以下のように実装しました。

void dfs(int v_sta) {
    vector<int> dfs_stack = {~v_sta, v_sta};

    while (!dfs_stack.empty()) {
        int v = dfs_stack.back();
        dfs_stack.pop_back();

        if (v < 0) {
            if (visited[~v] != 2) {
                visited[~v] = 2;
                v_stack.push_back(~v);
            }
            continue;
        }
        if (visited[v] != 0) {
            continue;
        }

        visited[v] = 1;

        for (int nv : edge[v]) {
            if (visited[nv] == 0) {
                dfs_stack.push_back(~nv);
                dfs_stack.push_back(nv);
            }
        }
    }
}

3.code

コード全体は、こうなりました。

#include <iostream>
#include <vector>
#include <unordered_set>
#include <unordered_map>
#include <map>
using namespace std;
#include <algorithm>
#include <queue>
#include <set>


int N, M;
int num_max = 100000*2 + 100;
vector<unordered_set<int>> edge(num_max), edge_rev(num_max);
vector<int> visited(num_max, 0);  // 1: visited, 2: finished
vector<int> visited_rev(num_max, 0);
vector<int> v_stack(0);


void dfs(int v_sta) {
    vector<int> dfs_stack = {~v_sta, v_sta};

    while (!dfs_stack.empty()) {
        int v = dfs_stack.back();
        dfs_stack.pop_back();

        if (v < 0) {
            if (visited[~v] != 2) {
                visited[~v] = 2;
                v_stack.push_back(~v);
            }
            continue;
        }
        if (visited[v] != 0) {
            continue;
        }

        visited[v] = 1;

        for (int nv : edge[v]) {
            if (visited[nv] == 0) {
                dfs_stack.push_back(~nv);
                dfs_stack.push_back(nv);
            }
        }
    }
}


int dfs_rev(int v_sta) {
    vector<int> dfs_rev_stack = {v_sta};

    int cnt = 0;
    while (!dfs_rev_stack.empty()) {
        int v = dfs_rev_stack.back();
        dfs_rev_stack.pop_back();

        if (visited_rev[v] != 0) {
            continue;
        }

        visited_rev[v] = 1;
        cnt++;

        for (int nv : edge_rev[v]) {
            if (visited_rev[nv] == 0) {
                dfs_rev_stack.push_back(nv);
            }
        }
    }
    return cnt;
}


int main() {
    cin >> N >> M;

    vector<int> A(M), B(M);
    for (int i = 0; i < M; i++) {
        cin >> A[i] >> B[i];
        A[i]--;
        B[i]--;
    }

    for (int i = 0; i < M; i++) {
        edge[A[i]].insert(B[i]);
        edge_rev[B[i]].insert(A[i]);
    }

    for (int i = 0; i < N; i++) {
        if (visited[i] == 0) {
            dfs(i);
        }
    }

    long long ans = 0;
    for (int i = 0; i < N; i++) {
        int v = v_stack.back();
        v_stack.pop_back();
        if (visited_rev[v] == 0) {
            int c = dfs_rev(v);
            ans += c * (c - 1LL) / 2LL;
        }
    }
    printf("%lld\n", ans);
}

Discussion