Chapter 08無料公開

Union-Find

Ryo Suzuki
Ryo Suzuki
2023.01.09に更新

C++ 標準ライブラリを用いた、Union-Find (別名: Disjoint-set) の実装です。

1. Union-Find のテンプレート

機能 説明 1.1 1.2 1.3 1.4
.find(i) i を含むグループの root (代表) を返す
.merge(a, b) a を含むグループと b を含むグループを併合する
.connected(a, b) a と b が同じグループに属すかを返す
.size(i) i を含むグループの要素数を返す
経路圧縮 計算量をおよそ O(log N) に減らす
union by size 経路圧縮との組み合わせで、計算量を競プロの実用範囲でほぼ O(1) に減らす

1.1 シンプルな実装

コード
#include <iostream>
#include <vector>
#include <numeric> // std::iota()

/// @brief Union-Find 木
/// @note 1.1 シンプルな実装
/// @see https://zenn.dev/reputeless/books/standard-cpp-for-competitive-programming/viewer/union-find
class UnionFind
{
public:

	UnionFind() = default;

	/// @brief Union-Find 木を構築します。
	/// @param n 要素数
	explicit UnionFind(size_t n)
		: m_parents(n)
	{
		std::iota(m_parents.begin(), m_parents.end(), 0);
	}

	/// @brief 頂点 i の root のインデックスを返します。
	/// @param i 調べる頂点のインデックス
	/// @return 頂点 i の root のインデックス
	int find(int i)
	{
		if (m_parents[i] == i)
		{
			return i;
		}

		// 経路圧縮
		return (m_parents[i] = find(m_parents[i]));
	}

	/// @brief a のグループと b のグループを統合します。
	/// @param a 一方のインデックス
	/// @param b 他方のインデックス
	void merge(int a, int b)
	{
		a = find(a);
		b = find(b);

		if (a != b)
		{
			m_parents[b] = a;
		}
	}

	/// @brief a と b が同じグループに属すかを返します。
	/// @param a 一方のインデックス
	/// @param b 他方のインデックス
	/// @return a と b が同じグループに属す場合 true, それ以外の場合は false
	bool connected(int a, int b)
	{
		return (find(a) == find(b));
	}

private:

	// m_parents[i] は i の 親,
	// root の場合は自身が親
	std::vector<int> m_parents;
};

1.2 グループの要素数取得対応

コード
#include <iostream>
#include <vector>
#include <numeric> // std::iota()

/// @brief Union-Find 木
/// @note 1.2 グループの要素数取得対応
/// @see https://zenn.dev/reputeless/books/standard-cpp-for-competitive-programming/viewer/union-find
class UnionFind
{
public:

	UnionFind() = default;

	/// @brief Union-Find 木を構築します。
	/// @param n 要素数
	explicit UnionFind(size_t n)
		: m_parents(n)
		, m_sizes(n, 1)
	{
		std::iota(m_parents.begin(), m_parents.end(), 0);
	}

	/// @brief 頂点 i の root のインデックスを返します。
	/// @param i 調べる頂点のインデックス
	/// @return 頂点 i の root のインデックス
	int find(int i)
	{
		if (m_parents[i] == i)
		{
			return i;
		}

		// 経路圧縮
		return (m_parents[i] = find(m_parents[i]));
	}

	/// @brief a のグループと b のグループを統合します。
	/// @param a 一方のインデックス
	/// @param b 他方のインデックス
	void merge(int a, int b)
	{
		a = find(a);
		b = find(b);

		if (a != b)
		{
			m_sizes[a] += m_sizes[b];
			m_parents[b] = a;
		}
	}

	/// @brief a と b が同じグループに属すかを返します。
	/// @param a 一方のインデックス
	/// @param b 他方のインデックス
	/// @return a と b が同じグループに属す場合 true, それ以外の場合は false
	bool connected(int a, int b)
	{
		return (find(a) == find(b));
	}

	/// @brief i が属するグループの要素数を返します。
	/// @param i インデックス
	/// @return i が属するグループの要素数
	int size(int i)
	{
		return m_sizes[find(i)];
	}

private:

	// m_parents[i] は i の 親,
	// root の場合は自身が親
	std::vector<int> m_parents;

	// グループの要素数 (root 用)
	// i が root のときのみ, m_sizes[i] はそのグループに属する要素数を表す
	std::vector<int> m_sizes;
};

1.3 高速化

コード
#include <iostream>
#include <vector>
#include <numeric> // std::iota()
#include <utility> // std::swap()

/// @brief Union-Find 木
/// @note 1.3 高速化
/// @see https://zenn.dev/reputeless/books/standard-cpp-for-competitive-programming/viewer/union-find
class UnionFind
{
public:

	UnionFind() = default;

	/// @brief Union-Find 木を構築します。
	/// @param n 要素数
	explicit UnionFind(size_t n)
		: m_parents(n)
		, m_sizes(n, 1)
	{
		std::iota(m_parents.begin(), m_parents.end(), 0);
	}

	/// @brief 頂点 i の root のインデックスを返します。
	/// @param i 調べる頂点のインデックス
	/// @return 頂点 i の root のインデックス
	int find(int i)
	{
		if (m_parents[i] == i)
		{
			return i;
		}

		// 経路圧縮
		return (m_parents[i] = find(m_parents[i]));
	}

	/// @brief a のグループと b のグループを統合します。
	/// @param a 一方のインデックス
	/// @param b 他方のインデックス
	void merge(int a, int b)
	{
		a = find(a);
		b = find(b);

		if (a != b)
		{
			// union by size (小さいほうが子になる)
			if (m_sizes[a] < m_sizes[b])
			{
				std::swap(a, b);
			}

			m_sizes[a] += m_sizes[b];
			m_parents[b] = a;
		}
	}

	/// @brief a と b が同じグループに属すかを返します。
	/// @param a 一方のインデックス
	/// @param b 他方のインデックス
	/// @return a と b が同じグループに属す場合 true, それ以外の場合は false
	bool connected(int a, int b)
	{
		return (find(a) == find(b));
	}

	/// @brief i が属するグループの要素数を返します。
	/// @param i インデックス
	/// @return i が属するグループの要素数
	int size(int i)
	{
		return m_sizes[find(i)];
	}

private:

	// m_parents[i] は i の 親,
	// root の場合は自身が親
	std::vector<int> m_parents;

	// グループの要素数 (root 用)
	// i が root のときのみ, m_sizes[i] はそのグループに属する要素数を表す
	std::vector<int> m_sizes;
};

1.4 高速化 + 省メモリ化

コード
#include <iostream>
#include <vector>
#include <utility> // std::swap()

/// @brief Union-Find 木
/// @note 1.4 高速化 + 省メモリ化
/// @see https://zenn.dev/reputeless/books/standard-cpp-for-competitive-programming/viewer/union-find
class UnionFind
{
public:

	UnionFind() = default;

	/// @brief Union-Find 木を構築します。
	/// @param n 要素数
	explicit UnionFind(size_t n)
		: m_parentsOrSize(n, -1) {}

	/// @brief 頂点 i の root のインデックスを返します。
	/// @param i 調べる頂点のインデックス
	/// @return 頂点 i の root のインデックス
	int find(int i)
	{
		if (m_parentsOrSize[i] < 0)
		{
			return i;
		}

		// 経路圧縮
		return (m_parentsOrSize[i] = find(m_parentsOrSize[i]));
	}

	/// @brief a のグループと b のグループを統合します。
	/// @param a 一方のインデックス
	/// @param b 他方のインデックス
	void merge(int a, int b)
	{
		a = find(a);
		b = find(b);

		if (a != b)
		{
			// union by size (小さいほうが子になる)
			if (-m_parentsOrSize[a] < -m_parentsOrSize[b])
			{
				std::swap(a, b);
			}

			m_parentsOrSize[a] += m_parentsOrSize[b];
			m_parentsOrSize[b] = a;
		}
	}

	/// @brief a と b が同じグループに属すかを返します。
	/// @param a 一方のインデックス
	/// @param b 他方のインデックス
	/// @return a と b が同じグループに属す場合 true, それ以外の場合は false
	bool connected(int a, int b)
	{
		return (find(a) == find(b));
	}

	/// @brief i が属するグループの要素数を返します。
	/// @param i インデックス
	/// @return i が属するグループの要素数
	int size(int i)
	{
		return -m_parentsOrSize[find(i)];
	}

private:

	// m_parentsOrSize[i] は i の 親,
	// ただし root の場合は (-1 * そのグループに属する要素数)
	std::vector<int> m_parentsOrSize;
};

2. Union-Find の例題

ATC 001 B - Union Find

コード
#include <iostream>
#include <vector>
#include <numeric> // std::iota()

// Union-Find 木 (1.1 シンプルな実装)
class UnionFind
{
public:

	UnionFind() = default;

	// n 個の要素
	explicit UnionFind(size_t n)
		: m_parents(n)
	{
		std::iota(m_parents.begin(), m_parents.end(), 0);
	}

	// i の root を返す
	int find(int i)
	{
		if (m_parents[i] == i)
		{
			return i;
		}

		// 経路圧縮
		return (m_parents[i] = find(m_parents[i]));
	}

	// a の木と b の木を統合
	void merge(int a, int b)
	{
		a = find(a);
		b = find(b);

		if (a != b)
		{
			m_parents[b] = a;
		}
	}

	// a と b が同じ木に属すかを返す
	bool connected(int a, int b)
	{
		return (find(a) == find(b));
	}

private:

	// m_parents[i] は i の 親,
	// root の場合は自身が親
	std::vector<int> m_parents;
};

int main()
{
	int N, Q;
	std::cin >> N >> Q;

	UnionFind uf(N);

	while (Q--)
	{
		int t, u, v;
		std::cin >> t >> u >> v;

		if (t == 0)
		{
			uf.merge(u, v);
		}
		else // t == 1
		{
			std::cout << (uf.connected(u, v) ? "Yes\n" : "No\n");
		}
	}
}

ABC 075 C - Bridge

コード
#include <iostream>
#include <vector>
#include <numeric> // std::iota()

// Union-Find 木 (1.1 シンプルな実装)
class UnionFind
{
public:

	UnionFind() = default;

	// n 個の要素
	explicit UnionFind(size_t n)
		: m_parents(n)
	{
		std::iota(m_parents.begin(), m_parents.end(), 0);
	}

	// i の root を返す
	int find(int i)
	{
		if (m_parents[i] == i)
		{
			return i;
		}

		// 経路圧縮
		return (m_parents[i] = find(m_parents[i]));
	}

	// a の木と b の木を統合
	void merge(int a, int b)
	{
		a = find(a);
		b = find(b);

		if (a != b)
		{
			m_parents[b] = a;
		}
	}

	// a と b が同じ木に属すかを返す
	bool connected(int a, int b)
	{
		return (find(a) == find(b));
	}

private:

	// m_parents[i] は i の 親,
	// root の場合は自身が親
	std::vector<int> m_parents;
};

int main()
{
	// N 頂点 M 辺
	int N, M;
	std::cin >> N >> M;

	std::vector<int> A(M), B(M);

	for (int i = 0; i < M; ++i)
	{
		int a, b;
		std::cin >> a >> b;
		A[i] = --a;
		B[i] = --b;
	}

	int ans = 0;

	// 各辺について
	for (int i = 0; i < M; ++i)
	{
		UnionFind uf(N);

		// 辺 i を取り除いた Union-Find 木を作る
		for (int k = 0; k < M; ++k)
		{
			if (i != k)
			{
				uf.merge(A[k], B[k]);
			}
		}

		// root がいくつあるか
		int count = 0;

		// 各頂点について
		for (int k = 0; k < N; ++k)
		{
			// 自身が root なら count を増やす
			if (k == uf.find(k))
			{
				++count;
			}
		}

		// 最終的にグラフが非連結になっていたら
		if (1 < count)
		{
			// 削除した辺は橋であった
			++ans;
		}
	}

	std::cout << ans << '\n';
}

ABC 177 D - Friends

コード
#include <iostream>
#include <vector>
#include <numeric> // std::iota()
#include <utility> // std::swap()

// Union-Find 木 (1.3 高速化)
class UnionFind
{
public:

	UnionFind() = default;

	// n 個の要素
	explicit UnionFind(size_t n)
		: m_parents(n)
		, m_sizes(n, 1)
	{
		std::iota(m_parents.begin(), m_parents.end(), 0);
	}

	// i の root を返す
	int find(int i)
	{
		if (m_parents[i] == i)
		{
			return i;
		}

		// 経路圧縮
		return (m_parents[i] = find(m_parents[i]));
	}

	// a の木と b の木を統合
	void merge(int a, int b)
	{
		a = find(a);
		b = find(b);

		if (a != b)
		{
			// union by size (小さいほうが子になる)
			if (m_sizes[a] < m_sizes[b])
			{
				std::swap(a, b);
			}

			m_sizes[a] += m_sizes[b];
			m_parents[b] = a;
		}
	}

	// a と b が同じ木に属すかを返す
	bool connected(int a, int b)
	{
		return (find(a) == find(b));
	}

	// i が属するグループの要素数を返す
	int size(int i)
	{
		return m_sizes[find(i)];
	}

private:

	// m_parents[i] は i の 親,
	// root の場合は自身が親
	std::vector<int> m_parents;

	// グループの要素数 (root 用)
	// i が root のときのみ, m_sizes[i] はそのグループに属する要素数を表す
	std::vector<int> m_sizes;
};

int main()
{
	std::cin.tie(0)->sync_with_stdio(0);

	int N, M;
	std::cin >> N >> M;

	UnionFind uf(N);

	while (M--)
	{
		int A, B;
		std::cin >> A >> B;
		--A; --B;
		uf.merge(A, B);
	}

	int answer = 0;

	for (int i = 0; i < N; ++i)
	{
		answer = std::max(answer, uf.size(i));
	}

	std::cout << answer << '\n';
}

ABC 049 D - 連結

AOJ GRL_2_A - Minimum Spanning Tree

  • クラスカル法
コード
#include <iostream>
#include <vector>
#include <numeric> // std::iota()
#include <algorithm> // std::sort()

// Union-Find 木 (1.1 シンプルな実装)
class UnionFind
{
public:

	UnionFind() = default;

	// n 個の要素
	explicit UnionFind(size_t n)
		: m_parents(n)
	{
		std::iota(m_parents.begin(), m_parents.end(), 0);
	}

	// i の root を返す
	int find(int i)
	{
		if (m_parents[i] == i)
		{
			return i;
		}

		// 経路圧縮
		return (m_parents[i] = find(m_parents[i]));
	}

	// a の木と b の木を統合
	void merge(int a, int b)
	{
		a = find(a);
		b = find(b);

		if (a != b)
		{
			m_parents[b] = a;
		}
	}

	// a と b が同じ木に属すかを返す
	bool connected(int a, int b)
	{
		return (find(a) == find(b));
	}

private:

	// m_parents[i] は i の 親,
	// root の場合は自身が親
	std::vector<int> m_parents;
};

struct Edge
{
	int from;
	int to;
	int cost;

	// コストに基づく大小定義
	bool operator <(const Edge& other) const
	{
		return (cost < other.cost);
	}
};

int main()
{
	int V, E;
	std::cin >> V >> E;

	std::vector<Edge> edges(E);
	for (auto& edge : edges)
	{
		std::cin >> edge.from >> edge.to >> edge.cost;
	}

	std::sort(edges.begin(), edges.end());

	UnionFind uf(V);
	long long sum = 0;

	for (const auto& edge : edges)
	{
		if (!uf.connected(edge.from, edge.to))
		{
			uf.merge(edge.from, edge.to);
			sum += edge.cost;
		}
	}

	std::cout << sum << '\n';
}

ARC 032 B - 道路工事

ABC 264 E - Blackout 2

コード 1
#include <iostream>
#include <vector>
#include <numeric> // std::iota()
#include <utility> // std::swap()
#include <algorithm> // std::reverse()

// Union-Find 木 (1.3 高速化)
class UnionFind
{
public:

	UnionFind() = default;

	// n 個の要素
	explicit UnionFind(size_t n)
		: m_parents(n)
		, m_sizes(n, 1)
	{
		std::iota(m_parents.begin(), m_parents.end(), 0);
	}

	// i の root を返す
	int find(int i)
	{
		if (m_parents[i] == i)
		{
			return i;
		}

		// 経路圧縮
		return (m_parents[i] = find(m_parents[i]));
	}

	// a の木と b の木を統合
	void merge(int a, int b)
	{
		a = find(a);
		b = find(b);

		if (a != b)
		{
			// union by size (小さいほうが子になる)
			if (m_sizes[a] < m_sizes[b])
			{
				std::swap(a, b);
			}

			m_sizes[a] += m_sizes[b];
			m_parents[b] = a;
		}
	}

	// a と b が同じ木に属すかを返す
	bool connected(int a, int b)
	{
		return (find(a) == find(b));
	}

	// i が属するグループの要素数を返す
	int size(int i)
	{
		return m_sizes[find(i)];
	}

private:

	// m_parents[i] は i の 親,
	// root の場合は自身が親
	std::vector<int> m_parents;

	// グループの要素数 (root 用)
	// i が root のときのみ, m_sizes[i] はそのグループに属する要素数を表す
	std::vector<int> m_sizes;
};

struct Edge
{
	int from;
	int to;
};

int main()
{
	int N, M, E;
	std::cin >> N >> M >> E;

	std::vector<Edge> edges(E);
	for (auto& edge : edges)
	{
		std::cin >> edge.from >> edge.to;
		--edge.from; --edge.to;
	}

	int Q;
	std::cin >> Q;

	// ある送電線が最終時点で残っているかを記録する配列
	std::vector<bool> finallyConnected(E, true);

	std::vector<int> X(Q);
	for (auto& x : X)
	{
		std::cin >> x;
		--x;
		finallyConnected[x] = false;
	}

	// 各地点が配電されているか (root 用)
	std::vector<bool> electrified_root(N + M);
	for (int i = N; i < (N + M); ++i)
	{
		electrified_root[i] = true;
	}

	UnionFind uf(N + M);

	// 配電されている都市の数
	int sum = 0;

	// 最終時点での接続情報を構築する
	for (int i = 0; i < E; ++i)
	{
		if (!finallyConnected[i])
		{
			continue;
		}

		// 接続構築
		{
			const Edge edge = edges[i];

			if (uf.connected(edge.from, edge.to))
			{
				continue;
			}

			const bool eFrom = electrified_root[uf.find(edge.from)];
			const bool eTo = electrified_root[uf.find(edge.to)];

			if ((eFrom == false) && (eTo == true)) // eFrom 側が新たに電化する
			{
				sum += uf.size(edge.from);
			}
			else if ((eFrom == true) && (eTo == false)) // eTo 側が新たに電化する
			{
				sum += uf.size(edge.to);
			}

			uf.merge(edge.from, edge.to);
			electrified_root[uf.find(edge.from)] = (eFrom || eTo);
		}
	}

	std::vector<int> results;

	// 接続イベントを逆からたどる
	std::reverse(X.begin(), X.end());

	for (const auto& x : X)
	{
		// 現時点での配電都市数を記録
		results.push_back(sum);

		// 接続構築
		{
			const Edge edge = edges[x];

			if (uf.connected(edge.from, edge.to))
			{
				continue;
			}

			const bool eFrom = electrified_root[uf.find(edge.from)];
			const bool eTo = electrified_root[uf.find(edge.to)];

			if ((eFrom == false) && (eTo == true)) // eFrom 側が新たに電化する
			{
				sum += uf.size(edge.from);
			}
			else if ((eFrom == true) && (eTo == false)) // eTo 側が新たに電化する
			{
				sum += uf.size(edge.to);
			}

			uf.merge(edge.from, edge.to);
			electrified_root[uf.find(edge.from)] = (eFrom || eTo);
		}
	}

	std::reverse(results.begin(), results.end());

	for (const auto& result : results)
	{
		std::cout << result << '\n';
	}
}
コード 2
#include <iostream>
#include <vector>
#include <numeric> // std::iota()
#include <utility> // std::swap()
#include <algorithm> // std::reverse()

// Union-Find 木 (1.3 高速化)
class UnionFind
{
public:

	UnionFind() = default;

	// n 個の要素
	explicit UnionFind(size_t n)
		: m_parents(n)
		, m_sizes(n, 1)
	{
		std::iota(m_parents.begin(), m_parents.end(), 0);
	}

	// i の root を返す
	int find(int i)
	{
		if (m_parents[i] == i)
		{
			return i;
		}

		// 経路圧縮
		return (m_parents[i] = find(m_parents[i]));
	}

	// a の木と b の木を統合
	void merge(int a, int b)
	{
		a = find(a);
		b = find(b);

		if (a != b)
		{
			// union by size (小さいほうが子になる)
			if (m_sizes[a] < m_sizes[b])
			{
				std::swap(a, b);
			}

			m_sizes[a] += m_sizes[b];
			m_parents[b] = a;
		}
	}

	// a と b が同じ木に属すかを返す
	bool connected(int a, int b)
	{
		return (find(a) == find(b));
	}

	// i が属するグループの要素数を返す
	int size(int i)
	{
		return m_sizes[find(i)];
	}

private:

	// m_parents[i] は i の 親,
	// root の場合は自身が親
	std::vector<int> m_parents;

	// グループの要素数 (root 用)
	// i が root のときのみ, m_sizes[i] はそのグループに属する要素数を表す
	std::vector<int> m_sizes;
};

struct Edge
{
	int from;
	int to;
};

int main()
{
	int N, M, E;
	std::cin >> N >> M >> E;

	std::vector<Edge> edges(E);
	for (auto& edge : edges)
	{
		std::cin >> edge.from >> edge.to;
		edge.from = std::min((edge.from - 1), N);
		edge.to = std::min((edge.to - 1), N);
	}

	int Q;
	std::cin >> Q;

	// ある送電線が最終時点で残っているかを記録する配列
	std::vector<bool> finallyConnected(E, true);

	std::vector<int> X(Q);
	for (auto& x : X)
	{
		std::cin >> x;
		--x;
		finallyConnected[x] = false;
	}

	UnionFind uf(N + 1);

	for (int i = 0; i < E; ++i)
	{
		if (!finallyConnected[i])
		{
			continue;
		}

		const Edge edge = edges[i];

		uf.merge(edge.from, edge.to);
	}

	std::vector<int> results;

	// 接続イベントを逆からたどる
	std::reverse(X.begin(), X.end());

	for (const auto& x : X)
	{
		// 現時点での配電都市数を記録
		results.push_back(uf.size(N) - 1);

		const Edge edge = edges[x];

		uf.merge(edge.from, edge.to);
	}

	std::reverse(results.begin(), results.end());

	for (const auto& result : results)
	{
		std::cout << result << '\n';
	}
}

ABC 183 F - Confluence