💬

RustでUnionFind

2021/02/11に公開1

例題

以下の問題を例に、UnionFind を実装していきます。
問題の概要は、N 個のノードと、2 つのノードを結ぶ M 個の辺があるとき、どのノードからもすべてのノードに移動できるようにしたい。最小で何本の辺を追加する必要があるか、というものです。
https://atcoder.jp/contests/arc032/tasks/arc032_2

制約

N(1≦N≦100,000)
M(0≦M≦100,000)

実装

M 個の辺で結ばれたノードを unite メソッドで結びつけます。その後、任意のノード(ここでは 0)をすべてのノードに対して unite メソッドを走らせます。

このメソッドは、2 つのノードのルートが同じなら false、違えばそれらを連結後に true を返します。そのため、「true ならインクリメント」すると、答えがでます。

この UnionFind はけんちょんさんの本のコードを Rust に落とし込んだものなので、理論面や詳細は同書を参照ください。

use std::io::*;
use std::str::FromStr;

fn read<T: FromStr>() -> T {
    let stdin = stdin();
    let stdin = stdin.lock();
    let token: String = stdin
        .bytes()
        .map(|c| c.expect("filed to read char") as char)
        .skip_while(|c| c.is_whitespace())
        .take_while(|c| !c.is_whitespace())
        .collect();
    token.parse().ok().expect("failed to parse token")
}

struct UnionFind {
    par: Vec<usize>,
    siz: Vec<usize>,
}

impl UnionFind {
    fn new(n: usize) -> Self {
        UnionFind {
            par: (0..n).collect(),
            siz: vec![1; n],
        }
    }

    fn root(&mut self, x: usize) -> usize {
        if self.par[x] == x {
            return x;
        }
        self.par[x] = self.root(self.par[x]);
        self.par[x]
    }

    fn issame(&mut self, x: usize, y: usize) -> bool {
        self.root(x) == self.root(y)
    }

    fn unite(&mut self, mut parent: usize, mut child: usize) -> bool {
        parent = self.root(parent);
        child = self.root(child);

        if parent == child {
            return false;
        }

        if self.siz[parent] < self.siz[child] {
            std::mem::swap(&mut parent, &mut child);
        }

        self.par[child] = parent;
        self.siz[parent] += self.siz[child];
        true
    }

    fn size(&mut self, x: usize) -> usize {
        let root = self.root(x);
        self.siz[root]
    }
}

fn main() {
    let n: usize = read();
    let m: usize = read();
    let mut uf = UnionFind::new(n);
    for _ in 0..m {
        let a: usize = read();
        let b: usize = read();
        uf.unite(a - 1, b - 1);
    }
    let mut cnt = 0;
    for i in 0..n {
        if uf.unite(0, i) {
            cnt += 1;
        }
    }
    println!("{}", cnt);
}