Open29

Haskell関連

4tsuzuru4tsuzuru

DFS・BFS

応用が効きやすいDFSとBFSの雛形。
グラフ構造の余代数がはっきりせず、余再帰やrecursion schemeは使えず。木構造ならシンプルに実装可能。

import Data.Array
import Data.Tree
import Control.Monad.Trans.State
import qualified Data.Set as S

bfs :: Ord a => (a -> [a]) -> S.Set a -> [S.Set a]
bfs getNexts vs= evalState (explore vs) S.empty
  where
    explore ts = do
      seen <- get
      let ts' = S.difference ts seen
      if S.null ts' then return [] else do
        put (S.union seen ts')
        (ts':) <$> (explore $ S.fromList $ concatMap getNexts $ S.elems ts')


dfs :: Ord a => (a -> [a]) -> [a] -> [Tree a]
dfs getNexts vs = evalState (explore vs) S.empty
  where
    explore [] = return []
    explore (v:vs) = do
      seen <- get
      if v `S.member` seen then explore vs else do
        put (S.insert v seen)
        (:) <$> (Node v <$> explore (getNexts v)) <*> explore vs

4tsuzuru4tsuzuru

ダイクストラ法

基本的にはリンク先をみればOK
https://zenn.dev/naoya_ito/articles/f05885f9e302ba
経路復元は一つ前を記録する方法が汎用性あり

import qualified Data.Heap as H
import Data.Heap (Entry (Entry))
import qualified Data.IntMap.Strict as IM
import Data.Array

dijkstra :: Array Int [(Int, Int)] -> Int -> Int -> IM.IntMap Int
dijkstra g n v0 = go (Just (Entry 0 v0, H.empty)) IM.empty (IM.fromList $ (v0, 0) : [(k, maxBound :: Int) | k <- [2 .. n]]) where
  go Nothing _ dist = dist
  go (Just (Entry currentCost i, q)) done dist
    | IM.member i done = go (H.uncons q) done dist
    | otherwise = go (H.uncons q') done' dist' 
      where
        done' = IM.insert i currentCost done
        dist' = foldl (\acc (j, cost) -> IM.insertWith min j (currentCost + cost) acc) dist $ g ! i
        q' = foldl (\acc (j, cost) -> H.insert (Entry (currentCost + cost) j) acc) q $ g ! i

-- 経路復元が必要な時
dijkstra' :: Array Int [(Int, Int)] -> Int -> Int -> IM.IntMap (Int,Int)
dijkstra' g n v0 = go (Just (Entry 0 v0, H.empty)) IM.empty (IM.fromList $ (v0, (0,0)) : [(k, ((maxBound :: Int), -1)) | k <- [2 .. n]]) where
  go Nothing _ dist = dist
  go (Just (Entry currentCost i, q)) done dist
    | IM.member i done = go (H.uncons q) done dist
    | otherwise = go (H.uncons q') done' dist' 
      where
        done' = IM.insert i currentCost done
        q' = foldl (\acc (j, cost) -> H.insert (Entry (currentCost + cost) j) acc) q $ g ! i
        dist' = foldl (\acc (j, cost) -> 
          if currentCost + cost < fst (acc IM.! j) then IM.insert j (currentCost + cost, i) acc else acc) dist $ g ! i

let dist  =  dijkstra graph n 1
let ans = unfoldr (\i -> if i == 0 then Nothing else Just ( i, snd (dist IM.! i))) n
4tsuzuru4tsuzuru

動的計画法 (一次元DP)

基本的にはSTモナドを利用するのが最も分かりやすいし、効率的。
Stateモナドもdynamorphismもsmall setなら問題ないが、large setだとTLEになりやすい。

B17 - Frog 1 with Restorationを例にする
https://atcoder.jp/contests/tessoku-book/tasks/tessoku_book_cp

Dynamorphism

下記ではTLE。改善策はわからず。

pair (f, g) x = (f x, g x)
cross (f, g) (x, y) = (f x, g y)

newtype Fix f = In { out :: f (Fix f) }
 
newtype Cofree f a = Cf { unCf :: (a, f (Cofree f a)) }
extract :: Cofree f t -> t
extract = fst . unCf
sub :: Functor f => Cofree f a -> f (Cofree f a)
sub = snd . unCf

newtype Free f a = Fr { unFr :: Either a (f (Free f a)) }
inject :: a -> Free f a
inject = Fr . Left
 
cata :: Functor f => (f a -> a) -> Fix f -> a
cata phi = phi . fmap (cata phi) . out
ana :: Functor f => (a -> f a) -> a -> Fix f
ana psi = In . fmap (ana psi) . psi
hylo :: Functor f => (f b -> b) -> (a -> f a) -> a -> b
hylo phi psi = cata phi . ana psi -- phi . fmap (hylo phi psi) . psi
para :: Functor f => (f (Fix f, t) -> t) -> Fix f -> t
para phi = phi . fmap (pair (id, para phi)) . out
histo :: Functor f => (f (Cofree f t) -> t) -> Fix f -> t
histo phi = extract . cata (Cf . pair (phi, id)) 
futu :: Functor f => (t -> f (Free f t)) -> t -> Fix f
futu psi = ana (uncurry either (psi, id) . unFr) . inject
chrono :: Functor f => (f (Cofree f b) -> b) -> (a -> f (Free f a)) -> a -> b
chrono phi psi = histo phi . futu psi
dyna :: Functor f => (f (Cofree f b) -> b) -> (a -> f a) -> a -> b
dyna f g = chrono f (fmap inject . g) -- histo f . ana g

data NatF r =
    ZeroF
  | SuccF r
  deriving (Show, Functor)

type Nat = Fix NatF


intToNat :: Int -> Nat
intToNat = ana phi where
  phi 0 = ZeroF
  phi n = SuccF (n-1)

natToInt :: Nat -> Int
natToInt = cata phi where
  phi ZeroF = 0
  phi (SuccF x) = x + 1



compress :: Num a1 => NatF (Cofree NatF a2) -> a1
compress ZeroF = 0
compress (SuccF x) = 1 + compress y where
  y = sub x

solve :: Array Int Int -> Int -> (Int, [Int])
solve hs = dyna phi psi where
  psi 1 = ZeroF
  psi n = SuccF (n-1)

  phi ZeroF = (0, [1])
  phi (SuccF (Cf (_, ZeroF))) = (abs (hs ! 2 - hs ! 1), [2,1])
  phi cur@(SuccF x) = if d1 < d2 then (d1, i:path1) else (d2, i:path2) where
    i = compress cur + 1
    (f1, path1) = extract x 
    (f2, path2) = case sub x of
      ZeroF   -> (0, [1])
      SuccF y -> extract y 
    d1 = f1 + abs (hs ! i - hs ! (i-1))
    d2 = f2 + abs (hs ! i - hs ! (i-2))

下記のようにSTモナドの利用のほうが応用が効きやすい。経路復元は前項参照。

solve ::  Array Int Int -> Int -> Array (Int, Int) Int
solve hs n = runST $ do
  memo <-  newArray ((1,0), (n,1)) 0 :: ST s (STUArray s (Int, Int) Int)
  writeArray memo (2,0) $ abs (hs ! 2 - hs ! 1)
  writeArray memo (2,1) 1
  forM_ [3..n] $ \i -> do
    f1 <- readArray memo ((i-1), 0)
    f2 <- readArray memo ((i-2), 0)
    let d1 = f1 + abs (hs ! i - hs ! (i-1))
    let d2 = f2 + abs (hs ! i - hs ! (i-2))
    if d1 < d2
      then do
        writeArray memo (i,0) (d1)
        writeArray memo (i,1) (i-1)
      else do
        writeArray memo (i,0) (d2)
        writeArray memo (i,1) (i-2)
  freeze memo


solve':: Array Int Int -> Int -> Array Int (Int,Int)
solve' hs n =  memo where
  memo = listArray (1,n)  $  (0,0) : (abs (hs ! 2 - hs ! 1) , 1) : [ if fst d1 < fst d2 then d1 else d2 | i <- [3..n], let d1 = (fst (memo ! (i-1)) + abs (hs ! i - hs ! (i-1)), i-1), let d2 = (fst (memo ! (i-2)) +  abs (hs ! i - hs ! (i-2)), i-2)]
4tsuzuru4tsuzuru

動的計画法 (0-1 ナップサック問題)

https://atcoder.jp/contests/tessoku-book/tasks/tessoku_book_s

1. Accumlation

knapsack :: Int ->  [(Int,Int)] -> (Int,Int) -> [(Int,Int)]
knapsack w dp (wi,vi) = merge (dropWhile ((w<) . fst) $ map (\(w',v') -> (w'+wi, v'+vi)) dp) dp 
  where merge xxs@(x:xs) yys@(y:ys) = case compare (fst x) (fst y) of
          LT -> if snd x < snd y then y : merge xxs ys  else merge xxs ys
          EQ -> if snd x > snd y then x : merge xs ys else y : merge xs ys
          GT -> if snd x > snd y then x : merge xs yys else merge xs yys
        merge [] ys = ys
        merge xs [] = xs

let ans = maximum $ map snd  $ foldl (knapsack w) [(0,0)] $ wvs

2.1. STモナド

knapsack :: Array Int (Int,Int) -> Int -> Int -> Array (Int,Int) Int
knapsack wvs n w =  runST $ do
  memo <- newArray ((0,0),(w,n)) (-1)  :: ST s (STUArray s (Int,Int) Int)
  writeArray memo (0,0) 0
  forM_ [1..n] $ \i -> do
    forM_ [0..w] $ \j -> do
      if (j >= fst (wvs!i)) then do
        f1 <- readArray memo (j, (i-1))
        f2 <- readArray memo (( j - fst (wvs!i)), (i-1))
        if f2 == -1 then do
          writeArray memo (j, i) f1
        else do
          writeArray memo (j, i) $ max f1 (f2 + snd (wvs!i))
      else do
        f1 <- readArray memo (j, (i-1))
        writeArray memo (j, i) f1
  freeze memo
let table = knapsack wvs n w
let ans =  maximum $ map (\i -> table!(i,n)) [0..w]

2.2 STモナド(Weightの上限が巨大な場合)

Unboxed Arrayを使わないとMLEになる

knapsack :: Array Int (Int,Int) -> Int -> Int -> Int-> UArray (Int,Int) Int
knapsack wvs n w vmax=  runST $ do
  memo <- newArray ((0,0),(vmax,n)) (10^15)  :: ST s (STUArray s (Int,Int) Int)
  writeArray memo (0,0) 0
  forM_ [1..n] $ \i -> do
    forM_ [0..vmax] $ \v -> do
       f1 <- readArray memo (v, (i-1))
       if (v >= snd (wvs!i)) then do
         f2 <- readArray memo (( v - snd (wvs!i)), (i-1))
         writeArray memo (v, i) $ min f1 (f2 + fst (wvs!i))
       else do
         writeArray memo (v, i) f1
  freeze memo

3.Dynamorphism

Dynamorphismの定義は2つ上を参照

import qualified Data.Vector.Unboxed as U
data ListF a r =
    NilF
  | ConsF a r
  deriving (Show, Functor)

knapsack'' wvs n w = dyna phi psi n where
  psi 0 = NilF 
  psi i = ConsF (wvs ! i) (i-1) 
 
  phi (ConsF (cw, cv) (Cf (_, NilF))) =  U.generate (w+1) (bool 0 cv . (cw==))
  phi (ConsF (cw, cv) x) =  U.zipWith max prev vec
    where
      prev = extract x
      vec = U.replicate cw 0 U.++ U.map (+cv) prev
4tsuzuru4tsuzuru

動的計画法(LCS)

https://atcoder.jp/contests/tessoku-book/tasks/tessoku_book_t

1. STモナド

lcs :: Array Int Char -> Array Int Char -> Int -> Int -> Int
lcs xs ys n m = runST $ do
  memo <- newArray ((0,0),(n,m)) 0 :: ST s (STUArray s (Int,Int) Int)
  forM_ [1..n] $ \i -> do
    forM_ [1..m] $ \j -> do
      if xs!i == ys!j then do
        f1 <- readArray memo (i-1, j-1)
        writeArray memo (i,j) $ f1+1
      else do
        f2 <- readArray memo (i-1, j)
        f3 <- readArray memo (i, j-1)
        writeArray memo (i,j) $ max f2 f3
  readArray memo (n,m)

2. Dynamorphism

詳しくは下記。
https://lotz84.github.io/recursion-algorithms/RecursionSchemes/Extra.html
TLEのため、改善必要だが。。。

lcsDyna :: Eq a => ([a], [a]) -> [a]
lcsDyna (cs, cs') =
  let n = length cs
   in dyna (g n) f (cs, cs')
  where
  f ([],   [])   = F []     []     Nothing
  f ([],   b:bs) = F []     (b:bs) (Just (cs, bs))
  f (a:as, bs)   = F (a:as) bs     (Just (as, bs))
  g _ (F _      _      Nothing)  = []
  g _ (F []     _      _)        = []
  g _ (F _      []     _)        = []
  g n (F (a:as) (b:bs) (Just x)) =
    let x1 = extract x
        x2 = extract $ iterate pi x !! n
        x3 = extract $ iterate pi x !! (n + 1)
    in if a == b then a:x3 else (if length x1 > length x2 then x1 else x2)
  pi (Cf (_ , F _ _ (Just y))) = y

Fantastic Morphisms and Where to Find Them ⋆
A Guide to Recursion Schemesより引用

lcs' s1 s2 = dyna a g (s1, s2) where
  g ([ ], [ ]) = NilF
  g (x , y) = if null y then ConsF (x , y) (tail x , s2) else ConsF (x , y) (x , tail y)

  a NilF = 0
  a (ConsF (x , y) table)
    | null x || null y = 0
    | head x == head y = index table (offset 1 1) + 1
    | otherwise = max (index table (offset 1 0))
    (index table (offset 0 1))
  index t 0 = extract t
  index (Cf (_,ConsF _ t')) n = index t' (n-1)

  offset n m = n * (length s2 + 1) + m - 1
4tsuzuru4tsuzuru

動的計画法(Edit Distance)

editDistance :: Array Int Char -> Array Int Char -> Int -> Int -> Int
editDistance xs ys n m = runST $ do
  memo <- newArray ((0,0),(n,m)) 0 :: ST s (STUArray s (Int,Int) Int)
  forM_ [0..n] $ \i -> do
    writeArray memo (i,0) i
  forM_ [0..m] $ \j -> do
    writeArray memo (0,j) j
  forM_ [1..n] $ \i -> do
    forM_ [1..m] $ \j -> do
      if xs!i == ys!j then do
        f1 <- readArray memo (i-1, j-1)
        writeArray memo (i,j) $ f1
      else do
        f2 <- readArray memo (i-1, j)
        f3 <- readArray memo (i, j-1)
        f4 <- readArray memo (i-1, j-1)
        writeArray memo (i,j) $ 1 + minimum [f2,f3,f4]
  readArray memo (n,m)
4tsuzuru4tsuzuru

動的計画法

おそらくこのやり方が最も汎用性が高い。maximumとplusはSemiringであることに注意。

knapsack :: (Array Int (Int,Int), Int, Int)  -> Int
knapsack (wvs, n, w) = runST $ do 
  memo <- newArray ((0,1), (w,n)) (-1) :: ST s (STUArray s (Int,Int) Int)
  go (w,1) memo where
    go p memo
      | isTrivial p = return 0
      | otherwise = do
        res <- readArray memo p
        if res /= -1 then return res
        else do
          ret <- maximum <$> mapM (\(s, sp) -> (+s) <$> go sp memo) (subproblems p)
          writeArray memo p ret
          return ret

    isTrivial (w,i) = i > n
    subproblems (w,i) 
      | w >= w1 = [(v1, (w-w1, i+1)), (0, (w, i+1))]
      | otherwise = [(0, (w, i+1))]
      where (w1,v1) = wvs!i


lcs:: (Array Int Char,Array Int Char, Int, Int)  -> Int
lcs (xs, ys, n, m) = runST $ do 

  memo <- newArray ((0,0), (n,m)) (-1) :: ST s (STUArray s (Int,Int) Int)
  writeArray memo (0,0) 0

  go (1,1) memo where
    go p memo
      | isTrivial p = return 0
      | otherwise = do
        res <- readArray memo p
        if res /= -1 then return res
        else do
          ret <- maximum <$> mapM (\(s, sp) -> (+s) <$> go sp memo) (subproblems p)
          writeArray memo p ret
          return ret

    isTrivial (i,j) = i > n || j > m
    subproblems (i,j) 
      | x == y = [(1, (i+1, j+1))]
      | otherwise = [(0, (i+1, j)), (0, (i, j+1))]
      where x = xs!i
            y = ys!j


blockgame :: (Array Int (Int,Int), Int)  -> Int
blockgame (pas, n) = runST $ do 
  memo <- newArray ((1,1), (n,n)) (-1) :: ST s (STUArray s (Int,Int) Int)
  go (1,n) memo where
    go p memo
      | isTrivial p = return 0
      | otherwise = do
        res <- readArray memo p
        if res /= -1 then return res
        else do
          ret <- maximum <$> mapM (\(s, sp) -> (+s) <$> go sp memo) (subproblems p)
          writeArray memo p ret
          return ret

    isTrivial (l,r) = l == r
    subproblems (l, r) = [(v1, (l+1, r)), (v2, (l, r-1))]
      where
        (p1,a1) = pas!l
        (p2,a2) = pas!r
        v1 = if l <= p1 && p1 <= r then a1 else 0
        v2 = if l <= p2 && p2 <= r then a2 else 0

editDistance :: (Array Int Char,Array Int Char, Int, Int)  -> Int
editDistance (xs, ys, n, m) = runST $ do 
  memo <- newArray ((1,1), (n+1,m+1)) (-1) :: ST s (STUArray s (Int,Int) Int)
  go (1,1) memo where
    go p memo
      | isTrivial p = return 0
      | otherwise = do
        res <- readArray memo p
        if res /= -1 then return res
        else do
          ret <- minimum <$> mapM (\(s, sp) -> (+s) <$> go sp memo) (subproblems p)
          writeArray memo p ret
          return ret

    isTrivial (i,j) = i >n && j>m
    subproblems (i,j) 
      | i > n = [(1, (i, j+1))]
      | j > m = [(1, (i+1, j))]
      | x == y = [(0, (i+1, j+1))]
      | otherwise = [(1, (i+1, j+1)), (1, (i+1, j)), (1, (i, j+1))]
      where x = xs!i 
            y = ys!j 

dungeon1 :: (Array Int Int,Array Int Int, Int) -> Int
dungeon1 (as, bs, n) = runST $ do 
  memo <- newArray (1,n) (-1) :: ST s (STUArray s Int Int)
  go n memo where
    go p memo
      | isTrivial p = return 0
      | otherwise = do
        res <- readArray memo p
        if res /= -1 then return res
        else do
          ret <- minimum <$> mapM (\(s, sp) -> (+s) <$> go sp memo) (subproblems p)
          writeArray memo p ret
          return ret
    isTrivial p = p <= 1
    subproblems p 
      | p == 2 = [(as ! 2, 1)]
      | otherwise = [(as ! p, p-1), (bs ! p, p-2)]

dungeon2 :: (Array Int Int, Int) -> Int
dungeon2 (hs, n) = runST $ do 
  memo <- newArray (1,n) (-1) :: ST s (STUArray s Int Int)
  go n memo where
    go p memo
      | isTrivial p = return 0
      | otherwise = do
        res <- readArray memo p
        if res /= -1 then return res
        else do
          ret <- minimum <$> mapM (\(s, sp) -> (+s) <$> go sp memo) (subproblems p)
          writeArray memo p ret
          return ret
    isTrivial p = p <= 1 
    subproblems p 
      | p == 2 = [(abs (hs ! 2 - hs ! 1), 1)]
      | otherwise = [(abs (hs ! p - hs ! (p-1)), p-1), (abs (hs ! p - hs ! (p-2)), p-2)]
   
4tsuzuru4tsuzuru

Binary Search

下記のやり方で、整数・実数ともに統一的に対応できる。
https://byorgey.wordpress.com/2023/01/01/competitive-programming-in-haskell-better-binary-search/

search :: (a -> a -> Maybe a) -> (a -> Bool) -> a -> a -> (a,a)
search mid p = go
  where
    go l r = case mid l r of
      Nothing -> (l,r)
      Just m
        | p m       -> go l m
        | otherwise -> go m r

binary :: Integral a => a -> a -> Maybe a
binary l r
  | r - l > 1 = Just ( (l + r) `div` 2)
  | otherwise = Nothing

continuous :: (Fractional a, Ord a) => a -> a -> a -> Maybe a
continuous eps l r
  | r - l > eps = Just ( l + (r-l) / 2)
  | otherwise = Nothing

具体例

整数

https://atcoder.jp/contests/tessoku-book/tasks/tessoku_book_l

 search binary (\x -> k <= (foldl (\sum cur -> x `div` cur+ sum) 0 as)) 1 (10^9+1)

実数

https://atcoder.jp/contests/tessoku-book/tasks/tessoku_book_ck

search (continuous 0.001) (\x -> x**3+x > n ) 0 50
4tsuzuru4tsuzuru

使いやすさを考えて、二分探索は下記を使うことに。以前のバージョンは実質的に条件を満たす区間の下限のインデックスを求めるものであったが、下記であれば、上限も下限もどちらも求められる。
また、境界条件として、区間全域が条件を満たす場合・満たさない場合でも適応できるようにした。

search :: (a -> a -> Maybe a) -> (a -> Bool) -> a -> a -> Maybe a
search mid p ok ng = 
  if (not . p) ok then Nothing
  else if p ng then Just ng
  else
    go ok ng
    where
      go ok ng = case mid ok ng of
        Nothing -> Just ok
        Just m
          | p m       -> go m ng
          | otherwise -> go ok m

binary :: Integral a => a -> a -> Maybe a
binary l r
  | abs (r - l) > 1 = Just ( (l + r) `div` 2)
  | otherwise = Nothing

continuous :: (Fractional a, Ord a) => a -> a -> a -> Maybe a
continuous eps l r
  | abs(r - l) > eps = Just ( l + (r-l) / 2)
  | otherwise = Nothing

また、二分探索前のソートは下記のようにするのが高速。

import qualified Data.Vector.Unboxed  as VU
import qualified Data.Vector.Algorithms.Intro as VAI

let sorted = VU.modify VAI.sort $ VU.fromList as

二分探索を用いた座標圧縮は下記

import qualified Data.IntSet as IS
import qualified Data.Vector.Unboxed  as VU

compressing :: [Int] -> [Int]
compressing xs = map fromJust $ map (\a -> search binary (\i -> sorted VU.! i <= a) 0 (VU.length sorted - 1)) xs
  where
    sorted = VU.fromList $ IS.toAscList $ IS.fromList xs
4tsuzuru4tsuzuru

しゃくとり法

群に類似した構造を持つ連続部分列に対するアルゴリズム(と自分は思ってる)。
(+, -, 0), (*, div, 1), (max, 左から除去, minBound), (S.insert, S.delete, S.emptry)などの構造を意識すれば、分かりやすい。

shakutori p op invOp identity as = go as as 0 identity
  where
    go lls@(l:ls) [] len res = len : (go ls [] (len-1) (invOp res l))
    go lls@(l:ls) rrs@(r:rs) len res
      | p l r res = go lls rs (len + 1) (op res r)
      | len == 0 = 0:(go ls rs 0 identity)
      | otherwise =  len : (go ls rrs (len-1) (invOp res l))
    go _ _ _ _ = []

具体例

https://atcoder.jp/contests/abc032/tasks/abc032_c

let ans = shakutori (\l r res -> res *r <= k) (*) div 1 as
print $ maximum ans

https://atcoder.jp/contests/tessoku-book/tasks/tessoku_book_cl

let ans = shakutori (\l r res -> res + r <= k) (+) (-) 0 as
print $ sum ans

https://atcoder.jp/contests/abc038/tasks/abc038_c

let ans = shakutori (\l r res -> res < r) max (\res l -> if res == l then minBound else res) minBound as
print $ sum ans

https://atcoder.jp/contests/tessoku-book/tasks/tessoku_book_m

let ans = shakutori (\l r res -> r- l <=k ) (+) (-) 0 as
print $ sum $ map (\x-> x-1) ans

https://atcoder.jp/contests/arc022/tasks/arc022_2

let ans = shakutori (\l r res ->not (S.member r res)) (flip S.insert) (flip S.delete) S.empty as
print $ maximum ans

https://atcoder.jp/contests/abc098/tasks/arc098_b

let ans = shakutori (\l r (a,b) -> a+r == b.|.r) (\(a,b) l -> (a+l, b.|.l)) (\(a,b) l -> (a-l ,b.&.complement l)) (0,0) as
print $ sum $ ans
4tsuzuru4tsuzuru

拡張ユークリッドの互除法の実装

(11, 7)を例に。%は剰余演算。

\begin{pmatrix} 0 & 1 \\ 1 & -1_{(=11\%7)} \\ \end{pmatrix} \begin{pmatrix} 11\\ 7 \\ \end{pmatrix} = \begin{pmatrix} 7\\ 4\\ \end{pmatrix} \\[3mm] \begin{pmatrix} 0 & 1 \\ 1 & -1_{(=7\%4)} \\ \end{pmatrix} \begin{pmatrix} 7\\ 4 \\ \end{pmatrix} = \begin{pmatrix} 4\\ 3\\ \end{pmatrix} \\[3mm] \begin{pmatrix} 0 & 1 \\ 1 & 1_{(=7\%4)} \\ \end{pmatrix} \begin{pmatrix} 0 & 1 \\ 1 & 1_{(=11\%7)} \\ \end{pmatrix} \begin{pmatrix} 11\\ 7 \\ \end{pmatrix} = \begin{pmatrix} 4\\ 3\\ \end{pmatrix}

上記を想定しながらgcd(a,b)についての漸化式を作る。

\begin{pmatrix} s_{0} & t_{0} \\ s_{1} & t_{1} \\ \end{pmatrix} \begin{pmatrix} a \\ b \\ \end{pmatrix} = \begin{pmatrix} r_{0}\\ r_{1} \\ \end{pmatrix} \\[3mm] \begin{pmatrix} 0 & 1 \\ 1 & -(r_{0}\%r_{1}) \\ \end{pmatrix} \begin{pmatrix} s_{0} & t_{0} \\ s_{1} & t_{1} \\ \end{pmatrix} \begin{pmatrix} a \\ b \\ \end{pmatrix} = \begin{pmatrix} r_{1}\\ r_{2} \\ \end{pmatrix}

以上より

s_{2} = s_{0} - (r_{0}/r_{1}) s{1} \\[2mm] t_{2} = t_{0} - (r_{0}/r_{1}) t{1}
exgcd :: Integral a => a -> a -> (a, a, a)
exgcd a b = f $ go a b 1 0 0 1
  where
    go r0 r1 s0 s1 t0 t1
      | r1 == 0   = (r0, s0, t0)
      | otherwise = go r1 r2 s1 s2 t1 t2
      where
        (q, r2) = r0 `divMod` r1
        s2 = s0 - q*s1
        t2 = t0 - q*t1
    f (g,u,v)
      | g < 0 = (-g, -u, -v)
      | otherwise = (g,u,v)
4tsuzuru4tsuzuru

剰余環

(\mathbb{Z} / p\mathbb{Z})^{\times}の実装

べき乗は二分累乗法を使用。拡張ユークリッドの互除法は前項参照。
素数p以外でも使用可能。

{-# LANGUAGE DataKinds,ScopedTypeVariables,KindSignatures #-}
import GHC.TypeNats ( KnownNat, Nat, natVal )
import Data.Proxy ( Proxy(..) )
import Data.Ratio ( denominator, numerator )

newtype IntMod (m :: Nat) = IntMod Integer deriving (Eq,Show)

instance KnownNat m => Num (IntMod m) where
  IntMod x + IntMod y = fromInteger $ x + y
  IntMod x - IntMod y = fromInteger $ x - y
  IntMod x * IntMod y = fromInteger $ x * y
  negate (IntMod x) =  fromInteger $ negate x
  abs a = a
  signum _ = 1
  fromInteger x =  IntMod $ x `mod` (fromIntegral $ natVal (Proxy :: Proxy m))

instance KnownNat m => Fractional (IntMod m) where
  recip a@(IntMod x) = fromInteger r where
    (_, r, _) = exgcd x (fromIntegral $ natVal (Proxy :: Proxy m))
  fromRational x = fromInteger $ numerator x `div` denominator x

instance KnownNat m => Bounded (IntMod m) where
  minBound = IntMod 0
  maxBound = IntMod $ fromIntegral $ natVal (Proxy :: Proxy m) - 1

instance KnownNat m => Enum (IntMod m) where
  toEnum = fromInteger . toInteger
  fromEnum (IntMod x) = fromInteger x

instance KnownNat m => Ord (IntMod m) where
  IntMod x <= IntMod y = x <= y

power a n = go a n 1
  where
    go _ 0 res = res
    go a n res
      | even n = go (a*a) (n `div` 2) res
      | otherwise = go a (n-1) (res*a)

exgcd :: Integral a => a -> a -> (a, a, a)
exgcd a b = f $ go a b 1 0 0 1
  where
    go r0 r1 s0 s1 t0 t1
      | r1 == 0   = (r0, s0, t0)
      | otherwise = go r1 r2 s1 s2 t1 t2
      where
        (q, r2) = r0 `divMod` r1
        s2 = s0 - q*s1
        t2 = t0 - q*t1
    f (g,u,v)
      | g < 0 = (-g, -u, -v)
      | otherwise = (g,u,v)
4tsuzuru4tsuzuru

String Hash

z:: (KnownNat p) => IntMod p
z = 26  

toHash :: (KnownNat p) => BS.ByteString -> IntMod p
toHash str = BS.foldl' (\acc c -> acc * z + fromIntegral (fromEnum c - fromEnum 'a')) 0 str

scanHash :: (KnownNat p) => BS.ByteString -> V.Vector (IntMod p)
scanHash str = V.fromList $  scanl (\acc c -> acc * z + fromIntegral (fromEnum c - fromEnum 'a')) 0 $ BS.unpack str

rangeHash :: (KnownNat p) => Int -> Int -> V.Vector (IntMod p) -> IntMod p
rangeHash l r h  = (h V.! r) - (h V.! (l-1)) * (z `power` (r-l+1))

4tsuzuru4tsuzuru

Unbox化できるように修正。

{-# LANGUAGE DataKinds,ScopedTypeVariables,KindSignatures, TemplateHaskell, TypeFamilies, MultiParamTypeClasses #-}

import GHC.TypeNats ( KnownNat, Nat, natVal )
import Data.Proxy ( Proxy(..) )
import Data.Ratio ( denominator, numerator )
import Language.Haskell.TH
import Data.Vector.Unboxed.Deriving ( derivingUnbox )

newtype IntMod (m :: Nat) = IntMod { unIntMod :: Int } deriving (Eq, Show, Read)

instance KnownNat m => Num (IntMod m) where
  IntMod x + IntMod y = IntMod $ (x+y) `mod` fromIntegral (natVal (Proxy :: Proxy m))
  IntMod x - IntMod y = IntMod $ (x-y) `mod` fromIntegral (natVal (Proxy :: Proxy m))
  IntMod x * IntMod y = IntMod $ (x*y) `mod` fromIntegral (natVal (Proxy :: Proxy m))
  fromInteger x = IntMod $ fromInteger $ x `mod` fromIntegral (natVal (Proxy :: Proxy m))
  abs = id
  signum = const 1


instance KnownNat m => Fractional (IntMod m) where
  fromRational x = fromInteger (numerator x) / fromInteger (denominator x)
  recip a@(IntMod x) = IntMod $ fromIntegral r
    where
    (_, r, _) = exgcd x (fromIntegral $ natVal (Proxy :: Proxy m))

instance KnownNat m => Bounded (IntMod m) where
  minBound = IntMod 0
  maxBound = IntMod $ fromIntegral $ natVal (Proxy :: Proxy m) - 1

instance KnownNat m => Enum (IntMod m) where
  toEnum = fromIntegral
  fromEnum (IntMod x) = x

instance KnownNat m => Ord (IntMod m) where
  IntMod x <= IntMod y = x <= y


power a n = go a n 1
  where
    go _ 0 res = res
    go a n res
      | even n = go (a*a) (n `div` 2) res
      | otherwise = go a (n-1) (res*a)


exgcd :: Integral a => a -> a -> (a, a, a)
exgcd a b = f $ go a b 1 0 0 1
  where
    go r0 r1 s0 s1 t0 t1
      | r1 == 0   = (r0, s0, t0)
      | otherwise = go r1 r2 s1 s2 t1 t2
      where
        (q, r2) = r0 `divMod` r1
        s2 = s0 - q*s1
        t2 = t0 - q*t1
    f (g,u,v)
      | g < 0 = (-g, -u, -v)
      | otherwise = (g,u,v)


derivingUnbox "IntMod"
  [t| forall m. (KnownNat m) => IntMod m -> Int|]
  [| \(IntMod x) -> x |]
  [| IntMod |]

4tsuzuru4tsuzuru

エラトステネスの篩

下記の実装が全く分からず、アルゴリズムを手で追ってなんとか理解。
競プロHaskellerはほとんどこの実装を使っている印象。

下記の論文が元ネタ。
The Genuine Sieve of Eratosthenes

基本的なアイデアはもとのエラトステネスの篩でバツをつける作業で、発見した素数の倍数全てにバツを付けるのではなく、素数と最小の倍数を記録するところにある。

ある素数pを発見した時にそれ以降の全てのpの倍数(無限にある)にバツをつけることはせず、素数pで次にバツをつける数p^2と素数pをセットで、優先度付きキュー(p^2, p)として追加する。そのようにして、これまでに見つかった素数とそれによって消せる最小の倍数の辞書ができる。
そして、以下のように素数によって消せる倍数を更新していく。

次にやってきた数が

  1. 辞書の最小値以下 → その数が素数であることが確定するため、(q^2, q)として辞書に追加する。
  2. 辞書の最小値と等しい → その数は合成数であり、ヒットした辞書の値を素数分インクリメント (q^2+q, q) 。つまり、その素数で次に消す数を更新している。
  3. 辞書の最小値以上 →2の作業が終わっていないので、引き続き、(l*l+l+l+l, l)のように更新していく

では、なぜ下記の実装で(x^2, x)ではなく、(x^2, 2x)となっているかというと、x^2=奇数, 2x=偶数なので、x^2+k*xでインクリメントされる数の中で、すでに偶数は除去されているから。([5,7,9,11,13,...])

もし、自分で手で追う場合は、(q^2, q)を辞書に追加して、[2,3,4...]から始めてみるのが分かりやすい。

import qualified Data.Heap as H

-- @gotoki_no_joe
primes :: [Int]
primes = 2 : 3 : sieve q0 [5,7..]
  where
    q0 = H.insert (H.Entry 9 6) H.empty
    sieve queue xxs@(x:xs) =
      case compare np x of
        LT ->     sieve queue1 xxs
        EQ ->     sieve queue1  xs
        GT -> x : sieve queue2  xs
      where
        H.Entry np p2 = H.minimum queue
        queue1 = H.insert (H.Entry (np+p2) p2) $ H.deleteMin queue
        queue2 = H.insert (H.Entry (x * x) (x * 2)) queue

isPrime :: Int -> Bool
isPrime n = all ((/=0) . (mod n))  $ takeWhile ((<= n) . (^2)) primes
4tsuzuru4tsuzuru

素数列挙のもう一つのやり方

上記の方法に比べて理解しやすいし、パフォーマンスも良い。
5以上の奇数から、すでに列挙した素数で消去できる合成数の和集合を引いているだけである。

primes = 2 : 3 : minus [5,7..] (unionAll [[p*p, p*p+2*p..] | p <- tail primes]) 

minus (x:xs) (y:ys) = case (compare x y) of 
           LT -> x : minus  xs  (y:ys)
           EQ ->     minus  xs     ys 
           GT ->     minus (x:xs)  ys
minus  xs     _     = xs

union (x:xs) (y:ys) = case (compare x y) of 
           LT -> x : union  xs  (y:ys)
           EQ -> x : union  xs     ys 
           GT -> y : union (x:xs)  ys
union  xs     []    = xs
union  []     ys    = ys

unionAll :: Ord a => [[a]] -> [a]
unionAll ((x:xs):t) = x : union xs (unionAll $ pairs t)  where
          pairs ((x:xs):ys:t) = (x : union xs ys) : pairs t

現在のAtCoderでは使えないが、Data.List.Orderedが採用されれば、下記のようになる。
爆速である。

import Data.List.Ordered (minus, unionAll)

primes = 2 : 3 : minus [5,7..] (unionAll [[p*p, p*p+2*p..] | p <- tail primes]) 
4tsuzuru4tsuzuru

UnionFind

UnionFindのrootにサイズ情報(マイナスの長さ)をもたせることで、配列一つで表現可能。

import Data.Array.ST
import Control.Monad.ST

type UnionFind s = STUArray s Int Int

newUnionFind :: Int -> ST s (UnionFind s) 
newUnionFind n = newArray (1,n) (-1)

root :: Int -> UnionFind s -> ST s Int
root x uf = do 
  p <- readArray uf x
  if p < 0 then return x 
  else do
    p' <- root p uf
    writeArray uf x p'
    return p'

isSame :: Int -> Int -> UnionFind s -> ST s Bool
isSame x y uf = (==) <$> root x uf <*> root y uf

unite :: Int -> Int -> UnionFind s -> ST s (UnionFind s)
unite x y uf = do
  rx <- root x uf
  ry <- root y uf
  if rx == ry then return uf else do
    px <- readArray uf rx
    py <- readArray uf ry
    if px <= py then do 
      writeArray uf rx (px + py)
      writeArray uf ry rx
      return uf
    else do
      writeArray uf ry (px + py)
      writeArray uf rx ry
      return uf
4tsuzuru4tsuzuru

最小全域木(クラスカル法)

kruskal :: Int -> [((Int,Int),Int)] -> [((Int,Int),Int)]
kruskal n edges = runST $ do
  uf <- newUnionFind n
  foldM (\res e@((a,b),c) -> do
    is_same <- isSame a b uf
    if is_same then return res 
    else do
      unite a b uf
      return $ e:res
    ) [] $ sortOn snd edges
4tsuzuru4tsuzuru

グラフライブラリFGLを使おう

このスクラップブックでDFSやBFS、クラスカル、ダイクストラ等のアルゴリズムのライブラリ化(整備)を目論んでいたが、既存のFGLライブラリを使った方が良い気がしてきた。
というのも、上記のDFSやBFSの実装は命令形言語のアルゴリズムをHaskellで書き直しただけであり、関数型言語のメリットであるアルゴリズムの本質を捕まえることができていなかったからだ。

Hakellの教科書でListやTreeの話題はよく出てくるが、Graphの話題はほとんど出てこない。
これは、ListやTreeには下記のような適切な始代数があり、それに応じたコンストラクター(パターンマッチング)がうまく機能するからだ。始代数についてはこちらを参照

  1. List の場合 -> 1 + A × X -> Nill / Cons
  2. Tree の場合 -> A + X x X

上記の理由でfoldlやunfold(余代数)などが定義され、関数型言語特有の見通しの良い計算が可能となる。
では、グラフ構造ではどうか?
私はalgebraic-graphsというライブラリで上記のような代数的データ型をうまく定義できているのではないかと期待したが、実際はアルゴリズムレベルでは破綻しており、結局隣接リストなど命令形言語と同じデータ構造を使っていた。
今回紹介するFGLライブラリは、20年以上前に作られた古いライブラリである。にもかかわらず、グラフ構造の代数的データ型の定義に現存するライブラリの中で最も成功していると感じる。しかも、AtCoderで使えるし、パフォーマンスも悪くない。

グラフ構造のコンストラクター(構築子)

このライブラリで使われているグラフコンストラクターは以下のようにListと似ている。
data Graph a b = Empty | Context a b & Graph a b

Context a b というのは、ある頂点に着目して、入ってくる辺、出ていく辺をひとまとめにしたものだ。

type Context a b= ([(辺のpayload b, 入ってくる辺の相手の頂点)], 頂点, 頂点のpayload a, [(辺のpayload b, 出ていく辺の相手の頂点)])

この一見単純なパターンマッチングで驚くほど簡潔にGraph関連のアルゴリズムを簡潔に説明することができる。

DFS

DFSを例に考えてみると、命令形言語では、グラフ表現として隣接リストを用いて、中間の訪問済み頂点を管理するためにリストやセットを用いる。
FGLではどちらも使用せず、直感的なアルゴリズム理解が可能だ。

dfs [] g = []
dfs (v:vs) (c & g') = v: dfs (scc c ++ vs) g'
dfs (v:vs) g = dfs vs g

二行目は探索すべき頂点のスタックの先頭vを取り出し、その頂点vに着目したcontext cをグラフから取り出している。もとのグラフgから、頂点vと流入出辺を取り除いたグラフがg'である。
sccはContext cからの流出先の頂点リストを表しており、それを探索すべき頂点リストの先頭にスタックしている。
三行目は頂点vに着目したContextをグラフgから取り出せなかった時(つまり、すでにグラフから取り除かれている時)、残りの探索リストvsに対してアルゴリズムを続行する。

競プロで使うには

AtCoderで使う際には、すでにライブラリの中にある便利関数(dfs,bfs)などを使ってもいいが、このライブラリの本質はmatch関数(前述したパターンマッチング)なので、それを使って自分でいくつか関数を作っておくと理解が早い。

具体例

DFS①

https://atcoder.jp/contests/tessoku-book/tasks/math_and_algorithm_am
既存のライブラリ内関数を使いAC

import Data.Graph.Inductive.Query.DFS (isConnected)
 let g = mkGraph (zip [1..n] $ repeat ()) $ map (\(x,y) -> (x,y,())) es :: Gr () ()
 let ans = isConnected  g

以降はグラフの作成とimportは省略する。

DFS②

https://atcoder.jp/contests/tessoku-book/submissions/39359250

解法①

dfsPath :: Graph gr => gr a b -> Node -> Node -> (Maybe [Node], gr a b)
dfsPath g start goal = go [start] g
  where 
    go [] g = (Nothing, g)
    go _ g | isEmpty g = (Nothing, g)
    go (v:vs) g = case match v g of
      (Nothing,g') -> go vs g'
      (Just c,g') -> 
        if v == goal then (pure [v] , g')
        else case (go (neighbors' c ) g') of
          (Nothing, g'') -> go vs g''
          (Just path, g'') -> (pure (v:path) , g'')

let ans = fromJust $ fst $ dfsPath g 1 n

解法②(既存のライブラリを使用)

simplePathはdfsで求めた木構造から、単純パスの復元のために使う。

simplePath :: Eq a => a -> Tree a  -> Maybe [a]
simplePath goal (Node v ts) = go v ts
  where
    go v [] = if v == goal then Just [v] else Nothing
    go v (t:ts) = case simplePath goal t of
      Nothing -> go v ts
      Just path -> Just (v:path)

 let ans = fromJust $ simplePath n $ head $ udff [1] g

BFS①

https://atcoder.jp/contests/tessoku-book/tasks/math_and_algorithm_an
有向グラフ用のlevelやlevelnはあるが、無向グラフ用の用のulevelやulevelnがないため、自分で作成した。
undir(有向グラフを無向グラフにする関数)はあるが、これを使って、levelを適用してもTLEとなるので、ulevelを使うしかない。

ulevel :: (Graph gr) => Node -> gr a b -> [(Node,Int)]
ulevel v = uleveln (queuePut (v,0) mkQueue)

uleveln :: (Graph gr, Num b1) => Queue (Node, b1) -> gr a b2 -> [(Node, b1)]
uleveln q g | queueEmpty q || isEmpty g = []
        | otherwise                 =
        case match v g of
          (Just c, g')  -> (v,j):uleveln (queuePutList (zip (neighbors' c) (repeat (j+1))) q') g'
          (Nothing, g') -> uleveln q' g'
          where ((v,j),q') = queueGet q

let ans =  accumArray max (-1) (1,n) $ ulevel 1 g
mapM_ (\i -> putStrLn $ show $ ans ! i ) [1..n]

BFS②

https://atcoder.jp/contests/tessoku-book/tasks/abc007_3
1次元のグラフgを構成して(ここが難しいが)、ライブラリのesp関数を使う。

let ans =  esp ((sy-1)*c + (sx-1)) ((gy-1)*c + (gx-1))  g

ダイクストラ①

https://atcoder.jp/contests/tessoku-book/tasks/tessoku_book_bl
ダイクストラ法を実装していると、このパターンマッチングがいかに各種グラフアルゴリズムの本質をついているのかがわかる。
一番上のダイクストラ法の実装と見比べてみると、確定済み頂点を管理するデータを持たなくても良い。
なぜなら、確定済み頂点とそこに接続する辺がすでにグラフgから取り除かれているからだ。
つまり、グラフg自体が下記の情報を内包している。

  1. (DFSやBFS、ダイクストラ法等の)確定済み配列
  2. 隣接リスト
import qualified Data.Heap as H

udijkstraWith :: (Graph gr, Ord b, Num b) => H.Heap (H.Entry b Edge) -> gr a b -> [(Edge, b)]
udijkstraWith h g | H.size h == 0 || isEmpty g = []
udijkstraWith h g =
    case match v g of
        (Just c@(p,_,_,s), g') -> 
            ((v,pre),currentCost) : udijkstraWith ( H.union h' (H.fromList (map (\(newCost,v') -> H.Entry (currentCost+newCost) (v',v)) $ p++s)) ) g'
        (Nothing, g') -> udijkstraWith h' g'
    where Just (H.Entry currentCost (v,pre), h') = H.uncons h

udijkstra :: (Graph gr, Ord b, Num b) => Node -> gr a b -> [(Edge, b)]
udijkstra v g = udijkstraWith (H.singleton (H.Entry 0 (v,0))) g

let ans  = accumArray max (-1) (1,n) $ map (\((a,b),c)->(a,c)) $ udijkstra 1 g

ダイクストラ②

https://atcoder.jp/contests/tessoku-book/tasks/tessoku_book_ek

shortestPath :: Graph gr => Node -> Node -> gr () Int ->Path
shortestPath s t g = (reverse $ unfoldr (\x -> if x == 0 then Nothing else Just (x, dist ! x)) t)
    where es =  udijkstraWith (H.singleton (H.Entry 0 (s,0))) g 
          dist = accumArray (+) 0 (1,noNodes g) $ map (\((a,b),c)->(a,b)) es

let ans = shortestPath 1 n g

MST

https://atcoder.jp/contests/tessoku-book/tasks/tessoku_book_bo
クラスカル法ではなく、プリム法で実装。
計算量としては、プリム法は 𝑂(|𝐸|+|𝑉|log|𝑉|) 、クラスカル法は 𝑂(|𝐸|log|𝐸|)。

import qualified Data.Heap as H

prim :: (Graph gr) => H.Heap (H.Entry Int Edge) -> gr a Int -> [LEdge Int]
prim h g | H.size h == 0 || isEmpty g = []
prim h g =
    case match to g of
         (Just c@(p,_,_,s),g')  -> (to,from,cost): prim (H.union h' (H.fromList (map (\(l,to')->H.Entry l (to',to)) $ p++s))) g'
         (Nothing,g') -> prim h' g'
    where Just (H.Entry cost (to,from), h') = H.uncons h

mst :: Graph gr =>gr a Int -> [LEdge Int]
mst g = prim (H.singleton $ H.Entry 0 (v,0)) g where ((_,v,_,_),_) = matchAny g

let res = mst g

最大流問題

https://atcoder.jp/contests/tessoku-book/tasks/tessoku_book_bp
BFSもMSTもダイクストラも色々と速度面で手直しが必要だったが、maxFlowは別物。爆速すぎ。
中身のアルゴリズムはEdmonds-Karp法。
論文は下記。
New Algorithms For The Functional Graph Library

let res = maxFlow g 1 n
4tsuzuru4tsuzuru

燃やす埋める問題

色々な記事を見て混乱したけれども、下記だけ見れば良い。
https://koyumeishi.hatenablog.com/entry/2021/01/14/052223

結局は、疑似ブール関数: {1,0}^n -> R を式変形していくと、最小カットと同様の式で表示できるから、それを元にグラフを構成していけば、脳死でできる。

4tsuzuru4tsuzuru

グラフライブラリ(AtCoder貼り付け用)

鉄則問題集はすべてパスしており、パフォーマンス上も大きな問題なし。
速度よりも使い勝手を重視しており、構成は基本的にはFGLライブラリの拡張と考えて良い。
実装されているものは、DFS/BFS, MST(Prim), Dijkstra, Dinic, 木の直径などなど。

{-# LANGUAGE BangPatterns #-}
import Data.Maybe
import Data.Array
import Data.List
import Data.Tree
import Data.IntMap.Strict (IntMap)
import qualified Data.IntMap.Strict as IM
import Control.Arrow (second)
import Data.Sequence (Seq)
import qualified Data.Sequence as Seq
import qualified Data.Heap as H


type Gr a b = IntMap (IntMap b, a, IntMap b)
type Node = Int
type LNode a = (Node, a)
type Edge = (Node, Node)
type LEdge a = (Node, Node, a)
type Capacity = Int
type Flow = Int
type Residual = Int


dinic :: Gr Int (Capacity, Flow, Residual) -> Node -> Node -> Gr Int (Capacity, Flow, Residual)
dinic graph src sink = 
    let (new_graph, can_continue) = dinicStep graph src sink
    in if can_continue 
           then dinic new_graph src sink
           else graph

dinicStep :: Gr Int (Capacity, Flow, Residual) -> Node -> Node -> (Gr Int (Capacity, Flow, Residual), Bool)
dinicStep graph src sink  = 
    let layered_graph = mkLevelGraph (mkResidualGraph graph) src
    in 
      if  getNodeLabel layered_graph sink /= (-1)
      then (dfs_dinic layered_graph src sink, True)
      else (graph, False)

dfs_dinic :: Gr Int (Capacity, Flow, Residual) -> Node -> Node -> Gr Int (Capacity, Flow, Residual)
dfs_dinic graph source sink = 
  if sink_reached then dfs_dinic (update_path graph path flow) source sink else graph
  where (path,flow , sink_reached) = find_augmentPath graph source sink


update_path :: Gr Int (Capacity, Flow, Residual) -> [Edge] -> Flow -> Gr Int (Capacity, Flow, Residual)
update_path graph [] _ = graph
update_path graph ((u, v):path) flow =
    let updatedGraph = IM.adjust (\(p,l,s) -> (p,l,(IM.adjust (\(cap, f, res) -> (cap, f + flow, res )) v s))) u graph
    in update_path updatedGraph path flow

find_augmentPath :: Gr Int (Capacity, Flow, Residual) -> Node ->  Node -> ([Edge], Flow, Bool)
find_augmentPath g node sink =
  let (_, source_level, suc_edges) = g IM.! node
      appropiate_edges = map (\(w,cfr) -> (node,w,cfr)) $ filter (\(w, (_, f, res)) -> f < res && res /= 0 &&  getNodeLabel g w == source_level + 1) $ IM.toList suc_edges
  in 
    if null appropiate_edges 
      then ([], 0, False) 
      else 
        let
          results = map (dfs_edge g node sink) appropiate_edges
          bestPath = dropWhile (\(_, _, reached) -> not reached) results
        in 
          if null bestPath 
          then ([], 0, False)
          else head bestPath

dfs_edge :: Gr Int (Capacity, Flow, Residual) -> Node -> Node -> (Int, Int, (Capacity, Flow, Residual)) -> ([(Int, Int)], Flow, Bool)
dfs_edge graph source sink (u, v, (_, f, r)) =
  if v == sink
    then ([(u, v)], r-f, True)
    else
      let (path, flow, reached) = find_augmentPath graph v sink
       in ((u, v) : path, min (r-f) flow, reached)

mkLevelGraph ::  Gr Int (Capacity,Flow,Residual) -> Int -> Gr Int (Capacity,Flow,Residual)
mkLevelGraph g src = updateLabels g $ IM.fromList $ leveln (pruneGraph g) (Seq.singleton (src, 0))


updateLabels :: Gr Int (Capacity,Flow,Residual) -> IntMap Int -> Gr Int (Capacity,Flow,Residual)
updateLabels g lvMap = IM.mapWithKey updateLabel g
  where
    updateLabel k (p,_, s) = (p,fromMaybe (-1) (IM.lookup k lvMap), s)

augmentGraph :: Gr () Capacity -> Gr Int (Capacity,Flow,Residual) 
augmentGraph g = IM.map (\(p, _, s) -> (IM.map (\cap -> (0,0,0)) p,0, IM.map (\cap -> (cap,0,cap)) s)) g

pruneGraph :: Gr Int (Capacity,Flow,Residual) -> Gr Int (Capacity,Flow,Residual)
pruneGraph g = IM.map ((\(p,l,s) ->(IM.filter (\(_,_,r)-> r/=0) p, l, IM.filter (\(_,_,r)-> r/=0) s))) g

mkResidualGraph :: Gr Int (Capacity,Flow,Residual) -> Gr Int (Capacity,Flow,Residual)
mkResidualGraph g = IM.map ((\(p,l,s) ->(p,l,(IM.map (\(cap, flow, res) -> (cap, 0, res-flow))s))))
    $ insEdges (concatMap addReverseEdge $ edges g') g'
  where
    g' = reflectReverseEdges $ g
    addReverseEdge (src, dest, (cap, flow, res)) = if flow > 0 then [(dest, src, (0, 0, (cap- (res-flow))))] else []


reflectReverseEdges :: Gr Int (Capacity,Flow,Residual) -> Gr Int (Capacity,Flow,Residual)
reflectReverseEdges g = updateEdgeLabels incEdges g
  where
    incEdges = map (\(src, dest, (cap, flow, res))->((dest, src),(\(c,f,r)->(c,f,r+flow))))  $ filter (\(src, dest, (cap, flow, res)) -> cap ==0 && flow /=0) $ edges g



edges :: Gr a b -> [LEdge b]
edges g = concatMap (\(v,c) -> map (\(w, edge) -> (v, w, edge)) $ lsuc c) $ IM.toList g

getNodeLabel :: Gr a b -> Node -> a
getNodeLabel g n =  (\(_,l,_) -> l) $ g IM.! n

maxFlow :: Gr () Capacity -> Int -> Int -> Flow
maxFlow g src sink = 
  let final_graph =  reflectReverseEdges $ dinic (augmentGraph g) src sink
      (_,_,outs) = final_graph IM.! src
  in sum $ map (\(_, (c, f, r)) -> c-(r-f)) $ IM.toList outs


mkGraph :: [LNode a] -> [(LEdge b)] -> Gr a b
mkGraph vs es   = insEdges es
                      . IM.fromList
                      . map (second (\l -> (IM.empty, l,IM.empty)))
                      $ vs

updateEdgeLabel  :: Edge -> (b -> b) -> Gr a b -> Gr a b
updateEdgeLabel (v,w) f g = IM.adjust (\(ins,lab,outs) -> (IM.adjust f v ins,lab,outs)) w $ IM.adjust (\(ins,lab,outs) -> (ins,lab, IM.adjust f w outs)) v $ g

updateEdgeLabels :: Foldable t => t (Edge, b -> b) -> Gr a b -> Gr a b
updateEdgeLabels es g = foldl' (flip (uncurry updateEdgeLabel)) g es

delEdge :: Edge -> Gr a b -> Gr a b
delEdge (v,w) g = IM.adjust (\(ins,lab,outs) -> (IM.delete v ins,lab, outs)) w $  IM.adjust (\(ins,lab,outs) -> (ins,lab, IM.delete w outs)) v $ g

insEdge :: LEdge b -> Gr a b -> Gr a b
insEdge (v,w,l) g = IM.adjust  (\(ins,lab,outs) -> (IM.insert v l ins , lab, outs)) w $ IM.adjust (\(ins,lab,outs) -> (ins, lab, IM.insert w l outs)) v $ g

insEdges :: Foldable t => t (LEdge b) -> Gr a b -> Gr a b
insEdges es g = foldl' (flip insEdge) g es

undirected :: Gr a b -> Gr a b
undirected g = foldl' addReversedEdge g (edges g)
  where
    addReversedEdge graph (v, w, edge) = insEdge (w, v, edge) graph

match :: Node -> Gr a b ->  (Maybe (IntMap b, a, IntMap b), Gr a b)
match node g =
  case IM.lookup node g of
    Nothing -> (Nothing, g)
    Just (p, label, s) ->
      let !g1 = IM.delete node g
          !g2 = clearPred g1 node s
          !g3 = clearSucc g2 node p
      in (Just (p, label, s), g3)

clearSucc :: Gr a b -> Node -> IntMap x -> Gr a b
clearSucc g v = IM.differenceWith go g
  where
    go :: (IntMap b, a, IntMap b) -> x -> Maybe (IntMap b, a, IntMap b)
    go (ps, l, ss) _ =
      let !ss' = IM.delete v ss
      in Just (ps, l, ss')

clearPred :: Gr a b -> Node -> IntMap x -> Gr a b
clearPred g v = IM.differenceWith go g
  where
    go :: (IntMap b, a, IntMap b) -> x -> Maybe (IntMap b, a, IntMap b)
    go (ps, l, ss) _ =
      let !ps' = IM.delete v ps
      in Just (ps', l, ss)

suc :: (IntMap b, a,IntMap b) -> [Node]
suc (_,_,outs) = IM.keys outs

lsuc :: (IntMap b, a, IntMap b) -> [(Node, b)]
lsuc (_,_,outs) = IM.toList outs

pre :: (IntMap b, a, IntMap b) -> [Node]
pre (ins,_,_) = IM.keys ins

lpre :: (IntMap b, a, IntMap b) -> [(Node, b)]
lpre (ins,_,_) = IM.toList ins

neighbors :: (IntMap b, a, IntMap b) -> [Node]
neighbors (ins,_,outs) = IM.keys ins ++ IM.keys outs

lneighbors :: (IntMap b, a, IntMap b) -> [(Node, b)]
lneighbors (ins,_,outs) = IM.toList ins ++ IM.toList outs

dfsForest :: Gr a b -> Node -> [Tree Node]
dfsForest g src = go g [src]
  where
    go g [] = []
    go g (v:vs) = case match v g of
      (Nothing, _) -> go g vs
      (Just c, g') -> Node v (go g' (suc c)) : go g' vs

dfs :: Gr a b -> Node -> [Node]
dfs g src = go g [src]
  where
    go g [] = []
    go g (v:vs) = case match v g of
      (Nothing, _) -> go g vs
      (Just c, g') -> v : go g' ((suc c) ++ vs)


isConnected :: Gr a b -> Bool
isConnected g = length (dfs g (fst $ head $ IM.toList g))  == IM.size g


dfsPath :: Gr a b -> Node -> Node -> Maybe [Node]
dfsPath g src sink = fst $ go [src] g
  where
    go vs g | null vs || IM.null g = (Nothing,g)
    go (v:vs) g = case match v g of
      (Nothing, _) -> go vs g
      (Just c, g') -> 
        if v == sink then (Just [v], g')
        else case go (suc c) g' of
          (Nothing, g'') -> go vs g''
          (Just path, g'') -> (Just (v:path), g'')


bfsShortestPath :: Gr a b -> Node -> Node -> Maybe [Node]
bfsShortestPath g src sink = case paths of
  [] -> Nothing
  _ -> Just $ reverse $ head paths
  where
    paths = filter (\(w : _) -> w == sink) $ go g (Seq.singleton [src])

    go g q | Seq.null q || IM.null g = []
    go g q = case match v g of
      (Nothing, _) -> go g rest
      (Just c, g') -> p : go g' (rest Seq.>< Seq.fromList (map (:p) $ suc c))
      where
        (p@(v:_), rest) = (Seq.index q 0, Seq.drop 1 q)

level :: Gr a b -> Node -> [(Node, Int)]
level g src = leveln g (Seq.singleton (src, 0))

leveln :: Gr a b -> Seq (Node, Int) -> [(Node, Int)]
leveln _ Seq.Empty = []
leveln g _ | IM.null g = []
leveln g queue = case match v g of
    (Nothing, _) -> leveln g rest
    (Just c, g') -> (v, j) : leveln g' (rest Seq.>< Seq.fromList (zip (suc c) (repeat (j+1)) ))
  where ((v,j), rest) = (Seq.index queue 0, Seq.drop 1 queue)

ulevel :: Gr a b -> Node -> [(Node, Int)]
ulevel g src = uleveln g (Seq.singleton (src, 0))

uleveln :: Gr a b -> Seq (Node, Int) -> [(Node, Int)]
uleveln _ Seq.Empty = []
uleveln g _ | IM.null g = []
uleveln g queue = case match v g of
    (Nothing, _) -> uleveln g rest
    (Just c, g') -> (v, j) : uleveln g' (rest Seq.>< Seq.fromList (zip (neighbors c) (repeat (j+1)) ))
  where ((v,j), rest) = (Seq.index queue 0, Seq.drop 1 queue)

dijkstra :: (Ord b, Num b) => Gr a b -> Node -> IntMap b
dijkstra g s = go g (H.singleton (H.Entry 0 s)) 
  where
    go g h  | H.null h || IM.null g = IM.empty
    go g h = case match v g of
      (Nothing, g) -> go g h'
      (Just c, g') -> IM.insert v d $ go g' (H.union h' (H.fromList $ map (\(w, l) -> H.Entry (l + d) w) $ lsuc c))
      where Just (H.Entry d v,h') = H.uncons h

udijkstra :: (Ord b, Num b) => Gr a b -> Node -> IntMap b
udijkstra g s = go g (H.singleton (H.Entry 0 s)) 
  where
    go g h  | H.null h || IM.null g = IM.empty
    go g h = case match v g of
      (Nothing, g) -> go g h'
      (Just c, g') -> IM.insert v d $ go g' (H.union h' (H.fromList $ map (\(w, l) -> H.Entry (l + d) w) $ (lsuc c ++ lpre c)))
      where Just (H.Entry d v,h') = H.uncons h

dijkstraShortestPath :: (Ord p, Num p) => Gr a p -> Node -> Node -> [Node]
dijkstraShortestPath g src sink = reverse $ head $ filter (\(w:_)-> w==sink) $ go g (H.singleton (H.Entry 0 [src]))
  where
    go g h | H.null h || IM.null g = []
    go g h = case match v g of
      (Nothing, g) -> go g h'
      (Just c, g') -> p : go g' (foldr H.union h' (map (\(w, l) -> H.singleton (H.Entry (l + d) (w:p))) $ lsuc c))
      where Just (H.Entry d p@(v:_),h') = H.uncons h

udijkstraAllPath :: (Ord a1, Num a1, Show a1) => Gr a2 a1 -> Node -> p -> M.Map Node [[Node]]
udijkstraAllPath g src sink = go g (M.singleton (0,src) [[src]])
  where
    go g h | M.null h || IM.null g = M.empty
    go g h = case match v g of
      (Nothing, g) -> go g h'
      (Just c, g') ->  M.insert v paths $ go g' (M.unionWith (++) h' (M.fromList $ map (\(v', w') -> ((w+w',v'), map (v':) paths)) $ lneighbors c))
      where Just (((w,v), paths), h') = M.minViewWithKey h

sizeSubgraph:: Node -> Gr () () -> IntMap Int
sizeSubgraph k g  
  | IM.null g  = IM.singleton k 0
  | otherwise = case match k g of
      (Just (p,l,s), g') -> IM.insert k (maximum (0:subranks)) memo
        where
          memo = IM.unions $ map (\x-> sizeSubgraph x g') (IM.keys p ++ IM.keys s)
          subranks = map (\x-> memo IM.! x +1)  (IM.keys p ++ IM.keys s)
      (Nothing, g') -> IM.singleton k 0

-- 木の直径
uDiameterOfTree :: Gr a b -> Int
uDiameterOfTree g = 1+ (snd $ maximumBy (comparing snd ) $ ulevel g $ fst $ maximumBy (comparing snd) $ ulevel g $ fst $ head $ IM.toList g)


uprim :: (Ord b, Num b) => H.Heap (H.Entry b Edge) -> Gr a b -> [LEdge b]
uprim h g | H.size h == 0 || IM.null g = []
uprim h g =
    case match v g of
         (Just c,g')  -> (p,v,cost): uprim (H.union h' (H.fromList (map (\(s,cost')->H.Entry cost' (v,s)) $ lneighbors c))) g'
         (Nothing,g') -> uprim h' g'
    where Just (H.Entry cost (p,v), h') = H.uncons h

umst :: (Ord b, Num b) => Gr a b -> [LEdge b]
umst g = uprim (H.singleton $ H.Entry 0 (0,v)) g where (v,_) = IM.findMin g

prim :: (Ord b, Num b) => H.Heap (H.Entry b Edge) -> Gr a b -> [LEdge b]
prim h g | H.size h == 0 || IM.null g = []
prim h g =
    case match v g of
         (Just c,g')  -> (p,v,cost): prim (H.union h' (H.fromList (map (\(s,cost')->H.Entry cost' (v,s)) $ lsuc c))) g'
         (Nothing,g') -> prim h' g'
    where Just (H.Entry cost (p,v), h') = H.uncons h

mst :: (Ord b, Num b) => Gr a b -> [LEdge b]
mst g = prim (H.singleton $ H.Entry 0 (0,v)) g where (v,_) = IM.findMin g

4tsuzuru4tsuzuru

Segment Tree

@cojna, @frtn_r, @toyboot4e氏を参考。

import Control.Monad.ST
import Data.Monoid
import Data.Semigroup
import qualified Data.Vector.Generic               as G
import qualified Data.Vector.Generic.Mutable       as GM
import Control.Monad.Primitive
import Data.Bits
import Data.Function ( fix )

newtype SegTree mv s a = SegTree {getSegBody :: mv s a}

newSegTree :: (GM.MVector mv a, Monoid a, PrimMonad m) => Int ->  m (SegTree mv (PrimState m) a)
newSegTree n = SegTree <$> GM.replicate (2*n) mempty

buildSegTree :: (G.Vector v a, Monoid a, PrimMonad m) =>  v a -> m (SegTree (G.Mutable v) (PrimState m) a)
buildSegTree v = do
  let n = expandToPowerOfTwo $ G.length v
  tree <- GM.replicate (2*n) mempty
  G.unsafeCopy (GM.unsafeSlice n (G.length v) tree) v
  forM_ [n-1,n-2..1] $ \i -> do
    x <- GM.unsafeRead tree (2*i)
    y <- GM.unsafeRead tree (2*i+1)
    GM.unsafeWrite tree i (x <> y)
  return $ SegTree tree

setSegTree :: (GM.MVector mv a, Monoid a, PrimMonad m) => SegTree mv (PrimState m) a -> Int -> a -> m ()
setSegTree st@(SegTree tree) i0 v = do
  let !n = unsafeShiftR (GM.length tree)  1
      !i = n+i0
      !h = 63 - countLeadingZeros n
  GM.unsafeWrite tree (i) v
  forM_ [1..h] $ \j -> do 
    pull st (unsafeShiftR i j)

pull :: (GM.MVector mv a, Monoid a, PrimMonad m) => SegTree mv (PrimState m) a -> Int -> m ()
pull (SegTree tree) i = do
  x <- GM.unsafeRead tree (2*i)
  y <- GM.unsafeRead tree (2*i+1)
  GM.unsafeWrite tree i (x <> y)

getSegTree :: (GM.MVector mv a, Monoid a, PrimMonad m) => SegTree mv (PrimState m) a -> Int -> m a
getSegTree (SegTree tree) i = do
  let n = unsafeShiftR (GM.length tree)  1
  GM.unsafeRead tree (n+i)
  
-- 半開区間[l,r)のクエリ
prodSegTree :: (GM.MVector mv a, Monoid a, PrimMonad m) => SegTree mv (PrimState m) a -> Int -> Int -> m a
prodSegTree (SegTree tree) l0 r0 = do
  let !n = unsafeShiftR (GM.length tree)  1
      !l = n+l0
      !r = n+r0
  fix
    ( \loop !accL !accR !l' !r' -> do
        if l' < r'
          then do
            !accL' <-
              if l' .&. 1 == 1
                then  (accL <>) <$!> GM.unsafeRead tree l'
                else return accL
            !accR' <-
              if r' .&. 1 == 1
                then (<> accR) <$!> GM.unsafeRead tree (r' - 1)
                else return accR
            loop
              accL'
              accR'
              (unsafeShiftR (l' + l' .&. 1) 1)
              (unsafeShiftR (r' - r' .&. 1) 1)
          else return $! accL <> accR
    ) mempty mempty l r

prodAllSegTree :: (GM.MVector mv a, Monoid a, PrimMonad m) => SegTree mv (PrimState m) a -> m a
prodAllSegTree (SegTree tree) = GM.unsafeRead tree 1

expandToPowerOfTwo :: Int -> Int
expandToPowerOfTwo x | x<2 = 2
          | otherwise = 1 `unsafeShiftL` (finiteBitSize (x-1) - countLeadingZeros (x-1) )
4tsuzuru4tsuzuru

遅延セグ木

@cojna氏を基盤(というかほとんどそのまま)に、library checkerの最速実装とACLライブラリ内の実装を参考にした。
上記3つはほとんど似たような実装となっている。

遅延セグ木の数学的議論は、
SegmentTreeに載る代数的構造について
セグメント木がモノイドを必要とする理由
を参照すれば十分。
下記のコードの定義を使えば、

  1. f @ ( x <> y ) == f @ x <> f @ y
  2. (f<>g)@x == f@(g@x)== (f@・g@) x
    という、作用@に関して、2つの準同型に気をつける必要がある。

作用素から生成される作用を関手と考えることもできるが、内部構造の議論もする必要があるため、実装上のメリットはない。

Unbox化に関しては、@toyboot4e氏を参考に、unboxing-vectorではなく、vector-th-unboxを用いることとした。置換操作もUpdate x | NoUpdate (=1+X)とすれば、モノイドとして扱えるが、unboxing-vectorは余積は扱えないため。

更新区間に関しては、countTrailingZerosを使うのが分かりやすい。更新区間を覆う区間の親区間の作用素を伝播処理するときは、セグ木の高さから、2の指数+1の高さまで順に作用素を伝播させていけば良い。

なぜ2の指数なのか。
具体的に、1-indexedで[10,14)を例に考えてみる。
10の2進数表記は1010で、親をたどる作業は一つ右シフトすることに相当する。つまり、10(=1010)の親は101となる。逆に、親が持つ値は、子のモノイド積である。101は1010と1011のモノイド積だ。
なので、1010から始まる区間モノイド積を考える際は、101を考えれば十分で、1010まで作用素を伝播させる必要がない。つまり、右シフトを続けて、0が続く限り親をたどることができ、その親の値を確定させれば良いということになる。そのために、到達した親のひとつ上(=2の指数+1)の作用素までを一番上から伝播させる。
右端はどうなるだろうか?13(=1101)も同様に、110が1100と1101のモノイド積なので、110の値を確定させれば十分だ。つまり、左端と同様に1が続く限り親をたどり、到達した親の情報を確定させる。
なので、countTrailingOnesで親までの高さを求めれば良い。
HaskellではcountTrailingOnesがないため、右隣のcountTrailingZerosと高さが等しいことを利用している。
セグ木のグラフをよく見ると、左のtrailingOnesと右のtrailingZerosが常に等しいことがわかる。

最後に、遅延セグ木では、作用素と要素に関して、モノイドの使いまわしができるので、Haskellは他の言語よりも圧倒的に実装しやすく、見通しも良いと思う。(Unbox化以外)

import Data.Monoid
import Data.Semigroup
import qualified Data.Vector.Unboxed               as U
import qualified Data.Vector.Unboxed.Mutable       as UM
import Control.Monad.Primitive
import Data.Bits
import Data.Function ( fix )

import Control.Monad.ST

class (Monoid f) => MonoidAction f a where
    (<@>) :: f -> a -> a

data LazySegTree s a f = LazySegTree (UM.MVector s a) (UM.MVector s f)
 
newLazySegTree :: (PrimMonad m, U.Unbox a, U.Unbox f, Monoid a,Monoid f) => Int -> m (LazySegTree (PrimState m) a f)
newLazySegTree x = do
  let !n = expandToPowerOfTwo x
  tree <- UM.replicate (2*n) mempty
  lazy <- UM.replicate (n) mempty
  return $ LazySegTree tree lazy
 
buildLazySegTree :: (PrimMonad m, Monoid a, Monoid f, U.Unbox f,U.Unbox a) => U.Vector a -> m (LazySegTree (PrimState m) a f)
buildLazySegTree xs = do
  let !n = expandToPowerOfTwo $ U.length xs
  tree <- UM.replicate (2*n) mempty
  lazy <- UM.replicate n mempty
  U.unsafeCopy (UM.unsafeSlice n (U.length xs) tree) xs
  forM_ [n-1,n-2..1] $ \i -> do
    x <- UM.unsafeRead tree (2*i)
    y <- UM.unsafeRead tree (2*i+1)
    UM.unsafeWrite tree i (x <> y)
  return $ LazySegTree tree lazy
  
setLazySegTree :: (PrimMonad m, MonoidAction f a, Monoid a, U.Unbox a,U.Unbox f, Eq f) =>LazySegTree (PrimState m) a f -> Int -> a -> m ()
setLazySegTree st@(LazySegTree tree lazy) i0 x = do
  let !n = UM.length lazy
      !i = n+i0
      !h = 63 - countLeadingZeros n
  forM_ [h,(h -1) .. 1] $ \j -> do
    push st (unsafeShiftR i j)
  UM.unsafeWrite tree i x
  forM_ [1..h] $ \j -> do 
    pull st (unsafeShiftR i j)
 
getLazySegTree :: (PrimMonad m, MonoidAction f a, U.Unbox a,U.Unbox f, Eq f) => LazySegTree (PrimState m) a f -> Int -> m a
getLazySegTree st@(LazySegTree tree lazy) i = do
  let !n = UM.length lazy
      !h = 63 - countLeadingZeros n
  forM_ [h,(h -1) .. 1] $ \j -> do
    push st (unsafeShiftR i j)
  UM.unsafeRead tree (n+i)
 
prodLazySegTree :: (MonoidAction f a, Monoid a, PrimMonad m, U.Unbox a,U.Unbox f, Eq f) => LazySegTree (PrimState m) a f -> Int -> Int -> m a
prodLazySegTree st@(LazySegTree tree lazy) l0 r0 = do
  let !n = UM.length lazy
      !l = n+l0
      !r = n+r0
      !h = 63 - countLeadingZeros n
      !l_ctz = countTrailingZeros l
      !r_ctz = countTrailingZeros r
  forM_ [h,(h -1) .. (l_ctz+1)] $ \i -> do
    push st (unsafeShiftR l i)
  forM_ [h,(h -1) .. (r_ctz+1)] $ \i -> do
    push st (unsafeShiftR (r-1) i)
 
  fix
    ( \loop !accL !accR !l' !r' -> do
        if l' < r'
          then do
            !accL' <-
              if l' .&. 1 == 1
                then  (accL <>) <$!> UM.unsafeRead tree l'
                else return accL
            !accR' <-
              if r' .&. 1 == 1
                then (<> accR) <$!> UM.unsafeRead tree (r' - 1)
                else return accR
            loop
              accL'
              accR'
              (unsafeShiftR (l' + l' .&. 1) 1)
              (unsafeShiftR (r' - r' .&. 1) 1)
          else return $! accL <> accR
    ) mempty mempty l r
 
applyLazySegTree :: (PrimMonad m, MonoidAction f a, Monoid a, U.Unbox a,U.Unbox f, Eq f, Show f, Show a) =>LazySegTree (PrimState m) a f -> Int -> Int -> f -> m ()
applyLazySegTree st@(LazySegTree tree lazy) l0 r0 f =  do
  let !n = UM.length lazy
      !l = n+l0
      !r = n+r0
      !h = 63 - countLeadingZeros n
      !l_ctz = countTrailingZeros l
      !r_ctz = countTrailingZeros r
  forM_ [h,(h -1) .. (l_ctz+1)] $ \i -> do
    push st (unsafeShiftR l i)
  forM_ [h,(h -1) .. (r_ctz+1)] $ \i -> do
    push st (unsafeShiftR (r-1) i)
 
  fix
    ( \loop !l' !r' -> when (l' < r') $ do
        when (l' .&. 1 == 1) $ do
          applyAt st l' f
        when (r' .&. 1 == 1) $ do
          applyAt st (r' - 1) f
        loop
          (unsafeShiftR (l' + l' .&. 1) 1)
          (unsafeShiftR (r' - r' .&. 1) 1)
    ) l r
 
  forM_ [(l_ctz+1)..h] $ \i -> do
    pull st (unsafeShiftR l i)
  forM_ [(r_ctz+1)..h] $ \i -> do
    pull st (unsafeShiftR (r-1) i)

 

expandToPowerOfTwo :: Int -> Int
expandToPowerOfTwo x | x<2 = 2
          | otherwise = 1 `unsafeShiftL` (finiteBitSize (x-1) - countLeadingZeros (x-1) )


applyAt :: (PrimMonad m, MonoidAction f a, U.Unbox a,U.Unbox f) => LazySegTree (PrimState m) a f -> Int -> f -> m ()
applyAt (LazySegTree tree lazy) k f = do
  UM.unsafeModify tree (f <@>) k
  when (k < UM.length lazy) $ do
    UM.unsafeModify lazy (f <>) k
 
push :: (PrimMonad m, MonoidAction f a, U.Unbox f,U.Unbox a, Eq f) =>LazySegTree (PrimState m) a f -> Int -> m ()
push st@(LazySegTree tree lazy) i =  do
  when (i < UM.length lazy) $ do
    f <- UM.unsafeRead lazy i
    when (f /= mempty) $ do
      applyAt st (2*i) f
      applyAt st (2*i+1) f
      UM.unsafeWrite lazy i mempty
 
pull :: (PrimMonad m, U.Unbox a, Monoid a) => LazySegTree (PrimState m) a f -> Int -> m ()
pull st@(LazySegTree tree lazy) i = do
  x <- UM.unsafeRead tree (2*i)
  y <- UM.unsafeRead tree (2*i+1)
  UM.unsafeWrite tree i (x <> y)

Genericバージョン(Unbox化が必須ではない)

Generic Version
import Data.Monoid
import Data.Semigroup
import qualified Data.Vector.Generic               as G
import qualified Data.Vector.Generic.Mutable       as GM
import Control.Monad.Primitive
import Data.Bits
import Data.Function ( fix )

newLazySegTree :: (PrimMonad m, GM.MVector mv a, GM.MVector mv f, Monoid a,Monoid f) => Int -> m (LazySegTree mv (PrimState m) a f)
newLazySegTree x = do
  let !n = expandToPowerOfTwo x
  tree <- GM.replicate (2*n) mempty
  lazy <- GM.replicate (n) mempty
  return $ LazySegTree tree lazy
 
 
buildLazySegTree :: (PrimMonad m, Monoid a, Monoid f, G.Vector v a, G.Vector v f) => v a -> m (LazySegTree (G.Mutable v) (PrimState m) a f)
buildLazySegTree xs = do
  let !n = expandToPowerOfTwo $ G.length xs
  tree <- GM.replicate (2*n) mempty
  lazy <- GM.replicate n mempty
  G.unsafeCopy (GM.unsafeSlice n (G.length xs) tree) xs
  forM_ [n-1,n-2..1] $ \i -> do
    x <- GM.unsafeRead tree (2*i)
    y <- GM.unsafeRead tree (2*i+1)
    GM.unsafeWrite tree i (x <> y)
  return $ LazySegTree tree lazy
 
 

 
applyAt :: (PrimMonad m, MonoidAction f a, GM.MVector mv a,GM.MVector mv f) => LazySegTree mv (PrimState m) a f -> Int -> f -> m ()
applyAt (LazySegTree tree lazy) k f = do
  GM.unsafeModify tree (f <@>) k
  when (k < GM.length lazy) $ do
    GM.unsafeModify lazy (f <>) k
 
 
push :: (PrimMonad m, MonoidAction f a, GM.MVector mv a,GM.MVector mv f, Eq f) =>LazySegTree mv (PrimState m) a f -> Int -> m ()
push st@(LazySegTree tree lazy) i =  do
  when (i < GM.length lazy) $ do
    f <- GM.unsafeRead lazy i
    when (f /= mempty) $ do
      applyAt st (2*i) f
      applyAt st (2*i+1) f
      GM.unsafeWrite lazy i mempty
 
 
pull :: (PrimMonad m, Monoid a,GM.MVector mv a,GM.MVector mv f) => LazySegTree mv (PrimState m) a f -> Int -> m ()
pull st@(LazySegTree tree lazy) i = do
  x <- GM.unsafeRead tree (2*i)
  y <- GM.unsafeRead tree (2*i+1)
  GM.unsafeWrite tree i (x <> y)
 
setLazySegTree :: (PrimMonad m, MonoidAction f a, Monoid a, GM.MVector mv a,GM.MVector mv f, Eq f) =>LazySegTree mv (PrimState m) a f -> Int -> a -> m ()
setLazySegTree st@(LazySegTree tree lazy) i0 x = do
  let !n = GM.length lazy
      !i = n+i0
      !h = 63 - countLeadingZeros n
  forM_ [h,(h -1) .. 1] $ \j -> do
    push st (unsafeShiftR i j)
  GM.unsafeWrite tree i x
  forM_ [1..h] $ \j -> do 
    pull st (unsafeShiftR i j)
 
getLazySegTree :: (PrimMonad m, MonoidAction f a, GM.MVector mv a,GM.MVector mv f, Eq f) => LazySegTree mv (PrimState m) a f -> Int -> m a
getLazySegTree st@(LazySegTree tree lazy) i = do
  let !n = GM.length lazy
      !h = 63 - countLeadingZeros n
  forM_ [h,(h -1) .. 1] $ \j -> do
    push st (unsafeShiftR i j)
  GM.unsafeRead tree (n+i)
 
prodLazySegTree :: (MonoidAction f a, Monoid a, PrimMonad m, GM.MVector mv a, GM.MVector mv f, Eq f) => LazySegTree mv (PrimState m) a f -> Int -> Int -> m a
prodLazySegTree st@(LazySegTree tree lazy) l0 r0 = do
  let !n = GM.length lazy
      !l = n+l0
      !r = n+r0
      !h = 63 - countLeadingZeros n
      !l_ctz = countTrailingZeros l
      !r_ctz = countTrailingZeros r
  forM_ [h,(h -1) .. (l_ctz+1)] $ \i -> do
    push st (unsafeShiftR l i)
  forM_ [h,(h -1) .. (r_ctz+1)] $ \i -> do
    push st (unsafeShiftR r i)
 
  fix
    ( \loop !accL !accR !l' !r' -> do
        if l' < r'
          then do
            !accL' <-
              if l' .&. 1 == 1
                then  (accL <>) <$!> GM.unsafeRead tree l'
                else return accL
            !accR' <-
              if r' .&. 1 == 1
                then (<> accR) <$!> GM.unsafeRead tree (r' - 1)
                else return accR
            loop
              accL'
              accR'
              (unsafeShiftR (l' + l' .&. 1) 1)
              (unsafeShiftR (r' - r' .&. 1) 1)
          else return $! accL <> accR
    ) mempty mempty l r
 
applyLazySegTree :: (PrimMonad m, MonoidAction f a, Monoid a, GM.MVector mv a,GM.MVector mv f , Eq f, Show f, Show a) =>LazySegTree mv (PrimState m) a f -> Int -> Int -> f -> m ()
applyLazySegTree st@(LazySegTree tree lazy) l0 r0 f =  do
  let !n = GM.length lazy
      !l = n+l0
      !r = n+r0
      !h = 63 - countLeadingZeros n
      !l_ctz = countTrailingZeros l
      !r_ctz = countTrailingZeros r
  forM_ [h,(h -1) .. (l_ctz+1)] $ \i -> do
    push st (unsafeShiftR l i)
  forM_ [h,(h -1) .. (r_ctz+1)] $ \i -> do
    push st (unsafeShiftR (r-1) i)
 
  fix
    ( \loop !l' !r' -> when (l' < r') $ do
        when (l' .&. 1 == 1) $ do
          applyAt st l' f
        when (r' .&. 1 == 1) $ do
          applyAt st (r' - 1) f
        loop
          (unsafeShiftR (l' + l' .&. 1) 1)
          (unsafeShiftR (r' - r' .&. 1) 1)
    ) l r
 
  forM_ [(l_ctz+1)..h] $ \i -> do
    pull st (unsafeShiftR l i)
  forM_ [(r_ctz+1)..h] $ \i -> do
    pull st (unsafeShiftR (r-1) i)
 

expandToPowerOfTwo :: Int -> Int
expandToPowerOfTwo x | x<2 = 2
          | otherwise = 1 `unsafeShiftL` (finiteBitSize (x-1) - countLeadingZeros (x-1) )

4tsuzuru4tsuzuru

モノイド

{-# TemplateHaskell #-}
import Data.Monoid
import Data.Semigroup
import Data.Vector.Unboxed.Deriving ( derivingUnbox )
import Language.Haskell.TH

-- 加算作用素
data Add a = Add !a
  deriving (Show,Eq, Generic)

instance Num a => Semigroup (Add a) where
  Add x <> Add y = Add $ x + y

instance Num a => Monoid (Add a) where
  mempty = Add 0

derivingUnbox "Add"
  [t| forall a. (U.Unbox a ) => Add a -> a |]
  [| \(Add x) -> x |]
  [| Add |]

-- 区間和要素
data RangedSum a = RangedSum !a !a
  deriving (Show,Eq)

instance Num a => Semigroup (RangedSum a) where
  RangedSum len1 x <> RangedSum len2 y = RangedSum (len1+len2) (x+y)

instance Num a => Monoid (RangedSum a) where
  mempty = RangedSum 1 0

derivingUnbox "RangedSum"
  [t| forall a. (U.Unbox a ) => RangedSum a -> (a,a) |]
  [| \(RangedSum len x) -> (len,x) |]
  [| \(len,x) -> RangedSum len x |]

-- 区間更新作用素
data Update a = Update a | NoUpdate 
  deriving (Show,Eq)

instance Semigroup (Update a) where
  NoUpdate <> x = x
  x <> NoUpdate = x
  Update x <> Update y = Update $ x

instance Monoid (Update a) where
  mempty = NoUpdate

class Default a where
    def :: a
instance Default Int where
    def = 0

derivingUnbox "Update"
  [t| forall a. (U.Unbox a, Default a ) => Update a -> (Bool, a) |]
  [| \update -> case update of
        NoUpdate -> (False, def)
        Update x -> (True, x) |]
  [| \(b, x) -> if b then Update x else NoUpdate |]

-- アファイン変換作用素
data Affine a = Affine a a
  deriving (Show,Eq)

instance (Num a) => Semigroup (Affine a) where
  Affine a b <> Affine c d = Affine (a*c) (a*d+b)

instance (Num a) => Monoid (Affine a) where
  mempty = Affine 1 0

derivingUnbox "Affine"
  [t| forall a. (U.Unbox a ) => Affine a -> (a,a) |]
  [| \(Affine a b) -> (a,b) |]
  [| \(a,b) -> Affine a b |]

作用 (M x X -> X)

-- 区間加算・区間最小値
instance Num a => MonoidAction (Add a) (Min a) where
  Add x <@> Min y = Min $ x+y

-- 区間加算・区間和
instance (Num a) => MonoidAction (Add a) (RangedSum a) where
  (Add x) <@> (RangedSum len y) = RangedSum len (x * len + y)

-- 区間変更・区間最小値
instance MonoidAction (Update a) (Min a) where
  Update x <@> _ = Min x
  NoUpdate <@> x = x

-- 区間変更・区間和
instance (Num a) => MonoidAction (Update a) (RangedSum a) where
  Update x <@> RangedSum l y = RangedSum l (x*l)
  NoUpdate <@> x = x

-- 区間アファイン変換・区間和
instance (Num a) => MonoidAction (Affine a) (RangedSum a) where
  Affine a b <@> RangedSum l x = RangedSum l (a*x+b*l)
4tsuzuru4tsuzuru

BIT (Fenwick Tree)

import qualified Data.Vector.Unboxed               as U
import qualified Data.Vector.Unboxed.Mutable       as UM
import Data.Bits
import Control.Monad.Primitive

newtype FenwickTree s a = FenwickTree (UM.MVector s a)

newFenwickTree :: (PrimMonad m, Monoid a, U.Unbox a) => Int -> m (FenwickTree (PrimState m) a)
newFenwickTree n = FenwickTree <$> UM.replicate (n+1) mempty

buildFenwickTree :: (PrimMonad m, Monoid a, U.Unbox a) => U.Vector a -> m (FenwickTree (PrimState m) a)
buildFenwickTree xs = do
    let n = U.length xs
    ft <- UM.unsafeNew (n+1)  
    UM.unsafeWrite ft 0 mempty
    U.unsafeCopy (UM.tail ft) xs
    U.forM_ (U.enumFromN 1 n) $ \i -> do
        let j = i + (i .&. (- i))
        when (j <= n) $ do
          fti <- UM.unsafeRead ft i
          UM.unsafeModify ft (<> fti) j
    return $ FenwickTree ft   


mappendAt :: (PrimMonad m, Monoid a,U.Unbox a) => FenwickTree (PrimState m) a -> Int -> a -> m ()
mappendAt (FenwickTree t) i x = go i where
    go j | j >= UM.length t = return ()
         | otherwise = UM.modify t (<> x) j >> go (j + (j .&. (-j)))

-- [1,i]のモノイド積
mappendTo :: (PrimMonad m, Monoid a,U.Unbox a) => FenwickTree (PrimState m) a -> Int -> m a
mappendTo (FenwickTree t) i = go i mempty where
    go 0 acc = return acc
    go j acc = UM.read t j >>= \x -> go (j - (j .&. (-j))) (x <> acc )

転倒数

https://atcoder.jp/contests/chokudai_S001/tasks/chokudai_S001_j
BITの実装の方が早いと思われたが、まさかのSegTreeの方が高速だった。
ちなみに、Lazy SegTreeのMonoid Actionを()とすれば、普通のSegTreeとして使える。

転倒数の性質

BIT version

import Data.Semigroup
import Control.Monad.ST ( runST, ST )

inversionCount :: U.Vector Int -> Int
inversionCount xs = runST $ do
    let n = U.maximum xs
    ft <- newFenwickTree n :: ST s (FenwickTree s (Sum Int))
    ret <- U.foldM (\acc x -> do
        Sum inv <- mappendTo ft (x-1)
        mappendAt ft x (Sum 1)
        return $ acc + inv) 0 $ U.reverse xs
    return ret

SegTree version

instance MonoidAction () (Sum Int) where
  _ <@> x = x

inversionCountLST :: U.Vector Int -> Int
inversionCountLST xs = runST $ do
  let n = U.maximum xs + 1
  st <- newLazySegTree n :: ST s (LazySegTree s (Sum Int) ())
  res <- U.foldM  (\acc x -> do
      Sum inv <- prodLazySegTree st (x+1) n
      old <- getLazySegTree st x
      setLazySegTree st x (old <> Sum 1)
      return $ acc + inv
    ) 0  xs
  return res
4tsuzuru4tsuzuru

LIS

大きく分けて①二分探索か②セグ木を使うかの2種類の解法がある。
セグ木の場合は座標圧縮が必要な事が多い。
2つ目の実装は@cojna氏を参照した。

セグ木バージョン

-- SegTree・座標圧縮は上記参照

lis :: [Int] -> Int
lis xs = runST $ do
  let compressed = compressing xs
  st <- buildSegTree (VU.replicate (length compressed) (Max 0)) :: ST s (SegTree VU.MVector s (Max Int))
  forM_ compressed $ \a -> do
    v <- prodSegTree st 0 a
    setSegTree st a $ ((+1) <$> v) <> (Max 1)
  ans <- prodAllSegTree st
  return $ getMax ans

IntSetバージョン

import qualified Data.IntSet as IS

lis :: [Int] -> Int
lis xs = IS.size $ foldl' step IS.empty xs
  where
    step s x = case IS.lookupGE x s of
      Just y
        | x < y -> IS.insert x $ IS.delete y s
        | otherwise -> s
      Nothing -> IS.insert x s
4tsuzuru4tsuzuru

組み合わせ

stackoverflowより

combinationsOf :: Int -> [a] -> [[a]]
combinationsOf k as@(x:xs) | k == 0   = [[]]
                           | k == 1    = map pure as
                           | k == l    = pure as
                           | k >  l    = []
                           | otherwise = run (l-1) (k-1) as $ combinationsOf (k-1) xs
                             where
                             l = length as

                             run :: Int -> Int -> [a] -> [[a]] -> [[a]]
                             run n k ys cs | n == k    = map (ys ++) cs
                                           | otherwise = map (q:) cs ++ run (n-1) k qs (drop dc cs)
                                           where
                                           (q:qs) = take (n-k+1) ys
                                           dc     = product [(n-k+1)..(n-1)] `div` product [1..(k-1)]