🚀

クイックソートを学び直す

2023/08/05に公開

『アルゴリズム図鑑』を参考に自力で実装したもの

https://www.amazon.co.jp/dp/479817243X

def sort(a)
  if a.size <= 1
    return a
  end
  pivot = a[a.size / 2]
  l = a.find_all { |e| e < pivot }
  r = a.find_all { |e| e > pivot }
  sort(l) + [pivot] + sort(r)
end

動いてるように見えたが

a = [3, 6, 1, 4, 2, 5, 0]  # => [3, 6, 1, 4, 2, 5, 0]
sort(a)                    # => [0, 1, 2, 3, 4, 5, 6]

重複があると減った!?

a = [2, 0, 1, 1, 0, 2]  # => [2, 0, 1, 1, 0, 2]
sort(a)                 # => [0, 1, 2]

重複を維持する

def sort(a)
  if a.size <= 1
    return a
  end
  pivot = a[a.size / 2]
  l = a.find_all { |e| e < pivot  }
  m = a.find_all { |e| e == pivot } # pivot と同じ値のものは複数あるかもしれない
  r = a.find_all { |e| e > pivot  }
  sort(l) + m + sort(r)
end

これで問題なさそう。

a = [2, 0, 1, 1, 0, 2]  # => [2, 0, 1, 1, 0, 2]
sort(a)                 # => [0, 0, 1, 1, 2, 2]

ただ find_all 呼びすぎな気がする。

リファクタリング後

def sort(a)
  a ||= []
  if a.size <= 1
    return a
  end
  pivot = a[a.size / 2]
  g = a.group_by { |e| e <=> pivot }
  sort(g[-1]) + g[0] + sort(g[1])
end

とくに壊れてなさそう。

a = [2, 0, 1, 1, 0, 2]  # => [2, 0, 1, 1, 0, 2]
sort(a)                 # => [0, 0, 1, 1, 2, 2]

乱択版にしてみる

def sort(a)
  a ||= []
  if a.size <= 1
    return a
  end
  pivot = a.sample                  # ← pivot をランダムに決める
  g = a.group_by { |e| e <=> pivot }
  sort(g[-1]) + g[0] + sort(g[1])
end

問題なさそう。

a = [2, 0, 1, 1, 0, 2]  # => [2, 0, 1, 1, 0, 2]
sort(a)                 # => [0, 0, 1, 1, 2, 2]

これで完成したように思えたが一般的に見かけるクイックソートのコードとはどうも違うような気がする。『アルゴリズム図鑑』に出会う前、実装しようとして挫折したんだから Ruby で書けるのを差し引いてもこんなに簡単なわけがない。これは本当にクイックソートなんだろうか?

調べてみると一般的にソートは入力の配列を「その場で」作業場として利用し、破壊してしまう方法を指すらしくそれをかっこつけて in-place というらしい。不変可変は immutable と mutable という言葉が一般的なんだから in-place ではなく mutable でいいんじゃないかという気がするのだけどソート業界の意味合いとは少し違うようだ。なお in-place の反対は immutable ではなく out-of-place や not-in-place という (と Wikipedia に書いてあった)

個人的には常に入力を破壊しない immutable out-of-place にしてほしいけどとりあえず in-place な方法もやってみる。

in-place 版

初見ではさっぱりだったので『問題解決力を鍛える!アルゴリズムとデータ構造』を参考にした。

https://www.amazon.co.jp/dp/4065128447

# r は右端を出たとこではなく右端自体を指した方がわかりやすいと思った
def sort(a, l = 0, r = a.size.pred)
  if r <= l
    return
  end

  # pivot の選択方法にさまざまなパターンがある
  pi = (l + r) / 2              # 中央選択

  pivot = a[pi]

  # pivot を右端に退けておく
  # 理由は in-place で作業する上で pivot が途中にあると邪魔だから
  a[pi], a[r] = a[r], a[pi]

  # [小さい数たち, 大きい数たち, pivot] みたいな構造を作る
  # i, j どちらも左側から進む
  i = l                         # i はストアする位置 (入れると右に進む)
  (l..r.pred).each do |j|       # 比較位置の j は pivot の「隣」まで進む
    if a[j] < pivot             # pivot より小さいなら左側に行ってほしいので
      a[i], a[j] = a[j], a[i]   # i の位置と交換して
      $swap += 1
      i += 1                    # 次のためにポインタをずらす
    end
  end

  # pivot を真ん中あたりに戻す
  a[i], a[r] = a[r], a[i]

  # すると不思議なことに
  # [小さい数たち, pivot, 大きい数たち]
  # な構造になっている

  sort(a, l, i.pred)            # pivot 未満の左グループ
  sort(a, i.next, r)            # pivot 以上の右グループ (超過ではない)
end
$swap = 0
a = [3, 1, 0, 2, 3, 1, 0, 2]  # => [3, 1, 0, 2, 3, 1, 0, 2]
sort(a)
a                             # => [0, 0, 1, 1, 2, 2, 3, 3]
$swap                         # => 5

pivot を右端に寄せておくとかそんなん思いつかん。

ここに来て分かったことはクイックソートが難しいのではなく in-place にするのが難しいということだった。だからアルゴリズムを学び、実装する際には「仕組みをそのまま実装する方法」と「メモリを節約して実装する方法」を別々にして段階的に進めた方がよさそう。いきなり後者から入るのはハードルが高すぎる。

ChatGPT の工夫と欠点

とくに条件も付けずにクイックソートのコードを提示してもらい、リファクタリングしたのがこれだけど新しい工夫と発見があった。

def sort(a, l = 0, r = a.size.pred)
  if r <= l
    return
  end

  # pivot は a[r] にあると決め打ちしてた

  i = l
  (l..r).each do |j|            # j は pivot の位置まで来る
    if a[j] <= a[r]             # 最後に pivot <= pivot が成立する
      a[i], a[j] = a[j], a[i]
      $swap += 1
      i += 1
    end
  end

  # 上の入れ替えによって右端にいた pivot は真ん中らへん(iの位置)に移動している

  sort(a, l, i - 2)             # pivot 以下の左グループ
  sort(a, i, r)                 # pivot 超過の右グループ
end
$swap = 0
a = [3, 1, 0, 2, 3, 1, 0, 2]  # => [3, 1, 0, 2, 3, 1, 0, 2]
sort(a)
a                             # => [0, 0, 1, 1, 2, 2, 3, 3]
$swap                         # => 12
  • 工夫
    • pivot は右端とする
      • 最初から右端にあるので右端に移動する処理が省かれている
    • pivot を真ん中あたりに移動する処理がない? → ある
      • pivot と比較する範囲に pivot が含まれているため最後に pivot <= pivot が成立してこっそり移動している
  • 欠点
    • 偏りのあるデータに弱い
      • pivot が右端のせいで昇順や降順のデータに対して効率的に分割できない
    • 同じ値が多いと遅くなる
      • pivot との比較が <= のため同じ値が多いと無駄に交換回数が増える

あとで気づいたけどアプリ版アルゴリズム図鑑の解説もこれに似た方法になっている。おそらく入門者向けに簡潔に伝えたかったと思われる。

アプリ版アルゴリズム図鑑の「実験」機能風のポインタの動き

交換する位置を示すポインタがそれぞれ両端から中央に寄ってくる。

def sort(a, l = 0, r = a.size.pred)
  if l >= r
    return
  end

  pivot = a[(l + r) / 2]

  i = l
  j = r
  loop do
    # 左側で pivot 以上の値(の位置)を見つける
    until a[i] >= pivot
      i += 1
    end

    # 右側で pivot 以下の値(の位置)を見つける
    until a[j] <= pivot
      j -= 1
    end

    # 完全に交差していたら終わる
    if i > j
      break
    end

    # 左にある pivot 以上の値と、右にある pivot 未満の値を入れ替える
    a[j], a[i] = a[i], a[j]
    i += 1
    j -= 1
  end

  sort(a, l, j)
  sort(a, i, r)
end
a = [4, 6, 1, 5, 3, 8, 7, 5]  # => [4, 6, 1, 5, 3, 8, 7, 5]
sort(a)
a                             # => [1, 3, 4, 5, 5, 6, 7, 8]

左側から寄ってくるポインタ i は pivot 以上の値を指し、右側から寄ってくるポインタ j は pivot 以下の値を指す。i と j がクロスしたら終わる──というロジックは完全に左右対称になるので、視覚的に動作を確認しようとしたときに、両方左から進むタイプよりもわかりやすいかもしれない。コードも規則的で読みやすい。

ただ i と j を動かす似たような処理を二箇所に書かないといけないのでそこの重複が匂ってしまうOAOO信者には向かないかもしれない。

再帰は偏りのあるデータに弱い

右端選択のクイックソートに

def sort(a, l = 0, r = a.size.pred)
  if r <= l
    return
  end
  pi = r                        # 右端選択
  pivot = a[pi]
  a[pi], a[r] = a[r], a[pi]
  i = l
  (l..r.pred).each do |j|
    if a[j] < pivot
      a[i], a[j] = a[j], a[i]
      i += 1
    end
  end
  a[i], a[r] = a[r], a[i]
  sort(a, l, i.pred)
  sort(a, i.next, r)
end

N = 10000 の昇順の配列を渡すと

begin
  sort(10000.times.to_a)
rescue SystemStackError => error
  error  # => #<SystemStackError: stack level too deep>
end

偏りすぎて死ぬ。

末尾再帰最適化も

RubyVM::InstructionSequence.compile(<<~CODE, __FILE__, __dir__, __LINE__, tailcall_optimization: true).eval
  def sort(a, l = 0, r = a.size.pred)
    if r <= l
      return
    end
    pi = r                        # 右端選択
    pivot = a[pi]
    a[pi], a[r] = a[r], a[pi]
    i = l
    (l..r.pred).each do |j|
      if a[j] < pivot
        a[i], a[j] = a[j], a[i]
        i += 1
      end
    end
    a[i], a[r] = a[r], a[i]
    sort(a, l, i.pred)
    sort(a, i.next, r)
  end
CODE
begin
  sort(10000.times.to_a)
rescue SystemStackError => error
  error  # => #<SystemStackError: stack level too deep>
end

効かない。

末尾に2つあるせい?

再帰を使わない方法

これは

def sort(a, l = 0, r = a.size.pred)
  # ...
  sort(a, l, i.pred)
  sort(a, i.next, r)
end

これに置き換えることができる。

def sort(a)
  stack = []
  stack.push([0, a.size.pred])
  until stack.empty?
    l, r = stack.pop
    # ...
    stack << [l, i.pred]
    stack << [i.next, r]
  end
end

なのでこれになる。

def sort(a)
  stack = []
  stack.push([0, a.size.pred])
  until stack.empty?
    l, r = stack.pop

    # ---------------------------------------- ↓ここから
    if r <= l
      next
    end
    pi = r                      # 右端選択
    pivot = a[pi]
    a[pi], a[r] = a[r], a[pi]
    i = l
    (l..r.pred).each do |j|
      if a[j] < pivot
        a[i], a[j] = a[j], a[i]
        i += 1
      end
    end
    a[i], a[r] = a[r], a[i]
    # ---------------------------------------- ↑ここまでは同じ

    stack << [l, i.pred]
    stack << [i.next, r]
  end
end

これで死ななくなる。

a = 10000.times.to_a
sort(a)
a == a.sort  # => true

pivot のいろんな選択方法

a[l..r] を対象としたとき

中央
(l + r) / 2
3点中央
[
  l,
  (l + r) / 2,
  r,
].sort_by { |i| a[i] }.at(1)
5点中央
[
  l,
  l + (r - l) * 1 / 4,
  (l + r) / 2,
  l + (r - l) * 3 / 4,
  r,
].sort_by { |i| a[i] }.at(2)
完全中央
(l..r).sort_by { |i| a[i] }.at((r - l) / 2)
乱択
rand(l..r)
左端
l
右端
r
最小
(l..r).min_by { |i| a[i] }
最大
(l..r).max_by { |i| a[i] }

pivot の選択方法と処理時間の関係

N=5000 単位:ms

選択方法 ランダム 昇順 降順 sin sin(2θ) 同値
中央 26 12 20 551 386 1869
3点中央 29 13 39 907 443 1873
5点中央 24 14 23 22 416 1870
完全中央 28 17 25 25 27 2733
乱択 26 16 24 25 25 1873
左端 25 1869 2081 909 447 1868
右端 26 2286 2076 911 470 1893
最小 2512 2509 2506 2510 2508 2507
最大 2910 2930 2915 2330 1402 2487

pivot の選択方法 x データの並び の可視化

中央


中央 x ランダム = 289 ns (比較: 727, 交換: 332, 書込: 0)

中央 x 昇順 = 157 ns (比較: 543, 交換: 126, 書込: 0)

中央 x 降順 = 223 ns (比較: 578, 交換: 302, 書込: 0)

中央 x sin = 404 ns (比較: 1314, 交換: 530, 書込: 0)

中央 x sin(2θ) = 289 ns (比較: 740, 交換: 483, 書込: 0)

3点中央


3点中央 x ランダム = 248 ns (比較: 593, 交換: 302, 書込: 0)

3点中央 x 昇順 = 183 ns (比較: 543, 交換: 126, 書込: 0)

3点中央 x 降順 = 365 ns (比較: 741, 交換: 468, 書込: 0)

3点中央 x sin = 450 ns (比較: 1039, 交換: 740, 書込: 0)

3点中央 x sin(2θ) = 395 ns (比較: 758, 交換: 554, 書込: 0)

5点中央


5点中央 x ランダム = 303 ns (比較: 574, 交換: 304, 書込: 0)

5点中央 x 昇順 = 202 ns (比較: 543, 交換: 126, 書込: 0)

5点中央 x 降順 = 254 ns (比較: 557, 交換: 318, 書込: 0)

5点中央 x sin = 236 ns (比較: 551, 交換: 292, 書込: 0)

5点中央 x sin(2θ) = 336 ns (比較: 717, 交換: 482, 書込: 0)

完全中央


完全中央 x ランダム = 309 ns (比較: 543, 交換: 300, 書込: 0)

完全中央 x 昇順 = 226 ns (比較: 543, 交換: 126, 書込: 0)

完全中央 x 降順 = 297 ns (比較: 543, 交換: 294, 書込: 0)

完全中央 x sin = 272 ns (比較: 547, 交換: 274, 書込: 0)

完全中央 x sin(2θ) = 294 ns (比較: 584, 交換: 317, 書込: 0)

乱択


乱択 x ランダム = 270 ns (比較: 754, 交換: 328, 書込: 0)

乱択 x 昇順 = 185 ns (比較: 677, 交換: 96, 書込: 0)

乱択 x 降順 = 238 ns (比較: 668, 交換: 290, 書込: 0)

乱択 x sin = 244 ns (比較: 708, 交換: 306, 書込: 0)

乱択 x sin(2θ) = 261 ns (比較: 783, 交換: 322, 書込: 0)

左端


左端 x ランダム = 243 ns (比較: 747, 交換: 290, 書込: 0)

左端 x 昇順 = 851 ns (比較: 5049, 交換: 198, 書込: 0)

左端 x 降順 = 933 ns (比較: 5049, 交換: 148, 書込: 0)

左端 x sin = 445 ns (比較: 1166, 交換: 783, 書込: 0)

左端 x sin(2θ) = 327 ns (比較: 835, 交換: 549, 書込: 0)

右端


右端 x ランダム = 219 ns (比較: 683, 交換: 222, 書込: 0)

右端 x 昇順 = 963 ns (比較: 5049, 交換: 0, 書込: 0)

右端 x 降順 = 890 ns (比較: 5049, 交換: 50, 書込: 0)

右端 x sin = 414 ns (比較: 1096, 交換: 726, 書込: 0)

右端 x sin(2θ) = 301 ns (比較: 759, 交換: 500, 書込: 0)

最小


最小 x ランダム = 1146 ns (比較: 5049, 交換: 194, 書込: 0)

最小 x 昇順 = 1144 ns (比較: 5049, 交換: 198, 書込: 0)

最小 x 降順 = 1171 ns (比較: 5049, 交換: 148, 書込: 0)

最小 x sin = 1164 ns (比較: 5049, 交換: 198, 書込: 0)

最小 x sin(2θ) = 1141 ns (比較: 5049, 交換: 198, 書込: 0)

最大


最大 x ランダム = 1264 ns (比較: 5049, 交換: 94, 書込: 0)

最大 x 昇順 = 1240 ns (比較: 5049, 交換: 0, 書込: 0)

最大 x 降順 = 1250 ns (比較: 5049, 交換: 50, 書込: 0)

最大 x sin = 985 ns (比較: 2328, 交換: 1418, 書込: 0)

最大 x sin(2θ) = 872 ns (比較: 1425, 交換: 1089, 書込: 0)

Discussion