Elixirでダイクストラ(Dijkstra)法 + AtCoderで使ってみる

8 min read読了の目安(約7600字

ダイクストラ法というアルゴリズムを実装してみました。
また、実際にダイクストラ法を使う問題(ABC035 D - トレジャーハント)を解いてみたので
そのコードも書き残します。
(実際は逆で、この問題を解くために実装したので、各所で少々苦戦しました...)

ダイクストラ法

最短経路問題を解くことができるアルゴリズムです。辺が負の重みを持つ場合には使えません。

参考

https://ja.wikipedia.org/wiki/ダイクストラ法
http://www.deqnotes.net/acmicpc/dijkstra/

ダイクストラ法の実装

最初に全体のコードです。

defmodule Dijkstra do
  @inf 1_000_000_000

  def dijkstra(graph, n, start) do
    dist = Enum.reduce(1..n, %{}, fn i, acc -> Map.put(acc, i, @inf) end)
           |> Map.put(start, 0)
    solve(graph, put({}, {0, start}), dist)
  end
  
  defp solve(_, {}, dist) do
    dist
  end
  defp solve(g, que, dist) do
    {cost, v} = min(que)
    new_que = remove_min(que)
    if Map.get(dist, v) < cost do
      solve(g, new_que, dist)
    else
      edges = Map.get(g, v)
      {q, d} = update_que_and_dist(edges, cost, g, new_que, dist)
      solve(g, q, d)
    end
  end

  defp update_que_and_dist([], _, _, que, dist) do
    {que, dist}
  end
  defp update_que_and_dist(nil, _, _, que, dist) do
    {que, dist}
  end
  defp update_que_and_dist([{next, c}|tail]=_edges, cost, g, que, dist) do
    {new_dist, new_que} =
      if cost + c < Map.get(dist, next) do
        d = Map.put(dist, next, cost+c)
        q = put(que, {cost+c, next})
        {d, q}
      else
        {dist, que}
      end
    update_que_and_dist(tail, cost, g, new_que, new_dist)
  end

  
  ##### priority-queue #####

  def put(t, v) do
    merge(build(v), t)
  end

  def min({}), do: {}
  def min({{_, v}, _, _}), do: v

  def remove_min({}), do: {}
  def remove_min({_, l, r}), do: merge(l, r)

  defp merge(t1, {}), do: t1
  defp merge({}, t2), do: t2
  defp merge({{_, {key1, _}=v1}, l1, r1} = t1, {{_, {key2, _}=v2}, l2, r2} = t2) do
    if key1 < key2 do
      build(v1, l1, merge(r1, t2))
    else
      build(v2, l2, merge(t1, r2))
    end
  end

  defp rank({}), do: 0
  defp rank({{rank, _}, _, _}), do: rank

  defp build(v), do: build(v, {}, {})
  defp build(v, l, r) do
    if rank(l) >= rank(r) do
      {{rank(r)+1, v}, l, r}
    else
      {{rank(l)+1, v}, r, l}
    end
  end
end

dijkstra/3

# graph:
  # 例: %{1 => [{2, 3}, ...], 2 => [{4, 1}, ...], ... }
  # keyにノード、valueにタプルのリスト
  # タプルは、{keyから伸びてるノード, ノード間の辺のコスト}
# n:
  # グラフのノードの数
# start:
  # 最短経路のスタートノードの番号
  def dijkstra(graph, n, start) do
    # dist:
    #  各ノードとそのノードへの最短距離を格納する
    #  %{1 => 0, 2 => @inf, 3 => @inf, ... }
    #  keyにノード、valueにkeyまでの最小コスト
    dist = Enum.reduce(1..n, %{}, fn i, acc -> Map.put(acc, i, @inf) end)
           |> Map.put(start, 0)

    # solve(graph, priority-queue, dist):
    #  第二引数はpriority-queueで、{ノードまでのコスト, ノード}を入れる
    #  初期状態としてノードにスタートノードの番号、初期位置なのでコストには0を入れる
    #  これ以降priority-queueには、未訪問かつ現在のノードに隣接するノードを入れていく
    solve(graph, put({}, {0, start}), dist)
  end

solve/3

# 以下vは現在のノードを示しています。

# 第二引数のpriority-queueが空になるまで探索を続ける
  defp solve(_, {}, dist) do
    dist
  end
  defp solve(g, que, dist) do
    # コストが最小のノードを取得
    {cost, v} = min(que)
    # priority-queueから最小ノードの要素を除く
    new_que = remove_min(que)
    # 現時点でのノードvまでの最小コストと、今queueから取ったcostを比較
    if Map.get(dist, v) < cost do
      solve(g, new_que, dist)
    else
      # 新しく取ったcostのほうが小さい、もしくは同じなら
      # 現時点でのノードvまでの最小コストを更新していく
      # [{vから伸びてるノード, その間のコスト}, ...]
      # (グラフにvから伸びてるノードが存在しない場合はnilになる)
      edges = Map.get(g, v)

      # edgesを調べていく
      {q, d} = update_que_and_dist(edges, cost, g, new_que, dist)
      solve(g, q, d)
    end
  end

update_que_and_dist/5

# vから伸びてるノードを探索し終えたら、更新されたqueueとdistを返す
  defp update_que_and_dist([], _, _, que, dist) do
    {que, dist}
  end
  defp update_que_and_dist(nil, _, _, que, dist) do
    {que, dist}
  end
  defp update_que_and_dist([{next, c}|tail]=_edges, cost, g, que, dist) do
    {new_dist, new_que} =
      # edgesから取ったvから次のノード(next)のコスト(c)+スタートからvまでの最小ノード(cost)と、
      # 現時点での、スタートからnextへのコストを比較
      if cost + c < Map.get(dist, next) do
        # 新たなコストのほうが小さかったら
        # vの次のノードnextについて最小コストを更新
        d = Map.put(dist, next, cost+c)
        q = put(que, {cost+c, next})
        {d, q}
      else
        {dist, que}
      end
    update_que_and_dist(tail, cost, g, new_que, new_dist)
  end

Priority_Queue の実装

ここで一番悩みました。最初は自前で実装しようとしましたが、どうやっても削除、追加、参照のいずれかの操作がO(n)になってしまい断念。次にErlangのmoduleにあるgb_treesをつかってみましたが、提出してみるとTLE(他の部分に問題があったのかもしれない)。
そして結論、Leftist Tree というデータ構造にたどり着きました。

参考

http://typeocaml.com/2015/03/12/heap-leftist-tree/

Time complexity
1.get_min: O(1)
2.insert: O(logn)
3.delete_min: O(logn)
4.merge: O(logn)

なんとlinked_listで実装できる上に、上記の計算時間で操作できるらしい。たまげた。

実装は以下のコードを拝借、一部改変しました。

https://github.com/sotojuan/lheap/blob/master/lib/lheap.ex

問題を解いてみる

ABC035 D - トレジャーハント

以下、提出したコードです。

defmodule Main do
  @inf 1_000_000_000

  def init(list) do
    Enum.reduce(list, %{}, fn [a, b, c], acc -> Map.merge(acc, %{a => [{b, c}]}, fn _, v1, [v2] -> [v2 | v1] end) end)
  end

  def init_r(list) do
    Enum.reduce(list, %{}, fn [a, b, c], acc -> Map.merge(acc, %{b => [{a, c}]}, fn _, v1, [v2] -> [v2 | v1] end) end)
  end

  def dijkstra(g, n, s) do

    dist = Enum.reduce(1..n, %{}, fn i, acc -> Map.put(acc, i, @inf) end)
           |> Map.put(s, 0)

    solve(g, put({}, {0, s}), dist)
  end

  defp solve(_, {}, dist) do
    dist
  end
  defp solve(g, que, dist) do
    {cost, v} = min(que)
    new_que = remove_min(que)
    if Map.get(dist, v) < cost do
      solve(g, new_que, dist)
    else
      edges = Map.get(g, v)
      {q, d} = update_que_and_dist(edges, cost, g, new_que, dist)
      solve(g, q, d)
    end
  end

  defp update_que_and_dist([], _, _, que, dist) do
    {que, dist}
  end
  defp update_que_and_dist(nil, _, _, que, dist) do
    {que, dist}
  end
  defp update_que_and_dist([{next, c}|tail], cost, g, que, dist) do
    {new_dist, new_que} =
      if cost + c < Map.get(dist, next) do
        d = Map.put(dist, next, cost+c)
        q = put(que, {cost+c, next})
        {d, q}
      else
        {dist, que}
      end
    update_que_and_dist(tail, cost, g, new_que, new_dist)
  end

  def read_line(ptn) do
    {:ok, list} = :io.fread('', ptn)
    list
  end

  def read_all(ptn, n) do
    {:ok, list} = :io.fread('', List.duplicate(ptn, n) |> List.flatten())
    list
  end

  def main() do
    [n, m, t] = read_line('~d~d~d')
    a = read_all('~d', n) |> Enum.with_index(1)
    v = read_all('~d~d~d', m) |> Enum.chunk_every(3)

    res   = init(v) |> dijkstra(n, 1)
    res_r = init_r(v) |> dijkstra(n, 1)

    default = elem(hd(a), 0) * t
    ans =
      Enum.reduce(a, default, fn {x, idx}, acc ->
        rem = t - (Map.get(res, idx)+Map.get(res_r, idx))
        max(acc, rem*x)
      end)

    IO.puts(ans)
  end


  ##### priority-queue #####

  def put(t, v) do
    merge(build(v), t)
  end

  def min({}), do: {}
  def min({{_, v}, _, _}), do: v

  def remove_min({}), do: {}
  def remove_min({_, l, r}), do: merge(l, r)

  defp merge(t1, {}), do: t1
  defp merge({}, t2), do: t2
  defp merge({{_, {key1, _}=v1}, l1, r1} = t1, {{_, {key2, _}=v2}, l2, r2} = t2) do
    if key1 < key2 do
      build(v1, l1, merge(r1, t2))
    else
      build(v2, l2, merge(t1, r2))
    end
  end

  defp rank({}), do: 0
  defp rank({{rank, _}, _, _}), do: rank

  defp build(v), do: build(v, {}, {})
  defp build(v, l, r) do
    if rank(l) >= rank(r) do
      {{rank(r)+1, v}, l, r}
    else
      {{rank(l)+1, v}, r, l}
    end
  end
end

おわり

誤り、改善点などありましたらご指摘ください🙇‍♂️
他に良い実装方法がありましたらぜひ教えてください!