🐤

Haskell で Union-Find とクラスカルのアルゴリズム

2022/12/26に公開

Haskell で、優先度付きキューを使ったダイクストラ法 でダイクストラ法を実装したので、続けて最少全域木問題に取り組んでみようと思います。

最少全域木問題といえばクラスカルのアルゴリズムです。以下に詳しい解説があります。

そして、確かずいぶん前に一度実装したような記憶が朧気ながらあって、検索してみたところ過去に自分が書いたブログがヒットしその日付は2009年でした。内容については全く覚えていませんでした。

クラスカルのアルゴリズム

気をとりなおしてクラスカルのアルゴリズムですが、このアルゴリズムは直感的には割と簡単というのが面白いところです。

グラフの辺の中から、重みが小さい順に辺を選んでいく。このとき選んだ辺によりグラフに閉路ができるならそれは選ばない。これだけで、最少全域木が構成される。

直感的には簡単ですが、実装には一つ山があります。重みが小さい順に辺を選んでいく・・・というのは単に、辺を重み順でソートすれば良いので難しくない。問題は「選んだ辺によりグラフに閉路ができるなら」をどう判定するかです。絵を書いて眺めたらわりと明らかなことですが、実装するとなるとそうもいかない。

この判定には互いに素な集合 (Disjoint Sets) を扱えるデータ構造である Union-Find を使うと良いことが分かっています。

Haskell で Union-Find

さて Haskell で Union-Find をどう実装するかです。

アルゴ式に Union-Find の練習問題があるので、これを解きます。

Union-Find は互いに素な集合に対し、二つの集合を併合する (グラフでいうと、グラフ同士を連結する) ですとか、集合の代表元を求める (グラフでいうと木の根にあたる頂点を求める) ことを可能にするデータ構造で、シンプルながら応用が効き、競技プログラミングでも活用場面が多いそうです。

Union-Find の実装には計算量を抑えるためのコツがふたつあります。

  • 集合の併合時に、常にサイズもしくは木の高さが大きい集合に、小さい集合を併合させる (union by size / union by rank )
  • 代表元 (根) を求める経路を適宜単純化する (圧縮する)

このうち後者は、根を求める手続き root(x) の中で副作用を起こしてついでに経路を圧縮するよう実装するのが定番のようです。うーん、副作用、ここが Haskell での悩ましポイントですね。

イミュータブルな Union-Find

経路圧縮は実装しなくても union by size だけでもそれなりに計算量は抑えられるらしいので、まずは経路圧縮は保留してイミュータブルに実装してみます。

頂点集合の表現には Data.IntMap.Strict を使いました。ある頂点の親の頂点 (parent) と、ある代表元が含まれる集合のサイズを UnionFind 型に持たせます。その UnionFind 型に対し

  • ある頂点が含まれる集合の代表元 (根) をみつける root
  • ふたつの頂点それぞれが含まれる集合を併合する unite
  • ふたつの頂点が同じ集合に含まれているかどうか判定する isSame

関数を実装します。

ほかの言語では経路圧縮は root の中で副作用を起こして parent を更新するのが定番のようですが、前述のとおりいったんそれは保留します。

{-# LANGUAGE TupleSections #-}

import Control.Monad (replicateM)
import qualified Data.ByteString.Char8 as BS
import Data.Char (isSpace)
import qualified Data.IntMap.Strict as IM
import Data.List (mapAccumL, unfoldr)

getInts :: IO [Int]
getInts = unfoldr (BS.readInt . BS.dropWhile isSpace) <$> BS.getLine

data UnionFind = UnionFind
  { parent :: IM.IntMap Int,
    size :: IM.IntMap Int
  }
  deriving (Show)

newUF :: (Int, Int) -> UnionFind
newUF (s, e) =
  UnionFind
    { parent = IM.fromList $ map (,-1) [s .. e],
      size = IM.fromList $ map (,1) [s .. e]
    }

-- TODO: 経路圧縮
root :: UnionFind -> Int -> Int
root uf x
  | p == -1 = x
  | otherwise = root uf p
  where
    p = parent uf IM.! x

-- union by size
unite :: UnionFind -> Int -> Int -> UnionFind
unite uf x y
  | x' == y' = uf
  | sizeX > sizeY = update uf x' (y', sizeY)
  | otherwise = update uf y' (x', sizeX)
  where
    x' = root uf x
    y' = root uf y
    sizeX = size uf IM.! x'
    sizeY = size uf IM.! y'

    -- a を b の親にする更新
    update :: UnionFind -> Int -> (Int, Int) -> UnionFind
    update u a (b, sizeB) =
      u
        { parent = IM.insert b a (parent u),
          size = IM.adjust (+ sizeB) a (size u)
        }

isSame :: UnionFind -> Int -> Int -> Bool
isSame uf x y = root uf x == root uf y

main :: IO ()
main = do
  [n, m] <- getInts
  abs_ <- replicateM m getInts

  let uf0 = newUF (0, n - 1)

  let result =
        mapAccumL
          ( \uf [a, b] ->
              let x = isSame uf a b
               in (if x then uf else unite uf a b, x)
          )
          uf0
          abs_

  mapM_ (\x -> putStrLn $ if x then "Yes" else "No") (snd result)

これで AC します。

気になる速度のほうはといいますと

うーん、遅い。他の人の C++ の提出をみていると 300 〜 400ms ぐらいで通している人が多いので、1.4 sec はちょっと。これは経路圧縮してない問題というより Map を使っているからの方が大きいかもしれない。

MArray で Union-Find

あきらめてミュータブルなデータ構造に切り替えます。
STUArray / IOUArray それぞれをベースにした実装を作りました。こちらは経路圧縮もやっています。

STUArray で構成したもの

import Control.Monad (forM, replicateM, unless, when)
import Control.Monad.ST (ST, runST)
import Data.Array.MArray (readArray, writeArray)
import Data.Array.ST (MArray (newArray), STUArray)
import qualified Data.ByteString.Char8 as BS
import Data.Char (isSpace)
import Data.List (unfoldr)

data UnionFind s
  = UnionFind
      (STUArray s Int Int) -- parent
      (STUArray s Int Int) -- size

newUF :: (Int, Int) -> ST s (UnionFind s)
newUF (s, e) =
  UnionFind
    <$> newArray (s, e) (-1)
    <*> newArray (s, e) 1

root :: UnionFind s -> Int -> ST s Int
root uf@(UnionFind parent _) x = do
  p <- readArray parent x
  if p == (-1)
    then return x
    else do
      p' <- root uf p
      writeArray parent x p' -- 経路圧縮
      return p'

unite :: UnionFind s -> Int -> Int -> ST s ()
unite uf@(UnionFind parent size) x y = do
  x' <- root uf x
  y' <- root uf y

  when (x' /= y') $ do
    sizeX <- readArray size x'
    sizeY <- readArray size y'

    -- union by size
    if sizeX > sizeY
      then do
        writeArray parent y' x'
        writeArray size x' (sizeX + sizeY)
      else do
        writeArray parent x' y'
        writeArray size y' (sizeX + sizeY)

isSame :: UnionFind s -> Int -> Int -> ST s Bool
isSame uf x y = (==) <$> root uf x <*> root uf y

getInts :: IO [Int]
getInts = unfoldr (BS.readInt . BS.dropWhile isSpace) <$> BS.getLine

main :: IO ()
main = do
  [n, m] <- getInts
  qs <- replicateM m getInts

  let result = runST $ do
        uf <- newUF (0, n - 1)
        forM qs $ \[a, b] -> do
          same <- isSame uf a b
          unless same $ unite uf a b
          return same

  mapM_ (\x -> putStrLn $ if x then "Yes" else "No") result

IOUArray で構成したもの

import Control.Monad (forM_, replicateM, unless, when)
import Data.Array.IO (IOUArray)
import Data.Array.MArray (readArray, writeArray)
import Data.Array.ST (MArray (newArray))
import qualified Data.ByteString.Char8 as BS
import Data.Char (isSpace)
import Data.List (unfoldr)

data UnionFind = UnionFind (IOUArray Int Int) (IOUArray Int Int)

newUF :: (Int, Int) -> IO UnionFind
newUF (s, e) = UnionFind <$> newArray (s, e) (-1) <*> newArray (s, e) 1

root :: UnionFind -> Int -> IO Int
root uf@(UnionFind parent _) x = do
  p <- readArray parent x
  if p == (-1)
    then return x
    else do
      p' <- root uf p
      writeArray parent x p'
      return p'

unite :: UnionFind -> Int -> Int -> IO ()
unite uf@(UnionFind parent size) x y = do
  x' <- root uf x
  y' <- root uf y

  when (x' /= y') $ do
    sizeX <- readArray size x'
    sizeY <- readArray size y'

    if sizeX > sizeY
      then do
        writeArray parent y' x'
        writeArray size x' (sizeX + sizeY)
      else do
        writeArray parent x' y'
        writeArray size y' (sizeX + sizeY)

isSame :: UnionFind -> Int -> Int -> IO Bool
isSame uf x y = (==) <$> root uf x <*> root uf y

getInts :: IO [Int]
getInts = unfoldr (BS.readInt . BS.dropWhile isSpace) <$> BS.getLine

main :: IO ()
main = do
  [n, m] <- getInts
  qs <- replicateM m getInts

  uf <- newUF (0, n - 1)
  forM_ qs $ \[a, b] -> do
    same <- isSame uf a b
    unless same $ unite uf a b
    putStrLn $ if same then "Yes" else "No"

ばりばり手続き型のプログラムになってしまいますが、まあ仕方がない。
速度のほうはと言いますと

といずれも 200m 程度。だいぶ速くなりました、まずまずでしょう。

Union-Find を使ってクラスカルのアルゴリズム

Union-Find の実装ができたので、最少全域木問題を解きます。こちらもアルゴ式に練習問題がありました。

Union-Find さえできてしまえば、あとは簡単です。辺を重みでソートして、その順番に Union-Find で辺に含まれる頂点同士を併合してグラフを全域木を構築していきます。このとき、閉路ができないよう isSameTrue になるケースは飛ばします。

import Control.Monad (forM, replicateM, when)
import Data.Array.IO (IOUArray)
import Data.Array.MArray (readArray, writeArray)
import Data.Array.ST (MArray (newArray))
import qualified Data.ByteString.Char8 as BS
import Data.Char (isSpace)
import Data.List (sortOn, unfoldr)
import Data.Maybe (catMaybes)

data UnionFind = UnionFind (IOUArray Int Int) (IOUArray Int Int)

newUF :: (Int, Int) -> IO UnionFind
newUF (s, e) = UnionFind <$> newArray (s, e) (-1) <*> newArray (s, e) 1

root :: UnionFind -> Int -> IO Int
root uf@(UnionFind parent _) x = do
  p <- readArray parent x
  if p == (-1)
    then return x
    else do
      p' <- root uf p
      writeArray parent x p'
      return p'

unite :: UnionFind -> Int -> Int -> IO ()
unite uf@(UnionFind parent size) x y = do
  x' <- root uf x
  y' <- root uf y

  when (x' /= y') $ do
    sizeX <- readArray size x'
    sizeY <- readArray size y'

    if sizeX > sizeY
      then do
        writeArray parent y' x'
        writeArray size x' (sizeX + sizeY)
      else do
        writeArray parent x' y'
        writeArray size y' (sizeX + sizeY)

isSame :: UnionFind -> Int -> Int -> IO Bool
isSame uf x y = (==) <$> root uf x <*> root uf y

getInts :: IO [Int]
getInts = unfoldr (BS.readInt . BS.dropWhile isSpace) <$> BS.getLine

-- クラスカルのアルゴリズムで最少全域木を構成する
kruskal :: (Int, Int) -> [(Int, (Int, Int))] -> IO [(Int, (Int, Int))]
kruskal (s, e) edges = do
  uf <- newUF (s, e)
  es <- forM edges' $ \edge@(u, (v, _)) -> do
    same <- isSame uf u v
    
    -- 閉路を作らないケースのみ併合
    if not same
      then unite uf u v >> return (Just edge)
      else return Nothing
  return (catMaybes es)
  where
    edges' = sortOn (snd . snd) edges

main :: IO ()
main = do
  [n, m] <- getInts
  edges <- map (\[u, v, w] -> (u, (v, w))) <$> replicateM m getInts

  es <- kruskal (0, n -1) edges

  print $ sum $ map (\(_, (_, w)) -> w) es

無事 AC しました。

状態の更新の違いとプログラミングパラダイム

ところで Union-Find を実装し、それをベースに問題を解いていると Union-Find は併合を繰り返す中、そのときどきのグラフや集合のグループ (連結成分) の状態を管理しているデータ構造なんだということに気がつきます。

この「状態管理」という観点でイミュータブルな Union-Find とミュータブルな Union-Find の実装を比較してみると、関数型のパラダイムと手続き型のパラダイムの差がとてもはっきり出るのがよくわかります。

  • Union-Find をミュータブルなデータ構造として構成して手続き的に書くと、状態はある単一の Union-Find のデータ構造の中に留まり、変更があった場合はそれ自体が書き換わる
  • 一方、イミュータブルなデータ構造で状態を表現している場合は、状態 (Union-Find) が更新されるとそれは別の値になる。 mapAccumL のような関数を使うとはっきりわかるように、すべての状態遷移は別の値として列挙することができる (もちろん畳み込むこともできる)

当たり前といえば当たり前なんですが、この差がプログラミングスタイルに与える影響は顕著で

  • 状態の変更にあたって状態を保持しているオブジェクトそのものを更新するのではなく、イミュータブルにそれを扱う場合、状態の変化は手続きの戻り値として明示的に返却する必要がある。つまりそれは関数になる。
  • 対照的に、オブジェクトそのものを更新する場合、値そのものが更新されるので手続きから返却する必要がない。つまり返値を伴わない。結果、手続き的なプログラミングになる

ということが見て取れます。

ミュータブルなデータ構造を中心に置くとそこから芋づる式的に手続き型プログラミングになっていく、というのは Haskell をやっているとよくわかることですが、同じ目的のデータ構造をイミュータブル版 / ミュータブル版で実装してみるとなんでそうなのかがはっきり見えて来て面白いです。

Discussion