Chapter 09無料公開

重み付き Union-Find

Ryo Suzuki
Ryo Suzuki
2023.01.01に更新

C++ 標準ライブラリを用いた、重み付き Union-Find(ポテンシャル付き Union-Find)の実装です。

1. 重み付き Union-Find のテンプレート

機能 説明 1.1 1.2
.find(i) i を含むグループの root (代表) を返す
.merge(a, b, w) a を含むグループと b を含むグループを併合し、b の重みは a と比べて w 大きくなるようにする
.diff(a, b) (b の重み) - (a の重み) を返す。a と b が同じグループに属さない場合の結果は不定
.connected(a, b) a と b が同じグループに属すかを返す
.size(i) i を含むグループの要素数を返す
計算量の削減 経路圧縮と union by size

1.1 シンプルな実装

コード
#include <iostream>
#include <vector>
#include <numeric> // std::iota()
#include <utility> // std::swap()

/// @brief 重み付き Union-Find 木
/// @tparam Type 重みの表現に使う型
/// @note 1.1 シンプルな実装
/// @see https://zenn.dev/reputeless/books/standard-cpp-for-competitive-programming/viewer/weighted-union-find
template <class Type>
class WeightedUnionFind
{
public:

	WeightedUnionFind() = default;

	/// @brief 重み付き Union-Find 木を構築します。
	/// @param n 要素数
	explicit WeightedUnionFind(size_t n)
		: m_parents(n)
		, m_sizes(n, 1)
		, m_diffWeights(n)
	{
		std::iota(m_parents.begin(), m_parents.end(), 0);
	}

	/// @brief 頂点 i の root のインデックスを返します。
	/// @param i 調べる頂点のインデックス
	/// @return 頂点 i の root のインデックス
	int find(int i)
	{
		if (m_parents[i] == i)
		{
			return i;
		}

		const int root = find(m_parents[i]);

		m_diffWeights[i] += m_diffWeights[m_parents[i]];

		// 経路圧縮
		return (m_parents[i] = root);
	}

	/// @brief a のグループと b のグループを統合します。
	/// @param a 一方のインデックス
	/// @param b 他方のインデックス
	/// @param w (b の重み) - (a の重み)
	void merge(int a, int b, Type w)
	{
		w += weight(a);
		w -= weight(b);

		a = find(a);
		b = find(b);

		if (a != b)
		{
			// union by size (小さいほうが子になる)
			if (m_sizes[a] < m_sizes[b])
			{
				std::swap(a, b);
				w = -w;
			}

			m_sizes[a] += m_sizes[b];
			m_parents[b] = a;
			m_diffWeights[b] = w;
		}
	}

	/// @brief (b の重み) - (a の重み) を返します。
	/// @param a 一方のインデックス
	/// @param b 他方のインデックス
	/// @remark a と b が同じグループに属さない場合の結果は不定です。
	/// @return (b の重み) - (a の重み)
	Type diff(int a, int b)
	{
		return (weight(b) - weight(a));
	}

	/// @brief a と b が同じグループに属すかを返します。
	/// @param a 一方のインデックス
	/// @param b 他方のインデックス
	/// @return a と b が同じグループに属す場合 true, それ以外の場合は false
	bool connected(int a, int b)
	{
		return (find(a) == find(b));
	}

	/// @brief i が属するグループの要素数を返します。
	/// @param i インデックス
	/// @return i が属するグループの要素数
	int size(int i)
	{
		return m_sizes[find(i)];
	}

private:

	// m_parents[i] は i の 親,
	// root の場合は自身が親
	std::vector<int> m_parents;

	// グループの要素数 (root 用)
	// i が root のときのみ, m_sizes[i] はそのグループに属する要素数を表す
	std::vector<int> m_sizes;

	// 重み
	std::vector<Type> m_diffWeights;

	Type weight(int i)
	{
		find(i); // 経路圧縮
		return m_diffWeights[i];
	}
};

1.2 省メモリ化

コード
#include <iostream>
#include <vector>
#include <utility> // std::swap()

/// @brief 重み付き Union-Find 木
/// @tparam Type 重みの表現に使う型
/// @note 1.2 省メモリ化
/// @see https://zenn.dev/reputeless/books/standard-cpp-for-competitive-programming/viewer/weighted-union-find
template <class Type>
class WeightedUnionFind
{
public:

	WeightedUnionFind() = default;

	/// @brief 重み付き Union-Find 木を構築します。
	/// @param n 要素数
	explicit WeightedUnionFind(size_t n)
		: m_parentsOrSize(n, -1)
		, m_diffWeights(n) {}

	/// @brief 頂点 i の root のインデックスを返します。
	/// @param i 調べる頂点のインデックス
	/// @return 頂点 i の root のインデックス
	int find(int i)
	{
		if (m_parentsOrSize[i] < 0)
		{
			return i;
		}

		const int root = find(m_parentsOrSize[i]);

		m_diffWeights[i] += m_diffWeights[m_parentsOrSize[i]];

		// 経路圧縮
		return (m_parentsOrSize[i] = root);
	}

	/// @brief a のグループと b のグループを統合します。
	/// @param a 一方のインデックス
	/// @param b 他方のインデックス
	/// @param w (b の重み) - (a の重み)
	void merge(int a, int b, Type w)
	{
		w += weight(a);
		w -= weight(b);

		a = find(a);
		b = find(b);

		if (a != b)
		{
			// union by size (小さいほうが子になる)
			if (-m_parentsOrSize[a] < -m_parentsOrSize[b])
			{
				std::swap(a, b);
				w = -w;
			}

			m_parentsOrSize[a] += m_parentsOrSize[b];
			m_parentsOrSize[b] = a;
			m_diffWeights[b] = w;
		}
	}

	/// @brief (b の重み) - (a の重み) を返します。
	/// @param a 一方のインデックス
	/// @param b 他方のインデックス
	/// @remark a と b が同じグループに属さない場合の結果は不定です。
	/// @return (b の重み) - (a の重み)
	Type diff(int a, int b)
	{
		return (weight(b) - weight(a));
	}

	/// @brief a と b が同じグループに属すかを返します。
	/// @param a 一方のインデックス
	/// @param b 他方のインデックス
	/// @return a と b が同じグループに属す場合 true, それ以外の場合は false
	bool connected(int a, int b)
	{
		return (find(a) == find(b));
	}

	/// @brief i が属するグループの要素数を返します。
	/// @param i インデックス
	/// @return i が属するグループの要素数
	int size(int i)
	{
		return -m_parentsOrSize[find(i)];
	}

private:

	// m_parentsOrSize[i] は i の 親,
	// ただし root の場合は (-1 * そのグループに属する要素数)
	std::vector<int> m_parentsOrSize;

	// 重み
	std::vector<Type> m_diffWeights;

	Type weight(int i)
	{
		find(i); // 経路圧縮
		return m_diffWeights[i];
	}
};

2. 重み付き Union-Find の例題

AOJ DSL_1_B - Weighted Union Find Trees

コード
#include <iostream>
#include <vector>
#include <numeric> // std::iota()
#include <utility> // std::swap()

/// @brief 重み付き Union-Find 木
/// @tparam Type 重みの表現に使う型
/// @note 1.1 シンプルな実装
/// @see https://zenn.dev/reputeless/books/standard-cpp-for-competitive-programming/viewer/weighted-union-find
template <class Type>
class WeightedUnionFind
{
public:

	WeightedUnionFind() = default;

	/// @brief 重み付き Union-Find 木を構築します。
	/// @param n 要素数
	explicit WeightedUnionFind(size_t n)
		: m_parents(n)
		, m_sizes(n, 1)
		, m_diffWeights(n)
	{
		std::iota(m_parents.begin(), m_parents.end(), 0);
	}

	/// @brief 頂点 i の root のインデックスを返します。
	/// @param i 調べる頂点のインデックス
	/// @return 頂点 i の root のインデックス
	int find(int i)
	{
		if (m_parents[i] == i)
		{
			return i;
		}

		const int root = find(m_parents[i]);

		m_diffWeights[i] += m_diffWeights[m_parents[i]];

		// 経路圧縮
		return (m_parents[i] = root);
	}

	/// @brief a のグループと b のグループを統合します。
	/// @param a 一方のインデックス
	/// @param b 他方のインデックス
	/// @param w (b の重み) - (a の重み)
	void merge(int a, int b, Type w)
	{
		w += weight(a);
		w -= weight(b);

		a = find(a);
		b = find(b);

		if (a != b)
		{
			// union by size (小さいほうが子になる)
			if (m_sizes[a] < m_sizes[b])
			{
				std::swap(a, b);
				w = -w;
			}

			m_sizes[a] += m_sizes[b];
			m_parents[b] = a;
			m_diffWeights[b] = w;
		}
	}

	/// @brief (b の重み) - (a の重み) を返します。
	/// @param a 一方のインデックス
	/// @param b 他方のインデックス
	/// @remark a と b が同じグループに属さない場合の結果は不定です。
	/// @return (b の重み) - (a の重み)
	Type diff(int a, int b)
	{
		return (weight(b) - weight(a));
	}

	/// @brief a と b が同じグループに属すかを返します。
	/// @param a 一方のインデックス
	/// @param b 他方のインデックス
	/// @return a と b が同じグループに属す場合 true, それ以外の場合は false
	bool connected(int a, int b)
	{
		return (find(a) == find(b));
	}

	/// @brief i が属するグループの要素数を返します。
	/// @param i インデックス
	/// @return i が属するグループの要素数
	int size(int i)
	{
		return m_sizes[find(i)];
	}

private:

	// m_parents[i] は i の 親,
	// root の場合は自身が親
	std::vector<int> m_parents;

	// グループの要素数 (root 用)
	// i が root のときのみ, m_sizes[i] はそのグループに属する要素数を表す
	std::vector<int> m_sizes;

	// 重み
	std::vector<Type> m_diffWeights;

	Type weight(int i)
	{
		find(i); // 経路圧縮
		return m_diffWeights[i];
	}
};

int main()
{
	int n, q;
	std::cin >> n >> q;

	WeightedUnionFind<int> uf(n);

	while (q--)
	{
		int t;
		std::cin >> t;

		if (t == 0)
		{
			int x, y, z;
			std::cin >> x >> y >> z;
			uf.merge(x, y, z);
		}
		else
		{
			int x, y;
			std::cin >> x >> y;

			if (uf.connected(x, y))
			{
				std::cout << uf.diff(x, y) << '\n';
			}
			else
			{
				std::cout << "?\n";
			}
		}
	}
}

ABC 087 D - People on a Line

コード
#include <iostream>
#include <vector>
#include <numeric> // std::iota()
#include <utility> // std::swap()

/// @brief 重み付き Union-Find 木
/// @tparam Type 重みの表現に使う型
/// @note 1.1 シンプルな実装
/// @see https://zenn.dev/reputeless/books/standard-cpp-for-competitive-programming/viewer/weighted-union-find
template <class Type>
class WeightedUnionFind
{
public:

	WeightedUnionFind() = default;

	/// @brief 重み付き Union-Find 木を構築します。
	/// @param n 要素数
	explicit WeightedUnionFind(size_t n)
		: m_parents(n)
		, m_sizes(n, 1)
		, m_diffWeights(n)
	{
		std::iota(m_parents.begin(), m_parents.end(), 0);
	}

	/// @brief 頂点 i の root のインデックスを返します。
	/// @param i 調べる頂点のインデックス
	/// @return 頂点 i の root のインデックス
	int find(int i)
	{
		if (m_parents[i] == i)
		{
			return i;
		}

		const int root = find(m_parents[i]);

		m_diffWeights[i] += m_diffWeights[m_parents[i]];

		// 経路圧縮
		return (m_parents[i] = root);
	}

	/// @brief a のグループと b のグループを統合します。
	/// @param a 一方のインデックス
	/// @param b 他方のインデックス
	/// @param w (b の重み) - (a の重み)
	void merge(int a, int b, Type w)
	{
		w += weight(a);
		w -= weight(b);

		a = find(a);
		b = find(b);

		if (a != b)
		{
			// union by size (小さいほうが子になる)
			if (m_sizes[a] < m_sizes[b])
			{
				std::swap(a, b);
				w = -w;
			}

			m_sizes[a] += m_sizes[b];
			m_parents[b] = a;
			m_diffWeights[b] = w;
		}
	}

	/// @brief (b の重み) - (a の重み) を返します。
	/// @param a 一方のインデックス
	/// @param b 他方のインデックス
	/// @remark a と b が同じグループに属さない場合の結果は不定です。
	/// @return (b の重み) - (a の重み)
	Type diff(int a, int b)
	{
		return (weight(b) - weight(a));
	}

	/// @brief a と b が同じグループに属すかを返します。
	/// @param a 一方のインデックス
	/// @param b 他方のインデックス
	/// @return a と b が同じグループに属す場合 true, それ以外の場合は false
	bool connected(int a, int b)
	{
		return (find(a) == find(b));
	}

	/// @brief i が属するグループの要素数を返します。
	/// @param i インデックス
	/// @return i が属するグループの要素数
	int size(int i)
	{
		return m_sizes[find(i)];
	}

private:

	// m_parents[i] は i の 親,
	// root の場合は自身が親
	std::vector<int> m_parents;

	// グループの要素数 (root 用)
	// i が root のときのみ, m_sizes[i] はそのグループに属する要素数を表す
	std::vector<int> m_sizes;

	// 重み
	std::vector<Type> m_diffWeights;

	Type weight(int i)
	{
		find(i); // 経路圧縮
		return m_diffWeights[i];
	}
};

int main()
{
	// N 頂点 M 辺
	int N, M;
	std::cin >> N >> M;

	WeightedUnionFind<int> uf(N);

	for (int i = 0; i < M; ++i)
	{
		int l, r, d;
		std::cin >> l >> r >> d;
		--l, --r;

		if (!uf.connected(l, r))
		{
			uf.merge(l, r, d);
		}
		else if (uf.diff(l, r) != d)
		{
			std::cout << "No\n";
			return 0;
		}
	}

	std::cout << "Yes\n";
}

ABC 280 F - Pay or Receive

コード
#include <iostream>
#include <vector>
#include <numeric> // std::iota()
#include <utility> // std::swap()

/// @brief 重み付き Union-Find 木
/// @tparam Type 重みの表現に使う型
/// @note 1.1 シンプルな実装
/// @see https://zenn.dev/reputeless/books/standard-cpp-for-competitive-programming/viewer/weighted-union-find
template <class Type>
class WeightedUnionFind
{
public:

	WeightedUnionFind() = default;

	/// @brief 重み付き Union-Find 木を構築します。
	/// @param n 要素数
	explicit WeightedUnionFind(size_t n)
		: m_parents(n)
		, m_sizes(n, 1)
		, m_diffWeights(n)
	{
		std::iota(m_parents.begin(), m_parents.end(), 0);
	}

	/// @brief 頂点 i の root のインデックスを返します。
	/// @param i 調べる頂点のインデックス
	/// @return 頂点 i の root のインデックス
	int find(int i)
	{
		if (m_parents[i] == i)
		{
			return i;
		}

		const int root = find(m_parents[i]);

		m_diffWeights[i] += m_diffWeights[m_parents[i]];

		// 経路圧縮
		return (m_parents[i] = root);
	}

	/// @brief a のグループと b のグループを統合します。
	/// @param a 一方のインデックス
	/// @param b 他方のインデックス
	/// @param w (b の重み) - (a の重み)
	void merge(int a, int b, Type w)
	{
		w += weight(a);
		w -= weight(b);

		a = find(a);
		b = find(b);

		if (a != b)
		{
			// union by size (小さいほうが子になる)
			if (m_sizes[a] < m_sizes[b])
			{
				std::swap(a, b);
				w = -w;
			}

			m_sizes[a] += m_sizes[b];
			m_parents[b] = a;
			m_diffWeights[b] = w;
		}
	}

	/// @brief (b の重み) - (a の重み) を返します。
	/// @param a 一方のインデックス
	/// @param b 他方のインデックス
	/// @remark a と b が同じグループに属さない場合の結果は不定です。
	/// @return (b の重み) - (a の重み)
	Type diff(int a, int b)
	{
		return (weight(b) - weight(a));
	}

	/// @brief a と b が同じグループに属すかを返します。
	/// @param a 一方のインデックス
	/// @param b 他方のインデックス
	/// @return a と b が同じグループに属す場合 true, それ以外の場合は false
	bool connected(int a, int b)
	{
		return (find(a) == find(b));
	}

	/// @brief i が属するグループの要素数を返します。
	/// @param i インデックス
	/// @return i が属するグループの要素数
	int size(int i)
	{
		return m_sizes[find(i)];
	}

private:

	// m_parents[i] は i の 親,
	// root の場合は自身が親
	std::vector<int> m_parents;

	// グループの要素数 (root 用)
	// i が root のときのみ, m_sizes[i] はそのグループに属する要素数を表す
	std::vector<int> m_sizes;

	// 重み
	std::vector<Type> m_diffWeights;

	Type weight(int i)
	{
		find(i); // 経路圧縮
		return m_diffWeights[i];
	}
};

int main()
{
	int N, M, Q;
	std::cin >> N >> M >> Q;

	WeightedUnionFind<long long> uf(N);
	std::vector<bool> ng(N);

	for (int i = 0; i < M; ++i)
	{
		int A, B, C;
		std::cin >> A >> B >> C;
		--A, --B;

		if (!uf.connected(A, B))
		{
			uf.merge(A, B, C);
		}
		else if (uf.diff(A, B) != C)
		{
			// 矛盾が発生する場合、そのグループは負閉路を含む
			ng[uf.find(A)] = true;
		}
	}

	while (Q--)
	{
		int X, Y;
		std::cin >> X >> Y;
		--X, --Y;

		if (!uf.connected(X, Y))
		{
			std::cout << "nan\n";
		}
		else if (ng[uf.find(X)])
		{
			std::cout << "inf\n";
		}
		else
		{
			std::cout << uf.diff(X, Y) << '\n';
		}
	}
}