Open46

OCaml で競技プログラミングをやっている時の知見

lmdexprlmdexpr

repo https://github.com/lmdexpr/atcoder
タイトル通りのことをメモっていく

なお、現在、OCaml が使えるコンテストはそもそも atcoder しかないと思われる

lmdexprlmdexpr

Codeforcesも使えるっぽいけど、バージョンが低いし、Core がない?

lmdexprlmdexpr

https://atcoder.jp/contests/abc276/submissions?f.Task=abc276_e&f.LanguageName=OCaml&f.Status=&f.User=lmdexpr
なんか知らんけど input_line_exn が死ぬというやつ

競プロやるなら scanf 使っておくのが丸い
また、二個目からの入力に対して " %d" などとして、頭にスペースを入れるのも有名(?)なハック
これで改行も含めた空白区切りを読み飛ばしてくれる

今思うと、input_line_exn も入力に対して改行文字が無いとかそういう感じだろうか
あんまり検証する気もない

lmdexprlmdexpr

https://atcoder.jp/contests/abc277/submissions/36456530 は TLE で、
https://atcoder.jp/contests/abc277/submissions/36457621 なら AC になる

つまり、 ref を使って mutable として使うと遅くて、再帰にして immutable で使うと圧倒的に速い

おそらく末尾再帰最適化によるものだと思う
ref によるオブジェクトのコピーもあるのかもだけど、多分コンパイラの最適化が ref を使うと効かないんじゃないかなあ(自信はない)

lmdexprlmdexpr

Atcoder では iter package ( https://ocaml.org/p/iter/1.2.1/doc/Iter/index.html ) が使用可能で、まあ、便利というか結構ないと困るレベルなので多用する。
ところで、 https://atcoder.jp/contests/typical90/tasks/typical90_bd の回答中に謎挙動を見た。

items (* Iter.t *)
|> Iter.flat_map (fun j -> ...)
|> Iter.filter (fun j -> ...)
|> Iter.map (fun j -> printf "debug print %d" j; j)

と、すると、出力には debug print の文字がなかった。
これを、

let items = 
  items (* Iter.t *)
  |> Iter.flat_map (fun j -> ...)
  |> Iter.filter (fun j -> ...)
in
items |> Iter.iter (fun j -> printf "debug print %d" j);
items

と書き直すと普通に動いた。
副作用が消えているように見えて不思議。

lmdexprlmdexpr

今思うと、Iter は継続になってるだけなので前者は実際に処理が走ってないだけだな、と思った

lmdexprlmdexpr

modulo 計算に便利系

let modulo = 1000000007
let (+%) a b = (a + b) % modulo
and (-%) a b = (a - b + modulo) % modulo
and( *%) a b = a * b % modulo
lmdexprlmdexpr

functor にした版
ちょっと遅くなるのがネック

module Modulo (M : sig include Int_intf.S val modulo : t end) = struct
  open M

  let rec inverse ?(b = modulo) ?(u = one) ?(v = zero) a =
    if b = zero then (u % modulo + modulo) % modulo
    else
      let t = a / b in
      let a, b = b, a - t * b in
      let u, v = v, u - t * v in
      inverse ~b ~u ~v a
  let power a b =
    Array.init num_bits ~f:Fn.id
    |> Array.fold ~init:(one, a) ~f:(fun (p, q) i ->
        if b land (one lsl i) <> zero then p * q % modulo, q * q % modulo
        else
          p, q * q % modulo
      )
    |> Tuple2.get1

  let ( + ) a b = (a + b) % modulo
  let ( * ) a b = (a * b) % modulo
  let ( / ) a b = a * inverse b
  let ( ** ) a b = power a b
end

使用例

module Modulo998244353 = Modulo (struct include Int let modulo = 998244353 end)
lmdexprlmdexpr

Iter 使うのに便利系

let (let+) x k = Iter.flat_map k x
and (let*) x k = Iter.map k x
and (let^) x k = Iter.filter_map k x

最後はfilterにすることもある

lmdexprlmdexpr

Iter がない環境でそれっぽいことをする用

let (--) start end_ k =
  for i = start to end_ do
    k i
  done
let (let+) seq f k = seq (fun x -> f x k)
let (let*) seq f k = seq (fun x -> k (f x))
let (let^) seq p k = seq (fun x -> if p x then k x)

ただし、最後は filter

lmdexprlmdexpr

結局 map がなくてダルくなるので普通に mutable 許容する方が楽かもしれない

lmdexprlmdexpr

Graph 用

module Graph = struct
  include Hashtbl
  let push v = function
    | None     -> Iter.singleton v
    | Some acc -> Iter.cons v acc
  let push g v u = update g v ~f:(push u)
  let connect g v u = push g v u; push g u v
 
  let around g v = find g v |> Option.value ~default:Iter.empty
end
let g = Graph.create ~size:n (module Int)
let () =
  for _ = 1 to m do
    scanf " %d %d" @@ Graph.connect g
  done

Iter じゃなくて普通に List でも良いが、結構 Iter で使うことの方が多かった。
connectは無向グラフ用。

lmdexprlmdexpr

演算子系。
適当書きで使ったことないので動くか知らないやつ。

let (.!())   = Graph.around
let (.!()<-) = Graph.push
let (.%()<-) = Graph.connect
lmdexprlmdexpr
module DirectedGraph = struct
  type t = {
    size : int;
    normal  : (int, int Iter.t) Hashtbl.t;
    reverse : (int, int Iter.t) Hashtbl.t;
  }
  let create ~size = {
    size;
    normal  = Hashtbl.create ~size (module Int);
    reverse = Hashtbl.create ~size (module Int);
  }

  let push v = function
    | None     -> Iter.singleton v 
    | Some acc -> Iter.snoc acc v
  let push g v u = Hashtbl.update g v ~f:(push u)
  let push g v u = push g.normal v u; push g.reverse u v

  let next g v = Hashtbl.find g.normal  v |> Option.value ~default:Iter.empty
  let pred g v = Hashtbl.find g.reverse v |> Option.value ~default:Iter.empty

  let strongly_connected_components g =
    let iterate visit f = 
      Iter.fold 
        (fun acc v -> if visit.(v) then acc else Iter.cons (f v) acc)
        Iter.empty
    in
    let step1 = 
      let visit = Array.init (g.size + 1) ~f:(const false) in
      let rec dfs acc v =
        visit.(v) <- true;
        next g v
        |> Iter.filter (fun u -> not visit.(u))
        |> Iter.fold dfs acc
        |> Iter.cons v
      in
      iterate visit @@ dfs Iter.empty
    in
    let step2 = 
      let visit = Array.init (g.size + 1) ~f:(const false) in
      let rec dfs acc v =
        visit.(v) <- true;
        pred g v
        |> Iter.filter (fun u -> not visit.(u))
        |> Iter.fold dfs (Iter.cons v acc)
      in
      iterate visit @@ dfs Iter.empty
    in
    Iter.(1 -- g.size) |> step1 |> Iter.flatten |> step2
end
lmdexprlmdexpr

めぐる式二分探索((l, r] でもつというやつ)

let rec binsearch ~ok left right =
  if abs (right - left) <= 1L then right
  else
    let mid = (right + left) / 2L in
    let left, right = if ok mid then left, mid else mid, right in
    binsearch ~ok left right

なんか微妙に使いこなせてなくて結構な確率でバグらせる。
でも、Core の binary_search で足りない時がごく稀にあり、使わざるを得ない。

lmdexprlmdexpr
let lower_bound, upper_bound =
  let arr_binsearch meth x = 
    Array.binary_search a ~compare meth x |> Option.value ~default:n
  in
  arr_binsearch `First_greater_than_or_equal_to,
  arr_binsearch `First_strictly_greater_than
lmdexprlmdexpr

segment tree

module SegmentTree = struct
  type m                = int
  let idm : m           = 0
  let mul : m -> m -> m = max

  type t = { tree: m array; size: int }

  let create ~len : t = { tree = Array.create ~len:(2 * len) idm; size = len }

  let product { tree; size } l r =
    let rec product lp rp l r =
      if r <= l
      then mul lp rp
      else
        product
          (if l mod 2 = 0 then lp else mul lp tree.(l))
          (if r mod 2 = 0 then rp else mul tree.(r - 1) rp)
          ((l + 1) / 2) (r / 2)
    in
    product idm idm (l + size) (r + size)

  let update { tree; size } i x =
    tree.(i + size) <- x;
    let left i = 2 * i and right i = 2 * i + 1 in
    let rec propagate i =
      if 0 < i then begin
        tree.(i) <- mul tree.(left i) tree.(right i);
        propagate (i / 2)
      end
    in
    propagate ((i + size) / 2)
end
lmdexprlmdexpr

C++ などで std::set が二分探索の代わりに使えることがある。(検索の実装が実質的に二分探索になっているため)
OCaml にも当然 Set があり、使えるのだが atcoder の環境というか Core を open した環境ではちょっとした罠がある。

以下、問題のネタバレなので一応隠す

Core を open した提出(TLE) https://atcoder.jp/contests/abc265/submissions/38911563
標準ライブラリの Set を使った提出(AC) https://atcoder.jp/contests/abc265/submissions/38911641

使用例だけ抜き出すとこんな感じ。

(* 標準ライブラリ環境 *)
module SI = Set.Make(Int)
let targets : SI.t = (* ... *)
let find x = SI.find_opt x targets |> Option.is_some

(* Core 環境 *)
open Core
let targets : Int.Set.t = (* ... *)
let find x = Int.Set.find targets ~f:((=) x) |> Option.is_some

これだけ見ても分かるし、実装を見ると特に分かりやすいが、Core では便利のために ~f を受け取ることが出来、この型は'a -> boolとなっている。
つまり、二分探索になっていないのである。
よって前者は O(logN) だが、後者は O(N) の実装になっており、致命傷になる。

また、余談だが、どちらの環境でもSetにはmem関数が存在しており、これはどちらとも高速に動作する(O(logN))のため、今回のケースならSet.memを使いさえすれば良い。

同様に問題のネタバレなので一応隠す

Int.Set.mem を使って、更にExtended indexing operators も使ってみた例
https://atcoder.jp/contests/abc265/submissions/38912026

lmdexprlmdexpr

ちなみに普通に Set にも binary_search があり、使える

lmdexprlmdexpr
let remove_range set l r =
  let compare = Int64.compare in
  let rec go acc set l =
    match Set.binary_search set ~compare `First_greater_than_or_equal_to l with
    | Some x when Int64.(x <= r) -> 
      go Set.(add acc x) (Set.remove set x) x
    | _ -> acc, set
  in
  go Int64.Set.empty set l

こういうことがしたい日もある(あった)

lmdexprlmdexpr
module Iter = struct
  include Iter
  let zip x y = flat_map (fun x -> map (fun y -> x, y) y) x
  let( * ) = zip
end
 
module Bit_all = struct
  open Iter
  let on x i = x land (1 lsl (i - 1)) <> 0
 
  let bits len = 0 -- (1 lsl len - 1)
  let by_bits x = map (fun bits -> (1 -- x) |> filter (on bits) |> map (fun x -> x - 1) |> to_array)
  let start x = by_bits x (bits x)
end

Bit 全探索
Iter の方は Bit 全探索じゃないときも欲しいけど、別になくても困らない

使い方は

Iter.(
    Bit_all.start a * Bit_all.start b
  )
|> Iter.filter (fun (i, j) -> ...)
|> Iter.iter (fun (i, j) -> ...)

みたいな感じ

lmdexprlmdexpr

module までするとだるい時はままあって、下みたいなコードをコピペするでも良い

let x = [| (* ... *) |]
let n = Array.length x

let ans =
  Iter.(0 -- (1 lsl n - 1))
  |> Iter.map (fun bits ->
      Iter.(0 -- (n - 1))
      |> Iter.filter (fun i -> bits land (1 lsl i) <> 0)
      |> Iter.map (Array.get x)
      |> Iter.to_array
    )
  |> Iter.filter_map (function
      | [| (* ... *) |] -> Some ( (* ... *) )
      | _ -> None
    )
lmdexprlmdexpr
module PI = struct
  type t = int * int
  let compare = Tuple2.compare ~cmp1:Int.compare ~cmp2:Int.compare
  let sexp_of_t = Tuple2.sexp_of_t sexp_of_int sexp_of_int
  let t_of_sexp = Tuple2.t_of_sexp int_of_sexp int_of_sexp
end
module SP = Set.Make(PI)

意外と使うけどちょっと面倒な pair の set

lmdexprlmdexpr
module type M = sig
  type t
  val compare : t -> t -> int
  val sexp_of_t : t -> Sexp.t
  val t_of_sexp : Sexp.t -> t
end
module Tuple2 = struct
  include Tuple2
  module Make (M1: M) (M2: M) = struct
    type t = M1.t * M2.t
    let compare = compare ~cmp1:M1.compare ~cmp2:M2.compare
    let sexp_of_t = sexp_of_t M1.sexp_of_t M2.sexp_of_t
    let t_of_sexp = t_of_sexp M1.t_of_sexp M2.t_of_sexp
  end
end
module S = Set.Make (Tuple2.Make (Int) (Int64))

汎用版

lmdexprlmdexpr
module Memo = struct
  include Memo
  let recursive m f =
    let h = Hashtbl.create m in
    let rec g x = Hashtbl.update_and_return h x ~f:(function
      | Some v -> v
      | None   -> f g x
    )
    in g
end

メモ化再帰(Core 環境)

lmdexprlmdexpr

使い方

let fib = Memo.recursive @@ fun self -> function
  | n when n < 2 -> 1
  | n -> self (n - 1) + self (n - 2)
lmdexprlmdexpr
let eratosthenes n =
  let sieve = Array.init n ~f:(const 1) in
  let rec eratosthenes ?(acc=Iter.empty) = function
    | []      -> acc
    | x :: xs ->
      if sieve.(x) = 0 then eratosthenes ~acc xs
      else begin
        List.range ~start:`inclusive ~stop:`inclusive ~stride:x (x*x) (n - 1)
        |> List.iter ~f:(fun x -> sieve.(x) <- 0);
        eratosthenes ~acc:(Iter.snoc acc @@ Int64.of_int x) xs
      end
  in
  let primes = eratosthenes List.(range 2 n) |> Iter.to_array in
  for i = 1 to Array.length sieve - 1 do
    sieve.(i) <- sieve.(i) + sieve.(i - 1);
  done;
  primes, sieve

let primes, prime_count = eratosthenes 300_005

エラトステネスの篩

lmdexprlmdexpr
let prime_count = Array.init 300_005 ~f:(const 0)
let rec sieve ?(acc=Iter.empty) = function
  | []      -> acc
  | x :: xs ->
    prime_count.(x) <- 1;
    List.filter ~f:(fun y -> y % x <> 0) xs
    |> sieve ~acc:(Iter.snoc acc @@ Int64.of_int x)
 
let primes = sieve (List.range 2 300_005) |> Iter.to_array
 
let () =
  for i = 1 to Array.length prime_count - 1 do
    prime_count.(i) <- prime_count.(i) + prime_count.(i - 1);
  done

なぜか TLE したバージョン
詳しいことは正直分かってないが、List.filterによる List の再生成が遅いとか……?

lmdexprlmdexpr

推測でしかないが、tailrec になってないとかあるのかなと思った
match とかするとダメな時がある……?

lmdexprlmdexpr
let (.!()<-) a i v = a.(i) <- max a.(i) v

こういうのがあると助かる命がたまにある

lmdexprlmdexpr
let chmin i v = dp.(i) <- min dp.(i) v
let chmax i j v = dp.(i).(j) <- max dp.(i).(j) v

こういうのでもいい
解説に寄せるならこの辺

lmdexprlmdexpr
module Heap = struct
  include Batteries.Heap
  let singleton v = add v empty
  let pop_min heap =
    if size heap = 0 then None
    else
      Some (find_min heap, del_min heap)
end

Heap (priority queue)
最近の Core にはなく、core_kernel の方で実装されていたりする

使えない認識なので Batteries から持ってくる

lmdexprlmdexpr
module Heap = struct
  module X = struct
    type t = int * int * int
    let compare (a, _, _) (b, _, _) = Int.compare a b
  end

  type t = Leaf | Node of t * X.t * t * int

  let empty = Leaf

  let singleton k = Node (Leaf, k, Leaf, 1)
  let rank = function Leaf -> 0 | Node (_,_,_,r) -> r

  let rec merge t1 t2 =  
    match t1, t2 with
    | Leaf, t | t, Leaf -> t
    | Node (_, k1, _, _), Node (_, k2, _, _) 
      when 0 < X.compare k1 k2 -> merge t2 t1
    | Node (l, k, r, _), _ ->
      let r = merge r t2 in
      let rank_left = rank l and rank_right = rank r in
      let l, r, rank =
        if rank_left >= rank_right 
        then l, r, rank_right
        else r, l, rank_left
      in
      Node (l, k, r, rank + 1)

  let insert t x = merge (singleton x) t
  let find_min = function
    | Leaf              -> None
    | Node (_, k, _, _) -> Some k
  let del_min = function
    | Leaf              -> empty
    | Node (l, _, r, _) -> merge l r
end

実装してみたやつ
特に計算量的に有利になったりはあまりしない認識

lmdexprlmdexpr
module Array = struct
  include Array
  let rec reverse a ~start ~stop =
    if start < stop then begin
      Array.swap a start stop;
      reverse a ~start:(start + 1) ~stop:(stop - 1)
    end
end

module Permutation (M: sig type t val compare: t -> t -> int end) = struct
  let next a ~l ~r =
    let downto_loop ~start ~stop ~p ~proc =
      Iter.(start --^ stop) |> Fn.flip Iter.fold_while false
      @@ fun _ i -> if p i then (proc i; true, `Stop) else false, `Continue
    in
    let change_to_next_permutation a ~l ~r =
      downto_loop ~start:(r - 1) ~stop:l
        ~p:(fun i -> M.compare a.(l) a.(i) < 0)
        ~proc:(fun i ->
            Array.swap a l i;
            Array.reverse a ~start:(l + 1) ~stop:(r - 1)
          )
    in
    downto_loop ~start:(r - 2) ~stop:l
      ~p:(fun i -> 
        M.compare a.(i) a.(i + 1) < 0 && 
        change_to_next_permutation a ~l:i ~r
      )
      ~proc:ignore

  let fold arr n ~f ~acc =
    let arr = Array.copy arr in
    Array.sort arr ~compare:M.compare;
    let rec permutations acc =
      let acc = f acc arr in
      let found_next = next arr ~l:0 ~r:n in
      if found_next then permutations acc else acc
    in
    permutations acc
end

module PermutationChar = Permutation (Char)

Permutation
これがないと死ぬ(死んだ)

lmdexprlmdexpr
let a = [| 1; 2; 3; 4; |]

let cumsum a =
  let paired f a b = let r = f a b in r, r in
  Array.folding_map a ~init:0 ~f:(paired Int.bit_xor)

let cumsum = cumsum a

let () =
  Array.iter cumsum ~f:(printf "%d ")

(* output: 1 3 6 10 *) 

累積和
初期値が含まれないことに注意

lmdexprlmdexpr
module Fenwick = struct
  module M = Int

  type t = { n : int; a : M.t array }

  let create n = { n; a = Array.create ~len:n M.zero }

  let add t i x =
    let rec go i = if i < t.n then (
      t.a.(i) <- M.(t.a.(i) + x);
      go (i lor (i + 1))
    ) in
    go i

  let sum t l r =
    let rec go i acc = 
      if i < 0 then acc
      else 
        go (i land (i + 1) - 1) M.(acc + t.a.(i))
    in
    M.(go Int.(pred r) zero - go Int.(pred l) zero)
end