🌳

素集合データ構造

2024/01/23に公開

用語

  • 素集合データ構造
    • まとめるのが得意
    • 削除は苦手
    • 複数の木を持つ
  • Disjoint-Set
    • 素集合データ構造のこと
    • 正式名 → Disjoint-set data structure
    • Disjoint-Sets と書かれていることもあり、英語圏でも表記が曖昧
  • DSU
  • Union-Find
    • 素集合データ構造に対する操作のこと
      • 実装・アルゴリズムの一つが Union-Find とする見方もある 1 2 3
      • アルゴリズム図鑑でも「素集合データ構造と、それを操作する Union-Find アルゴリズム」と解説されている
    • いずれにせよ素集合データ構造を指して Union-Find と呼んではいけない
      • スタックを Push-Pop と呼ぶぐらいピントが外れた感じになる
  • Union-Find 木
  • 連結成分
    • 木のこと
  • 非連結
    • 木が二つ以上ある状態のこと
  • 頂点 (vertex)
    • 節のこと
    • ノードともいう
  • 辺 (edge)
    • 幹のこと
    • 末端にある節のこと
    • 親と言うよりグループ内でしかたなくリーダーを任された人
  • ランク
    • 木の高さのこと
    • 葉から根までの距離のこと
  • 経路圧縮
    • 強力な最適化の一つ
    • Path Compression のこと
    • 根のメモ化もどき
    • union by rank や union by size とシナジーがある
  • union by rank
    • 最適化の一つ
    • 木の高さで調整する
    • 距離が遠い人を気にかける
  • union by size
    • 最適化の一つ
    • 木の大きさで調整する
    • 多数派の方が偉いと思っている
    • ついでに木の大きさがすぐにわかる
  • Quick Union
    • 木の形を調整せず、経路圧縮もしない、もっとも素朴な実装のこと
    • Union はただ繋げるだけしかしないのでたしかに Quick ではある
    • これに経路圧縮を足しただけで充分実用的になる
  • Quick Find
    • Union のタイミングで Find が最速になるように工夫した、あまりいけてない実装のこと
    • たしかに Find は超 Quick で O(1) だけど Union は O(n) になってしまう
    • 実用性が低い

素集合森に対する主な操作

Method 動作 必須 方言
root(x) 根を返す leader, find
unite(x, y) 同じ木にする merge, union, unify
same?(x, y) 同じ木か? connected
size(x) 木の大きさを返す component_size
make_set(x) 大きさ1の木を作る
groups 森の情報を返す
  • root, unite
    • Union-Find の名前は Union (併合) と Find (検索) の操作を主体に行うところから来ているものの、実装時のメソッド名はそれに従わず、実装者によってばらばらである
    • この二つがあればなんとでもなる
  • same?(x, y)
    • root(x) == root(y) のこと
  • size(x)
    • これがないときは同じ根をもった仲間を数えれば一応わかる
  • make_set(x)
    • なくても支障ないので用意していない実装が多い
  • groups
    • 根と節が1対多の関係になっている、森の全容がわかる情報を返す

実装

class DisjointSet
  attr_reader :parents

  def initialize(nodes = [])
    @parents = Hash.new { |h, k| h[k] = k }
    nodes.each { |e| make_set(e) }
  end

  def make_set(x)
    @parents[x]
  end

  def root(x)
    if @parents[x] == x
      x
    else
      root(@parents[x])
    end
  end

  def same?(x, y)
    root(x) == root(y)
  end

  def unite(x, y)
    x = root(x)
    y = root(y)
    if x != y
      chain(x, y)
    end
    self
  end

  def inspect
    @parents.inspect
  end

  private

  def chain(x, y)
    @parents[x] = y
  end
end

この素朴な実装を Quick Union と言う。

動作確認

最初はすべての木が独立している。

ds = DisjointSet.new("A".."E")

A と B を繋ぐ。

ds.unite("A", "B")  # => {"A"=>"B", "B"=>"B", "C"=>"C", "D"=>"D", "E"=>"E"}

さらに C と B も繋ぐ。

ds.unite("C", "B")  # => {"A"=>"B", "B"=>"B", "C"=>"B", "D"=>"D", "E"=>"E"}

A と C は B 経由で連結しているので同じ木である。

ds.same?("A", "C")  # => true

一方 A と D の木は異なる。

ds.same?("A", "D")  # => false

A と D を繋ぐ。

ds.unite("A", "D")  # => {"A"=>"B", "B"=>"D", "C"=>"B", "D"=>"D", "E"=>"E"}

A の根は B なので B と D が繋がり、A と D は同じ木になった。

ds.same?("A", "D")  # => true

A と D はすでに同じ木なので何もしない。

ds.unite("A", "D")  # => {"A"=>"B", "B"=>"D", "C"=>"B", "D"=>"D", "E"=>"E"}

E と E を同じ木とするが他の木との連結は生じない。

ds.unite("E", "E")  # => {"A"=>"B", "B"=>"D", "C"=>"B", "D"=>"D", "E"=>"E"}

E と E の木は同じ。

ds.same?("E", "E")  # => true

となって正しく動作しているので次から問題を解く。

[ATC001 B] Union Find

https://atcoder.jp/contests/atc001/tasks/unionfind_a

森を構築しながら二つの節が同じ連結成分かを確認する問題で、これはすでに試したので通るだろう。

提出したコード
class DisjointSet
  def initialize
    @parents = Hash.new { |h, k| h[k] = k }
  end

  def root(x)
    if @parents[x] == x
      x
    else
      root(@parents[x])
    end
  end

  def same?(x, y)
    root(x) == root(y)
  end

  def unite(x, y)
    x = root(x)
    y = root(y)
    if x != y
      @parents[x] = y
    end
  end
end

N, Q = gets.split.collect(&:to_i)                   # => [8, 9]
A = Q.times.collect { gets.split.collect(&:to_i) }  # => [[0, 1, 2], [0, 3, 2], [1, 1, 3], [1, 1, 4], [0, 2, 4], [1, 4, 1], [0, 4, 2], [0, 0, 0], [1, 0, 0]]
ds = DisjointSet.new
A.each do |p, a, b|
  if p == 0
    ds.unite(a, b)
  else
    ans = ds.same?(a, b) ? "Yes" : "No"             # => "Yes", "No", "Yes", "Yes"
    puts ans
  end
end

ところが TLE になる。

原因は二つ。

  1. 根は同じなのに何回も辿るから遅い
  2. 根までが遠すぎて遅い

効率化1. 経路圧縮

↓これを、

↓こうする。

順に書くと、A が B で B が C で──となると葉から根までが遠くなる。

ds = DisjointSet.new
ds.unite("A", "B")  # => {"A"=>"B", "B"=>"B"}
ds.unite("B", "C")  # => {"A"=>"B", "B"=>"C", "C"=>"C"}
ds.unite("C", "D")  # => {"A"=>"B", "B"=>"C", "C"=>"D", "D"=>"D"}
ds.unite("D", "E")  # => {"A"=>"B", "B"=>"C", "C"=>"D", "D"=>"E", "E"=>"E"}

そこで一回求めた根をメモする。

DisjointSet.prepend Module.new {
  def root(x)
    if @parents[x] == x
      x
    else
      @parents[x] = root(@parents[x])
    end
  end
}

そして葉から辿る。

ds.root("A")  # => "E"

すると経路が変わり、

一回で行けるようになる。ここで「親」も「親の親」も根に一回で行けるようになっている点がおもしろい。

まとめると経路圧縮は恐くない。経路を圧縮するなどという仰々しい言い回しの真相は単に根をメモしているだけだった。

これで 1 の「根は同じなのに何回も辿るから遅い」が解決する。

再提出したコード (AC)
class DisjointSet
  def initialize
    @parents = Hash.new { |h, k| h[k] = k }
  end

  def root(x)
    if @parents[x] == x
      x
    else
      root(@parents[x])
    end
  end

  def same?(x, y)
    root(x) == root(y)
  end

  def unite(x, y)
    x = root(x)
    y = root(y)
    if x != y
      @parents[x] = y
    end
  end
end

# 経路圧縮
DisjointSet.prepend Module.new {
  def root(x)
    if @parents[x] == x
      x
    else
      @parents[x] = root(@parents[x])
    end
  end
}

N, Q = gets.split.collect(&:to_i)                   # => [8, 9]
A = Q.times.collect { gets.split.collect(&:to_i) }  # => [[0, 1, 2], [0, 3, 2], [1, 1, 3], [1, 1, 4], [0, 2, 4], [1, 4, 1], [0, 4, 2], [0, 0, 0], [1, 0, 0]]
ds = DisjointSet.new
A.each do |p, a, b|
  if p == 0
    ds.unite(a, b)
  else
    ans = ds.same?(a, b) ? "Yes" : "No"             # => "Yes", "No", "Yes", "Yes"
    puts ans
  end
end

効率化2. ポプラ型よりブロッコリー型

↓こうならないように、

↓あらかじめこうする。

これは、

  • 最初からできる限りへらべったくしておけばよくないか?
  • そうした方が root() も最初行くとき負担が少ないだろう

とする Union 側の工夫の一つで、木の高さで調整するタイプを union by rank という。木の高さを言い換えると「根までの距離」であり「根までのステップ数」であるから、たしかに高さで調整するのが最適に思える。

順に書くと、さきほどと同様に A が B で B が C で──とすると一列になる。これがよくない。

ds = DisjointSet.new
ds.unite("A", "B")  # => {"A"=>"B", "B"=>"B"}
ds.unite("B", "C")  # => {"A"=>"B", "B"=>"C", "C"=>"C"}

なので繋ぐときに低い方を高い方に入れる。

DisjointSet.prepend Module.new {
  attr_reader :ranks

  def initialize
    super
    @ranks = Hash.new(1) # 「高さ」自体にそんなに意味はない
  end

  def unite(x, y)
    x = root(x)
    y = root(y)
    if x != y
      case
      when @ranks[x] < @ranks[y]
        chain(x, y)               # 低い方(x) が 高い方(y) にぶらさがる
      when @ranks[y] < @ranks[x]
        chain(y, x)               # 低い方(y) が 高い方(x) にぶらさがる
      else                        # 同じ高さの場合、
        chain(x, y)               # どちらにぶらさげてもよい
        @ranks[y] += 1            # ぶらさげられた方の高さを一段上げる
      end
    end
    self
  end
}

続いて、さきほどと同じ手順で作ると、

ds = DisjointSet.new
ds.unite("A", "B")  # => {"A"=>"B", "B"=>"B"}
ds.unite("B", "C")  # => {"A"=>"B", "B"=>"B", "C"=>"B"}

今度は、

  • C の木の高さ → 1
  • A-B の木の高さ → 2

となるため、C が A-B に入って最初から低い木になっているのがわかる。

当初、低い方が高い方にぶらさがっても、もしくは仮に逆だったとしても、結局高さは同じなのでは? と考えて混乱していたが、相手の末端にぶらさがるのではなく、相手のに直接ぶらさがるから低い形を維持できる。

そこに木 X-Y が登場した。

ds.unite("X", "Y")  # => {"A"=>"B", "B"=>"B", "C"=>"B", "X"=>"Y", "Y"=>"Y"}

そこで A と X と繋ぐとどうなるか?

ds.unite("A", "X")  # => {"A"=>"B", "B"=>"Y", "C"=>"B", "X"=>"Y", "Y"=>"Y"}

まず A と X は、直接は繋げられない。かわりに代表の B と Y の話し合いになる。

上の図では、ひと目、B の木に Y の木が入れば収まりがよさそうだが、結果は逆になった。

これはどういうことか? 比較するのは大きさではなく高さで、B と Y は同じ高さのため、どちらに入ってもよく、つまり「たまたま Y の方に入った」ということになる。

この場合、A B C の3人はちょっと不満かもしれない。こっちの方が多いのになんで移動しないといけないのか。少ない方が来いよ、と。

このようにして 2 の「根までが遠すぎて遅い」が解決し、経路圧縮の助けを借りなくても通る。

再提出したコード (AC)
class DisjointSet
  def initialize
    @parents = Hash.new { |h, k| h[k] = k }
    @ranks = Hash.new(1)
  end

  def root(x)
    if @parents[x] == x
      x
    else
      root(@parents[x])
    end
  end

  def same?(x, y)
    root(x) == root(y)
  end

  def unite(x, y)
    x = root(x)
    y = root(y)
    if x != y
      case
      when @ranks[x] < @ranks[y]
        @parents[x] = y
      when @ranks[y] < @ranks[x]
        @parents[y] = x
      else
        @parents[x] = y
        @ranks[y] += 1
      end
    end
  end
end

N, Q = gets.split.collect(&:to_i)                   # => [8, 9]
A = Q.times.collect { gets.split.collect(&:to_i) }  # => [[0, 1, 2], [0, 3, 2], [1, 1, 3], [1, 1, 4], [0, 2, 4], [1, 4, 1], [0, 4, 2], [0, 0, 0], [1, 0, 0]]
ds = DisjointSet.new
A.each do |p, a, b|
  if p == 0
    ds.unite(a, b)
  else
    ans = ds.same?(a, b) ? "Yes" : "No"             # => "Yes", "No", "Yes", "Yes"
    puts ans
  end
end

高さの初期値

0 か 1 で迷ってしまうのだけど、どちらでもよい。比べるための値なので比べることが可能な値ならなんでもよい。とはいえ、デバッグや可視化のためには 0 か 1 にしておくのがわかりやすい。

0 か 1 かは、

  • 節は点で辺に長さがある → 0
  • 辺は概念上のもので節に長さがある → 1

という考え方の違いになる。この記事では後者にしている。

二つの効率化を合体する

「経路圧縮」と「そもそも最初からできる限り木を低く作る」は相乗効果により合わせて使うと進化する。進化すると O(log n) から O(α(n)) ことアッカーマンの逆関数になる。

アッカーマン関数は、与える数が大きくなると爆発的に計算量が大きくなる。

def ack(x, y)
  @count += 1
  case
  when x == 0
    y + 1
  when y == 0
    ack(x - 1, 1)
  else
    ack(x - 1, ack(x, y - 1))
  end
end

@count = 0
ack(3, 9)  # => 4093
@count     # => 11164370

具体的には x が 4 以上になると大爆発する。

ここではとりあえず「ありえないほど遅くなるアッカーマン関数」の逆で、n がいくら増加したとしても「ありえないほど遅くなっていかない」と覚えておく。

高さの不整合

union by rank に経路圧縮を組み合わせると、保存した高さと実際の高さが一致しなくなる場合がある。そこで親切心から経路圧縮の際に高さをリセットしてあげたくなってくるが、これはありがた迷惑で、やると 1.08 倍遅くなる。

Warming up --------------------------------------
                そのまま     1.165k i/100ms
              高さリセット     1.072k i/100ms
Calculating -------------------------------------
                そのまま     11.639k (± 0.8%) i/s -     58.250k in   5.005246s
              高さリセット     10.756k (± 0.6%) i/s -     54.672k in   5.083295s

Comparison:
                そのまま:    11638.6 i/s
              高さリセット:    10755.6 i/s - 1.08x  slower

したがって、不整合は無視してよい。

冗長な連続併合をまとめる

インスタンス生成後に幹を順番に渡して森を作る処理が多いので、クラスに対して幹をまとめて渡してインスタンスを作るようにする。このようにしておくと複数の森を作る場合もコードを単純化できる。

DisjointSet.prepend Module.new {
  def self.prepended(klass)
    class << klass
      def [](...)
        new.unites(...)
      end

      def unites(...)
        new.unites(...)
      end
    end
  end

  def unites(edges = [])
    edges.each { |e| unite(*e) }
    self
  end
}
edges = [["A", "A"], ["B", "C"], ["C", "D"], ["E", "F"], ["F", "G"], ["G", "H"]]
ds = DisjointSet[edges]  # => {"A"=>"A", "B"=>"C", "C"=>"C", "D"=>"C", "E"=>"F", "F"=>"F", "G"=>"F", "H"=>"F"}

森の様子をうかがう

森の情報を把握したいとき、すべての根を求めて集計する処理を何度も行うことになる。

ds = DisjointSet.new
ds.unite("A", "B")
ds.unite("B", "C")
ds.unite("D", "D")

そこで便利メソッドたちを入れておくと、

DisjointSet.prepend Module.new {
  def nodes
    @parents.keys
  end

  def roots
    @parents.inject({}) { |a, (e, _)| a.merge(e => root(e)) }
  end

  def groups
    @parents.keys.group_by { |e| root(e) }
  end

  def tree_count
    groups.count
  end

  def tree_size_max(...)
    groups.transform_values(&:size).values.max(...)
  end
}

可視化しなくても森の様子がわかる。

節たち
ds.nodes  # => ["A", "B", "C", "D"]
節に対応する根
ds.roots  # => {"A"=>"B", "B"=>"B", "C"=>"B", "D"=>"D"}
木の節たち, 木の大きさ
ds.groups                           # => {"B"=>["A", "B", "C"], "D"=>["D"]}
ds.groups.transform_values(&:size)  # => {"B"=>3, "D"=>1}
辺の数
ds.groups.transform_values { |e| e.size.pred }  # => {"B"=>2, "D"=>0}

関連する問題: ARC037 B バウムテスト

木の数
ds.groups.count  # => 2
ds.tree_count    # => 2

関連する問題: ARC032 B 道路工事

いちばん大きな木の大きさ
ds.tree_size_max  # => 3

関連する問題: ABC177 D Friends

木の数を一瞬で得る

木の数は前述の、

  • 根のユニーク数

とする考え方がシンプルだが、

  • 併合するたびに木が減っていく

とする考え方もある。

後者の場合は全体の併合回数を保持しておく。

DisjointSet.prepend Module.new {
  attr_reader :chain_count

  def initialize
    super
    @chain_count = 0
  end

  def edge_count
    @chain_count
  end

  def tree_count
    @parents.size - @chain_count
  end

  private

  def chain(...)
    super
    @chain_count += 1
  end
}

この場合の木の数 2 は、

ds = DisjointSet.new
ds.unite("A", "B")
ds.unite("B", "C")
ds.unite("D", "E")
ds.unite("E", "F")

併合回数 (=全体の辺の数) が、

ds.chain_count  # => 4
ds.edge_count   # => 4

4 なので全体の節の数 6 から引けば木の数 2 が求まる。

ds.parents.size - ds.chain_count  # => 2

最初からすべての節の数 N がわかっている場合は、併合する度に N をデクリメントしていく方法も一応あるが、「木の数」を求める目的とは分けて、単に全体の併合回数 (または全体の辺の数) の意味でカウントした方が応用の余地が広がる。実際に「全体の辺の数」を求めるだけの問題もある。

関連する問題: ABC206 D KAIBUNsyo

木の大きさを一瞬で得る

木の大きさは「同じ根を持つ仲間の数」になる。

ds = DisjointSet.new
ds.unite("A", "B")
ds.unite("C", "C")

図にすれば A-B の木が 2 で、C が 1 だとすぐにわかるが、これをコードにするとややこしい。

まず、各節の根を求めて頻度を集計する。

roots = ds.parents.keys.collect { |e| ds.root(e) }  # => ["B", "B", "C"]
freq = roots.tally                                  # => {"B"=>2, "C"=>1}

次に、再度 A の根を求めると B なので、

ds.root("A")  # => "B"

やっとここで A の木の大きさは 2 だとわかる。

freq["B"]  # => 2

集計後の情報 (freq) を持っておけば、次のように続けて他の木の大きさもわかって便利だが、

freq[ds.root("C")]  # => 1

もし A の木だけでいいなら「同じ根を持つ仲間の数」の言葉通り、

ds.parents.keys.count { |e| ds.root(e) == ds.root("A") }  # => 2

とする手もある。

が、間違えずに書ける自信はないのでメソッド化しておく。

DisjointSet.prepend Module.new {
  def size(x)
    @parents.count { |e, _| root(e) == root(x) }
  end
}

これで扱いやすくなる。

ds.size("A")  # => 2

ただ遅い。頻繁に大きさを調べる場合にここがボトルネックになる。

最適化

併合のタイミングで、ひっかける方の節の数を、ひっかけられる根の方に足して、木の大きさを常時保持しておく。

DisjointSet.prepend Module.new {
  attr_reader :sizes

  def initialize
    super
    @sizes = Hash.new(1)
  end

  def size(x)
    @sizes[root(x)]
  end

  def tree_size_max(...)
    @parents.collect { |e, _| size(e) }.max(...)
  end

  private

  def chain(x, y)
    super
    @sizes[y] += @sizes[x]
  end
}

このようにしておくと木の大きさが最速で求まる。

ds = DisjointSet.new
ds.unite("A", "B")
ds.size("A")      # => 2
ds.size("B")      # => 2
ds.size("C")      # => 1
ds.tree_size_max  # => 2

グラフの連結成分の辺の数を一瞬で得る

考え方は木の大きさのときと同様で、辺の数は「同じ辺を持つ仲間の数」で求められる。

次のような三角形状に繋がっているグラフがあったとして、

edges = [
  ["A", "B"],
  ["B", "C"],
  ["C", "A"],
]

この場合の辺の数 3 を求めたい場合は、辺に対応する根の数を集計する。

ds = DisjointSet.new
edges.each { |e| ds.unite(*e) }
roots = edges.collect { |x, y| ds.root(x) }  # => ["B", "B", "B"]
edge_counts = roots.tally(Hash.new(0))       # => {"B"=>3}
edge_counts[ds.root("A")]                    # => 3

A だけの木でいいなら言葉通りに「同じ辺を持つ仲間の数」でよい。

edges.count { |x, _| ds.root(x) == ds.root("A") }  # => 3

そして同様に最適化できる。

最適化

DisjointSet.prepend Module.new {
  attr_reader :edge_counts

  def initialize(...)
    super
    @edge_counts = Hash.new(0)
  end

  def unite(x, y)
    x = root(x)
    y = root(y)
    if x == y
      @edge_counts[y] += 1
    end
    super
  end

  def edge(x)
    @edge_counts[root(x)]
  end

  def edge_max(...)
    @parents.collect { |e, _| edge(e) }.max(...)
  end

  private

  def chain(x, y)
    super
    @edge_counts[y] += @edge_counts[x] + 1
  end
}

として辺の数をあらかじめ保持しておくと、

ds = DisjointSet.new
edges.each { |e| ds.unite(*e) }
ds.edge("A")  # => 3

最速で求まる。

関連する問題: ABC292 D Unicyclic Components

関連記事

https://zenn.dev/megeton/articles/0c76b176b768bd

https://zenn.dev/megeton/articles/3757068745300b

グラフの閉路判定

素集合データ構造では次のような閉路があるグラフを表現できないが閉路の判定はできる。

方法1. 併合前にすでに繋がっているか確認する

ds = DisjointSet.new
closed = ds.same?("A", "B")  # => false
ds.unite("A", "B")
closed = ds.same?("B", "C")  # => false
ds.unite("B", "C")
closed = ds.same?("C", "A")  # => true
ds.unite("C", "A")

C と A を接続しようとした時点ですでに C と A は A-B-C で繋がっているため C と A を繋げるとループするのがわかる。

方法2. 辺の数が多すぎる木を判定する

edges = [["A", "B"], ["B", "C"], ["C", "A"]]
ds = DisjointSet.new("A".."C").unites(edges)  # => 親:{"A"=>"B", "B"=>"B", "C"=>"B"} 構図:{"B"=>["A", "B", "C"]} 高さ:{"B"=>2} 節数:{"B"=>3} 木数:1
roots = edges.collect { |x, y| ds.root(x) }   # => ["B", "B", "B"]
edge_counts = roots.tally(Hash.new(0))        # => {"B"=>3}
ds.groups                                     # => {"B"=>["A", "B", "C"]}
closed = ds.groups.any? { |parent, nodes| edge_counts[parent] >= nodes.size }
closed                                        # => true
  • 素集合森の木は閉路がない
  • 閉路がない木は「節の数 - 1 = 辺の数」が成り立つ

よって元のグラフ通りに素集合森を作って辺の数が合わない木は閉路があると判定できる。

関連する問題: ARC037 B バウムテスト

関連記事

https://zenn.dev/megeton/articles/3757068745300b

節を削除する方法

削除したい節を使わずに森を作る。

ds = DisjointSet.new
edges = [["A", "B"], ["B", "C"], ["C", "D"], ["D", "E"]]
edges.each { |e| ds.unite(*e) }

たとえばここで C を削除するには C を除外して最初から作る。

ds = DisjointSet.new
edges = [["A", "B"], ["B", "C"], ["C", "D"], ["D", "E"]]
edges = edges.reject { |e| e.include?("C") }
edges.each { |e| ds.unite(*e) }

関連する問題: ABC075 C Bridge

どちらの木に連結したか調べる

連結後に根が変化していなかった側の木が連結される側になったとわかるので、あらかじめ求めておいた根と連結後の根が同じかを確認する。

ds = DisjointSet.new
x = ds.root("A")  # => "A"
y = ds.root("B")  # => "B"
ds.unite(x, y)
if ds.root(x) == x
  "y -> x"        # => 
else
  "x -> y"        # => "x -> y"
end

この結果から、x → y の向きで連結したことがわかる。

併合処理の戻り値で判断可能な実装も見かけるが、そのようになっていなくても上述の方法で判断できる。

関連する問題: ABC183 F Confluence

ロールバック機能

いろんな実装がありそう。一例として親 (および関連するインスタンス変数) を毎回履歴に入れておくのが簡単そうに思える。デザインパターンで言えば Memento になるのかもしれないが、そんな大袈裟なことはしていない。

DisjointSet.prepend Module.new {
  def initialize
    super
    @history = []
  end

  def rollback
    @parents.replace(@history.pop)
  end

  private

  def chain(...)
    @history << @parents.clone
    super
  end
}
ds = DisjointSet.new
ds.unite("A", "B")
ds.unite("B", "C")

この状態から一手戻す。

ds.rollback

さらに戻す。

ds.rollback

これが役に立ったことはまだない。

重み付き

と、一般的に呼ばれているが、掘ったさつまいもの蔓を持ち上げたときのようなのをイメージしてはいけない。実際はその逆で、根がいちばん軽く、端に行くほど重くなる。したがって、重みと言うよりは「根からスタートした車の燃料代」と考えた方がしっくりくる。それと似た考えなのかはわからないが「重み付き」ではなく「ポテンシャル付き」と呼ぶ人もいる。

DisjointSet.prepend Module.new {
  attr_accessor :diff_weights

  def initialize
    super
    @diff_weights = Hash.new { |h, k| h[k] = 0 }
  end

  def root(x)
    if @parents[x] == x
      x
    else
      r = root(@parents[x])
      @diff_weights[x] += @diff_weights[@parents[x]]
      @parents[x] = r
    end
  end

  def weight(x)
    root(x)
    @diff_weights[x]
  end

  def diff(x, y)
    weight(x) - weight(y)
  end

  def unite(x, y, weight)
    weight -= weight(x)
    weight += weight(y)

    x = root(x)
    y = root(y)
    if x != y
      chain(x, y, weight)
    end
    self
  end

  private

  def chain(x, y, weight)
    super(x, y)
    @diff_weights[x] = weight
  end
}

最初はそれぞれの節が差分を持っている。

ds = DisjointSet.new
ds.unite("A", "B", 2)
ds.unite("B", "C", 2)
ds.unite("C", "D", 2)
ds.diff_weights  # => {"A"=>2, "B"=>2, "C"=>2, "D"=>0}

ここで経路圧縮すると、

ds.root("A")
ds.diff_weights  # => {"A"=>6, "B"=>4, "C"=>2, "D"=>0}

差分から累積和に変わっているのがわかる。

関連する問題: ABC087 D People on a Line

提出したコード (AC)
class DisjointSet
  attr_reader :parents

  def initialize
    @parents = Hash.new { |h, k| h[k] = k }
    @diff_weights = Hash.new(0)
  end

  def root(x)
    if @parents[x] == x
      x
    else
      r = root(@parents[x])
      @diff_weights[x] += @diff_weights[@parents[x]]
      @parents[x] = r
    end
  end

  def same?(x, y)
    root(x) == root(y)
  end

  def weight(x)
    root(x)
    @diff_weights[x]
  end

  def diff(x, y)
    weight(x) - weight(y)
  end

  def unite(x, y, weight)
    weight -= weight(x)
    weight += weight(y)

    x = root(x)
    y = root(y)
    if x != y
      chain(x, y, weight)
    end
    self
  end

  def inspect
    @parents.inspect
  end

  private

  def chain(x, y, weight)
    @parents[x] = y
    @diff_weights[x] = weight
  end
end

if $0 == __FILE__ && ENV["ATCODER"] != "1"
  require "rspec/autorun"
  RSpec.configure do |config|
    config.expect_with :test_unit
  end

  def test(lrd)
    ds = DisjointSet.new
    ans = "Yes"
    lrd.each do |l, r, d|
      if ds.same?(l, r)
        if ds.diff(l, r) != d
          ans = "No"
          break
        end
      else
        ds.unite(l, r, d)
      end
    end
    ans
  end

  describe do
    it "works" do
      assert { test([[1, 2, 1], [2, 3, 1], [1, 3, 2]]) == "Yes" }
      assert { test([[1, 2, 1], [2, 3, 1], [1, 3, 5]]) == "No" }
      assert { test([[2, 1, 1], [2, 3, 5], [3, 4, 2]]) == "Yes" }
      assert { test([[8, 7, 100], [7, 9, 100], [9, 8, 100]]) == "No" }
      assert { test([]) == "Yes" }
    end
  end
end

if $0 == "-"
  require "stringio"
  $stdin = StringIO.new(<<~eos)
3 3
1 2 1
2 3 1
1 3 2
eos
end

N, M = gets.split.collect(&:to_i)
lrd = M.times.collect { gets.split.collect(&:to_i) }
ds = DisjointSet.new
ans = "Yes"
lrd.each do |l, r, d|
  if ds.same?(l, r)
    if ds.diff(l, r) != d
      ans = "No"
      break
    end
  else
    ds.unite(l, r, d)
  end
end
ans  # => "No"
puts ans

参照:

https://qiita.com/drken/items/cce6fc5c579051e64fab

Quick Find 実装

Quick Find は併合処理を行う際に影響を受ける節の親をすべて更新する、という悪い意味で力強い実装になっている。何の知識もない状態で最適化するように言われたら(自分だと)このように書いてしまうかもしれない。

class QuickFind
  attr_reader :parents

  def initialize
    @parents = Hash.new { |h, k| h[k] = k }
  end

  def root(x)
    @parents[x]
  end

  def unite(x, y)
    x = root(x)
    y = root(y)
    if x != y
      @parents.each do |child, parent|
        if parent == x
          @parents[child] = y
        end
      end
    end
    self
  end

  def inspect
    @parents.inspect
  end
end

A → B → C と繋げたつもりが、すでに A は C に直結しているのがわかる。

ds = QuickFind.new
ds.unite("A", "B")  # => {"A"=>"B", "B"=>"B"}
ds.unite("B", "C")  # => {"A"=>"C", "B"=>"C", "C"=>"C"}

このおかげで Find が超 Quick になる。

ds.root("A")  # => "C"

しかし一方で、Union は常時 O(n) なので全体で見れば遅い。

なお詳細は、

https://www.coursera.org/lecture/algorithms-part1/quick-find-EcF3P

で見ることができる。

純粋な union by size 実装

↓こうならないように、

↓あらかじめこうする。

木の「高さ」で調整するのが union by rank だった。一方、木の「大きさ」で調整するのを union by size と言う。

まず、なんの調整も入っていない方法だと、A が B で B が C で──は、一列になる。

ds = DisjointSet.new
ds.unite("A", "B")  # => {"A"=>"B", "B"=>"B"}
ds.unite("B", "C")  # => {"A"=>"B", "B"=>"C", "C"=>"C"}

そこで繋ぐときに小さい方を大きい方にぶらさげる。

DisjointSet.prepend Module.new {
  attr_reader :sizes

  def initialize
    super
    @sizes = Hash.new(1)
  end

  def unite(x, y)
    x = root(x)
    y = root(y)
    if x != y
      if @sizes[x] > @sizes[y]
        x, y = y, x
      end
      @sizes[y] += @sizes[x]
      @parents[x] = y
    end
    self
  end
}

こうしておいて、さきほどと同じ手順で作ると、

ds = DisjointSet.new
ds.unite("A", "B")  # => {"A"=>"B", "B"=>"B"}
ds.unite("B", "C")  # => {"A"=>"B", "B"=>"B", "C"=>"B"}

今度は、

  • C の木の大きさ → 1
  • A-B の木の大きさ → 2

となるため、C が A-B に入って最初から低い木になっているのがわかる。

これは、個数の大小関係が、高さの場合と同じだったため、union by rank と同じ結果になっている。

では、この状況だとどうか?

ds = DisjointSet.new
ds.unite("A", "B")  # => {"A"=>"B", "B"=>"B"}
ds.unite("C", "D")  # => {"A"=>"B", "B"=>"B", "C"=>"D", "D"=>"D"}
ds.unite("B", "D")  # => {"A"=>"B", "B"=>"D", "C"=>"D", "D"=>"D"}

ds.unite("L", "K")  # => {"A"=>"B", "B"=>"D", "C"=>"D", "D"=>"D", "L"=>"K", "K"=>"K"}
ds.unite("M", "K")  # => {"A"=>"B", "B"=>"D", "C"=>"D", "D"=>"D", "L"=>"K", "K"=>"K", "M"=>"K"}
ds.unite("N", "K")  # => {"A"=>"B", "B"=>"D", "C"=>"D", "D"=>"D", "L"=>"K", "K"=>"K", "M"=>"K", "N"=>"K"}
ds.unite("O", "K")  # => {"A"=>"B", "B"=>"D", "C"=>"D", "D"=>"D", "L"=>"K", "K"=>"K", "M"=>"K", "N"=>"K", "O"=>"K"}

ここで K と D を繋ぐ。

ds.unite("K", "D")

K の高さは 2 で、D は 3 (A→B→D) なので、木を低く保ち、根までの距離を最短にするには、K が D に入るのが最適である。

ところが、結果は逆で D が K に入った。これはどういうことか?

いま見ているのは union by size である。union by size は、根までの距離に関心がなく、大きさで比較するため、たんに大きな K の方を根にした、ということであり、それは最適ではないということでもある。

A からしたらこれには不満だろう。なんで K 門下に入らないといけないのか。もともと遠かったのがさらに遠くなったじゃないか。そっちが来い、と。

一方、L M N O および K たちは大所帯にいてよかった、とほっとしていた。

A は不憫だが、L M N O K の立場で考えると5人分の経路が遠くならなかったので「最適ではない」とはっきりとは言い切れない。これはトロッコ問題にも似ているような、似ていないような。はたして一人の立場を尊重した union by rank と、五人の立場を尊重した union by size は、どちらの判断が正しかったのか?

なお、この効率化だけでも ATC001 B Union Find は通る。

提出したコード (AC)
class DisjointSet
  def initialize
    @parents = Hash.new { |h, k| h[k] = k }
    @sizes = Hash.new(1)
  end

  def root(x)
    if @parents[x] == x
      x
    else
      root(@parents[x])
    end
  end

  def same?(x, y)
    root(x) == root(y)
  end

  def unite(x, y)
    x = root(x)
    y = root(y)
    if x != y
      if @sizes[x] > @sizes[y]
        x, y = y, x
      end
      @sizes[y] += @sizes[x]
      @parents[x] = y
    end
  end
end

N, Q = gets.split.collect(&:to_i)                   # => [8, 9]
A = Q.times.collect { gets.split.collect(&:to_i) }  # => [[0, 1, 2], [0, 3, 2], [1, 1, 3], [1, 1, 4], [0, 2, 4], [1, 4, 1], [0, 4, 2], [0, 0, 0], [1, 0, 0]]
ds = DisjointSet.new
A.each do |p, a, b|
  if p == 0
    ds.unite(a, b)
  else
    ans = ds.same?(a, b) ? "Yes" : "No"             # => "Yes", "No", "Yes", "Yes"
    puts ans
  end
end

最適化対決 rank vs size

で、結局 union by rank と union by size はどちらがよいのか?

当初は union by rank の方が圧倒的に優れていると考えていた。距離が重要なのだから距離で調整するのが理に適っている。一方の、大きさを見ている union by size は方向がずれているように見えた。ところが、机上で動かしてみると一概にそうとも言えないような状況があった。

union by rank
class UnionByRank
  attr_reader :parents

  def initialize
    @parents = Hash.new { |h, k| h[k] = k }
    @ranks = Hash.new(1)
  end

  def root(x)
    if @parents[x] == x
      x
    else
      root(@parents[x])
    end
  end

  def unite(x, y)
    x = root(x)
    y = root(y)
    if x != y
      case
      when @ranks[x] < @ranks[y]
        chain(x, y)
      when @ranks[y] < @ranks[x]
        chain(y, x)
      else
        chain(x, y)
        @ranks[y] += 1
      end
    end
  end

  private

  def chain(x, y)
    @parents[x] = y
  end
end
union by size
class UnionBySize
  attr_reader :parents

  def initialize
    @parents = Hash.new { |h, k| h[k] = k }
    @sizes = Hash.new(1)
  end

  def root(x)
    if @parents[x] == x
      x
    else
      root(@parents[x])
    end
  end

  def unite(x, y)
    x = root(x)
    y = root(y)
    if x != y
      if @sizes[x] > @sizes[y]
        x, y = y, x
      end
      chain(x, y)
    end
  end

  private

  def chain(x, y)
    @parents[x] = y
    @sizes[y] += @sizes[x]
  end
end
ベンチマーク
r = 2000
n = 1000
edges = n.times.collect { [rand(r), rand(r)] }

test = -> klass {
  ds = klass.new
  edges.each { |e| ds.unite(*e) }
}

require "benchmark/ips"
Benchmark.ips do |x|
  x.report("union by rank")  { test[UnionByRank] }
  x.report("union by size")  { test[UnionBySize] }
  x.compare!
end
Warming up --------------------------------------
       union by rank   283.000  i/100ms
       union by size   274.000  i/100ms
Calculating -------------------------------------
       union by rank      2.841k (± 1.2%) i/s -     14.433k in   5.081506s
       union by size      2.748k (± 0.9%) i/s -     13.974k in   5.085557s

Comparison:
       union by rank:     2840.7 i/s
       union by size:     2748.0 i/s - 1.03x  slower

結果:

  • union by size は union by rank の 1.03 倍遅い
  • union by rank の方が優秀

とはいえ、この微差に経路圧縮も含めれば、ほぼ違いはないとも言える。

rank+ size+ 参戦

これは勝手に決めた名前で、

  • rank+ → 高さが同じとき大きさで調整する
  • size+ → 大きさが同じとき高さで調整する

となるように改造している。いいとこどりなので微差だが速くなるはずだ。

union by rank+
class UnionByRankPlus < UnionByRank
  def initialize
    super
    @sizes = Hash.new(1)
  end

  def unite(x, y)
    x = root(x)
    y = root(y)
    if x != y
      case
      when @ranks[x] < @ranks[y]
        chain(x, y)
      when @ranks[y] < @ranks[x]
        chain(y, x)
      else
        if @sizes[x] > @sizes[y]
          x, y = y, x
        end
        chain(x, y)
        @ranks[y] += 1
      end
    end
  end

  def chain(x, y)
    @parents[x] = y
    @sizes[y] += @sizes[x]
  end
end
union by size+
class UnionBySizePlus < UnionBySize
  def initialize
    super
    @ranks = Hash.new(1)
  end

  def unite(x, y)
    x = root(x)
    y = root(y)
    if x != y
      if @sizes[x] == @sizes[y]
        case
        when @ranks[x] < @ranks[y]
          chain(x, y)
        when @ranks[y] < @ranks[x]
          chain(y, x)
        else
          chain(x, y)
          @ranks[y] += 1
        end
      else
        if @sizes[x] > @sizes[y]
          x, y = y, x
        end
        chain(x, y)
      end
    end
  end
end
Warming up --------------------------------------
       union by rank   283.000  i/100ms
       union by size   273.000  i/100ms
      union by rank+   234.000  i/100ms
      union by size+   225.000  i/100ms
Calculating -------------------------------------
       union by rank      2.837k (± 1.1%) i/s -     14.433k in   5.087967s
       union by size      2.742k (± 1.6%) i/s -     13.923k in   5.078401s
      union by rank+      2.335k (± 1.5%) i/s -     11.700k in   5.012946s
      union by size+      2.316k (± 1.3%) i/s -     11.700k in   5.053541s

Comparison:
       union by rank:     2837.1 i/s
       union by size:     2742.4 i/s - 1.03x  slower
      union by rank+:     2334.5 i/s - 1.22x  slower
      union by size+:     2315.6 i/s - 1.23x  slower

結果: 逆に遅くなった。

自分が考えた程度の工夫で速くなるならすでに誰かがやって一般化されているだろう。

union by size が多用される理由

競プロでは union by size の実装が多いように見えるのはなぜか?

  • union by rank に比べてコードがシンプル
    • union by size の方が(若干)早く実装できる
      • 競プロ勢の中には毎回素で一から書く人も少なくない
  • ついでに木の大きさがすぐにわかる
    • 実はこちらが主体?
    • union by rank が持つ高さは木の平坦化以外に使い道がない
  • union by rank と比べてほとんど性能が落ちないのを知っているから
  • ググって見つかる解説やコードがほとんどそれだから

と思われる。

実装の工夫から見る想像上の歴史

  • Quick Union: 誕生
    ↓ 遅いので改良
  • Quick Find: 改良の方向を間違えて失敗
    ↓ いったん元に戻して Find の方を改良
  • 経路圧縮: 急激に速くなる
    ↓ 木を作るときに高さでバランスをとってみる
  • 経路圧縮 + union by rank: さらに速くなる
    ↓ 木の大きさを求める機能を追加してみる
  • 経路圧縮 + union by rank (+ size): 少しコード量が増えた
    ↓ 試しに木の大きさでバランスをとってみる
  • 経路圧縮 + union by size: コードが短かくなり速度もそんなに変わらなかった

アンチパターン

次に上げるものは競プロ向けに最適化されているため汎用性や変更容易性が低い。にもかかわらず手放しで賞賛されている印象がある。もし競プロの攻略だけが目的であれば推奨ということになるのかもしれないが、最速および最小メモリで動かなくてよいので、素直にアルゴリズムを学びたい、教えたい、学んだ素集合データ構造を普段使いしたい、より少ない手間で変更して別のことに応用したい場合に適しているとは言えない。

1. 親を連番で初期化する

一般的にノードオブジェクトが整数 0 や 1 から始まる N までの連番になることはないが、競プロに限って言えば (数学の延長のような面があるので) そうなっている。

コード
class DisjointSet
  def initialize(n)
    @parents = (0...n).to_a
  end

  # ...
end

2. 親の値が負のとき絶対値を特別扱いする

これは union by size の別実装で、ノードが正の整数になっているところに着目し、ノードとサイズ(連結成分の大きさ)がたまたま同じ整数型だったことから、サイズを保持する変数を用意せず、ノードが負の値であればその絶対値をサイズと見なす、という仕掛けになっている。一つの変数を二つのことに使うことで、良く言えば、性能が上がる。

コード
class DisjointSet
  attr_reader :parents

  def initialize(n)
    @parents = Array.new(n, -1)
  end

  def root(x)
    if @parents[x] < 0
      x
    else
      @parents[x] = root(@parents[x])
    end
  end

  def same?(x, y)
    root(x) == root(y)
  end

  def unite(x, y)
    x = root(x)
    y = root(y)
    if x != y
      if @parents[x] > @parents[y]
        x, y = y, x
      end
      @parents[y] += @parents[x]
      @parents[x] = y
    end
  end

  def size(x)
    -@parents[root(x)]
  end
end

過去問

問題 種類 解法 リンク
ABC062 A - Grouping 連結確認 連結確認 考察
ARC032 B - 道路工事 木の数 木の数 考察
ABC284 C - Count Connected Components 木の数 木の数 考察
典型90問 012 - Red Painting 二次元 作りつつ同じ木か? 考察
ARC031 B - 埋め立て 二次元 木の数 考察
ABC288 C - Don't be cycle 閉路判定 閉路判定 考察
典型問題集 F - 最小全域木問題 削除 クラスカル法 考察
鉄則 A67 - MST 削除 クラスカル法 考察
ABC287 C - Path Graph? パスグラフ 最大次数 <= 2 考察
ABC235 E - MST + 1 考察力 クラスカル法の応用 考察
ABC065 D - Built? 考察力 X,Y成分毎に辺を作りMST 考察
ABC075 C - Bridge 削除 作り直す, 木の数 考察
ABC333 D - Erase Leaves 削除 作り直す, 最大の木 考察
ABC218 E - Destruction 削除 ソートして作り直す 考察
ABC264 E - Blackout 2 削除 発電所は1つ 考察
ABC304 E - Good Graph 考察力 根に変換して考える 考察
ABC157 D - Friend Suggestions 木の大きさ 木の大きさから引いていく 考察
ABC231 D - Neighbors same? 節同士が同じ木か? 考察
ABC177 D - Friends 最大の木 最大の木 考察
ABC126 E - 1 or 2 木の数 木の数 考察
ABC214 D - Sum of Maximum Weights 併合 大きさ×大きさ×重さ 考察
ABC183 F - Confluence マージ 併合に合わせてマージ 考察
ABC292 D - Unicyclic Components グラフ 辺の数 考察
ARC037 B - バウムテスト グラフ 辺の数 考察
ARC111 B - Reversible Cards グラフ 普通の木と見分ける 考察
ABC189 C - Mandarin Orange 考察力 大きい順に木を作る 考察
ABC120 D - Decayed Bridges 削除 大きさ×大きさ 考察
ARC056 B - 駐車場 辺の削除 節以上の辺と連結する 考察
ABC229 E - Graph Destruction 頂点削除 若い辺と連結する 考察
ABC040 D - 道路の老朽化対策について 辺の削除 新しい順に使える辺を作る 考察
ABC279 F - BOX 根と箱の相互変換 考察
ARC097 D - Equals パネポン 節と連番が同じ木か? 考察
ABC049 D - 連結 複数の森 複数の森 考察
ABC259 D - Circumferences 円交差判定 両端を半径0の円とする 考察
ARC106 B - Values パネポン 木の節の値の合計 考察
ABC206 D - KAIBUNsyo 考察力 全体の辺の数 考察
ABC350 D - New Friends 木の大きさ 完全グラフの辺の数から引く 考察
ABC351 D - Grid and Magnet グリッドから森を作る 考察
ABC352 E - Clique Connect 考察力 クラスカル法 考察

競プロコピペ用

経路圧縮 + union by rank (+ 木の数 + 木の大きさ + 木の辺の数 + 併合回数)
class QuickUnion
  attr_reader :parents

  def initialize(...)
    @parents = Hash.new { |h, k| h[k] = k }
  end

  def root(x)
    if @parents[x] == x
      x
    else
      root(@parents[x])
    end
  end

  def make_set(...)
    root(...)
    self
  end

  def same?(x, y)
    root(x) == root(y)
  end

  def unite(x, y)
    x = root(x)
    y = root(y)
    if x != y
      chain(x, y)
    end
    self
  end

  private

  def chain(x, y)
    @parents[x] = y
  end
end

# 使いやすくする
module HelperMod
  def self.prepended(klass)
    class << klass
      def [](...)
        new.unites(...)
      end

      def unites(...)
        new.unites(...)
      end
    end
  end

  def initialize(nodes = [])
    super
    make_sets(nodes)
  end

  def unites(edges = [])
    edges.each { |e| unite(*e) }
    self
  end

  def make_sets(nodes)
    nodes.each { |e| root(e) }
    self
  end

  def exist?(...)
    @parents.has_key?(...)
  end

  def different?(...)
    !same?(...)
  end

  def roots
    @parents.inject({}) { |a, (e, _)| a.merge(e => root(e)) }
  end

  def groups
    @parents.keys.group_by { |e| root(e) }
  end

  def nodes
    @parents.keys
  end

  def node_count
    @parents.size
  end

  def inspect
    "親:#{@parents} 構図:#{groups}"
  end
end

# 経路圧縮
module PathCompress
  def root(x)
    if @parents[x] == x
      x
    else
      @parents[x] = root(@parents[x])
    end
  end

  # 全体
  def compress!
    @parents.each_key { |e| @parents[e] = root(e) }
  end
end

# union by rank
module UnionByRank
  attr_reader :ranks

  def initialize(...)
    super
    @ranks = Hash.new(1)
  end

  def unite(x, y)
    x = root(x)
    y = root(y)
    if x != y
      case
      when @ranks[x] < @ranks[y]
        chain(x, y)
      when @ranks[y] < @ranks[x]
        chain(y, x)
      else
        chain(x, y)
        @ranks[y] += 1
      end
    end
    self
  end

  def inspect
    [super, "高さ:#{@ranks}"] * " "
  end
end

# size(x) で x で木の大きさを得る
# tree_size_max でいちばん大きな木の大きさを得る
module SizeMod
  attr_reader :sizes

  def initialize(...)
    super
    @sizes = Hash.new(1)
  end

  def size(x)
    @sizes[root(x)]
  end

  def tree_size_max(...)
    @parents.collect { |e, _| size(e) }.max(...)
  end

  def inspect
    [super, "節数:#{@sizes}"] * " "
  end

  private

  def chain(x, y)
    super
    @sizes[y] += @sizes[x]
  end
end

# chain_count → 併合回数
# edge_count  → 全体の辺の数
# tree_count  → 木の数
module CountMod
  attr_reader :chain_count

  def initialize(...)
    super
    @chain_count = 0
  end

  def edge_count
    @chain_count
  end

  def tree_count
    @parents.size - @chain_count
  end

  def forest_size
    tree_count
  end

  def inspect
    [super, "木数:#{tree_count}"] * " "
  end

  private

  def chain(...)
    super
    @chain_count += 1
  end
end

# edge(x) で辺の数を返す(多重辺を含む)
module EdgeMod
  attr_reader :edge_counts

  def initialize(...)
    super
    @edge_counts = Hash.new(0)
  end

  def unite(x, y)
    x = root(x)
    y = root(y)
    if x == y
      @edge_counts[y] += 1
    end
    super
  end

  def edge(x)
    @edge_counts[root(x)]
  end

  def edge_max(...)
    @parents.collect { |e, _| edge(e) }.max(...)
  end

  private

  def chain(x, y)
    super
    @edge_counts[y] += @edge_counts[x] + 1
  end
end

module BoostMod
  def initialize(n)
    super
    @parents = Array.new(n) { |i| i }
    @ranks = Array.new(n) { |i| i }
  end
end

class DisjointSet < QuickUnion
  prepend HelperMod
  prepend PathCompress
  prepend UnionByRank
  prepend SizeMod
  prepend CountMod
  prepend EdgeMod
end

class Minimalist < QuickUnion
  prepend PathCompress
  prepend UnionByRank
  prepend BoostMod
end

DisjointSet2 = Minimalist

if $0 == __FILE__ && ENV["ATCODER"] != "1"
  require "rspec/autorun"
  RSpec.configure do |config|
    config.expect_with :test_unit
  end

  describe DisjointSet do
    it ".new" do
      ds = DisjointSet.new
      assert { ds.forest_size == 0 }
      ds = DisjointSet.new("A".."C")
      assert { ds.forest_size == 3 }
    end

    it "root" do
      ds = DisjointSet.new
      assert { ds.root("A") == "A" }
    end

    it "same? / different?" do
      ds = DisjointSet.new
      assert { !ds.same?("A", "B") }
      assert { ds.different?("A", "B") }
      ds.unite("A", "B")
      assert { ds.same?("A", "B") }
      assert { !ds.different?("A", "B") }
    end

    it "make_set / make_sets / exist?" do
      ds = DisjointSet.new
      assert { !ds.exist?("A") }
      ds.make_set("A")
      assert { ds.exist?("A") }
      ds.make_sets(["B", "C"])
      assert { ds.exist?("C") }
    end

    it "unite" do
      ds = DisjointSet.new
      ds.unite("A", "B")
      assert { ds.parents == {"A" => "B", "B" => "B"} }
    end

    it "unites" do
      ds = DisjointSet.new
      assert { ds.unites([["A", "B"], ["C", "D"]]).forest_size == 2 }
    end

    it ".[] / .unites" do
      assert { DisjointSet[].forest_size == 0 }
      assert { DisjointSet[[["A", "B"], ["C", "D"]]].forest_size == 2 }
      assert { DisjointSet.unites([["A", "B"], ["C", "D"]]).forest_size == 2 }
    end

    it "roots" do
      ds = DisjointSet.new
      assert { ds.roots == {} }
      ds.unite("A", "B")
      ds.unite("C", "C")
      assert { ds.roots == {"A" => "B", "B" => "B", "C" => "C"} }
    end

    it "groups" do
      ds = DisjointSet.new
      assert { ds.groups == {} }
      ds.unite("A", "B")
      ds.unite("C", "C")
      assert { ds.groups == {"B" => ["A", "B"], "C" => ["C"] } }
    end

    it "nodes" do
      ds = DisjointSet.new
      ds.unite("A", "B")
      ds.unite("C", "C")
      assert { ds.nodes == ["A", "B", "C"] }
    end

    it "node_count" do
      ds = DisjointSet.new
      ds.unite("A", "B")
      ds.unite("C", "C")
      assert { ds.node_count == 3 }
    end

    it "size" do
      ds = DisjointSet.new
      assert { ds.size("A") == 1 }
      ds.unite("A", "B")
      assert { ds.size("A") == 2 }
    end

    it "tree_size_max" do
      ds = DisjointSet.new
      assert { ds.tree_size_max == nil }
      ds.unite("A", "A")
      assert { ds.tree_size_max == 1 }
      ds.unite("A", "B")
      assert { ds.tree_size_max == 2 }
    end

    it "tree_count" do
      ds = DisjointSet.new
      assert { ds.tree_count == 0 }
      ds.unite("A", "B")
      assert { ds.tree_count == 1 }
      ds.unite("C", "C")
      assert { ds.tree_count == 2 }
    end

    it "union by rank" do
      ds = DisjointSet.new
      ds.unite("A", "B")
      ds.unite("B", "C")
      assert { ds.root("C") == "B" }
    end

    context "経路圧縮" do
      it "root" do
        ds = DisjointSet.new
        ds.unite("A", "B")
        ds.unite("C", "D")
        ds.unite("B", "D")
        assert { ds.parents["A"] == "B" }
        assert { ds.root("A") == "D" }
        assert { ds.parents["A"] == "D" }
      end

      it "compress!" do
        ds = DisjointSet.new
        ds.unite("A", "B")
        ds.unite("C", "D")
        ds.compress!
        assert { ds.parents == {"A" => "B", "B" => "B", "C" => "D", "D" => "D"} }
      end
    end

    context "辺の数" do
      it "edge / edge_max" do
        ds = DisjointSet.new
        assert { ds.edge_max == nil }
        assert { ds.edge("A") == 0 }
        assert { ds.edge_max == 0 }
        ds.unite("A", "B")
        ds.unite("A", "B")
        ds.unite("C", "D")
        assert { ds.edge("A") == 2 }
        assert { ds.edge("C") == 1 }
        assert { ds.edge("E") == 0 }
        assert { ds.edge_max == 2  }
        assert { ds.edge_counts == {"B" => 2, "D" => 1} }
      end
    end
  end

  describe Minimalist do
    it "works" do
      ds = Minimalist.new(2)
      assert { !ds.same?(0, 1) }
      ds.unite(0, 1)
      assert { ds.same?(0, 1) }
    end
  end
end

Discussion