クイックソートを学び直す
『アルゴリズム図鑑』を参考に自力で実装したもの
def sort(a)
if a.size <= 1
return a
end
pivot = a[a.size / 2]
l = a.find_all { |e| e < pivot }
r = a.find_all { |e| e > pivot }
sort(l) + [pivot] + sort(r)
end
動いてるように見えたが
a = [3, 6, 1, 4, 2, 5, 0] # => [3, 6, 1, 4, 2, 5, 0]
sort(a) # => [0, 1, 2, 3, 4, 5, 6]
重複があると減った!?
a = [2, 0, 1, 1, 0, 2] # => [2, 0, 1, 1, 0, 2]
sort(a) # => [0, 1, 2]
重複を維持する
def sort(a)
if a.size <= 1
return a
end
pivot = a[a.size / 2]
l = a.find_all { |e| e < pivot }
m = a.find_all { |e| e == pivot } # pivot と同じ値のものは複数あるかもしれない
r = a.find_all { |e| e > pivot }
sort(l) + m + sort(r)
end
これで問題なさそう。
a = [2, 0, 1, 1, 0, 2] # => [2, 0, 1, 1, 0, 2]
sort(a) # => [0, 0, 1, 1, 2, 2]
ただ find_all 呼びすぎな気がする。
リファクタリング後
def sort(a)
a ||= []
if a.size <= 1
return a
end
pivot = a[a.size / 2]
g = a.group_by { |e| e <=> pivot }
sort(g[-1]) + g[0] + sort(g[1])
end
とくに壊れてなさそう。
a = [2, 0, 1, 1, 0, 2] # => [2, 0, 1, 1, 0, 2]
sort(a) # => [0, 0, 1, 1, 2, 2]
乱択版にしてみる
def sort(a)
a ||= []
if a.size <= 1
return a
end
pivot = a.sample # ← pivot をランダムに決める
g = a.group_by { |e| e <=> pivot }
sort(g[-1]) + g[0] + sort(g[1])
end
問題なさそう。
a = [2, 0, 1, 1, 0, 2] # => [2, 0, 1, 1, 0, 2]
sort(a) # => [0, 0, 1, 1, 2, 2]
これで完成したように思えたが一般的に見かけるクイックソートのコードとはどうも違うような気がする。『アルゴリズム図鑑』に出会う前、実装しようとして挫折したんだから Ruby で書けるのを差し引いてもこんなに簡単なわけがない。これは本当にクイックソートなんだろうか?
調べてみると一般的にソートは入力の配列を「その場で」作業場として利用し、破壊してしまう方法を指すらしくそれをかっこつけて in-place というらしい。不変可変は immutable と mutable という言葉が一般的なんだから in-place ではなく mutable でいいんじゃないかという気がするのだけどソート業界の意味合いとは少し違うようだ。なお in-place の反対は immutable ではなく out-of-place や not-in-place という (と Wikipedia に書いてあった)
個人的には常に入力を破壊しない immutable out-of-place にしてほしいけどとりあえず in-place な方法もやってみる。
in-place 版
初見ではさっぱりだったので『問題解決力を鍛える!アルゴリズムとデータ構造』を参考にした。
# r は右端を出たとこではなく右端自体を指した方がわかりやすいと思った
def sort(a, l = 0, r = a.size.pred)
if r <= l
return
end
# pivot の選択方法にさまざまなパターンがある
pi = (l + r) / 2 # 中央選択
pivot = a[pi]
# pivot を右端に退けておく
# 理由は in-place で作業する上で pivot が途中にあると邪魔だから
a[pi], a[r] = a[r], a[pi]
# [小さい数たち, 大きい数たち, pivot] みたいな構造を作る
# i, j どちらも左側から進む
i = l # i はストアする位置 (入れると右に進む)
(l..r.pred).each do |j| # 比較位置の j は pivot の「隣」まで進む
if a[j] < pivot # pivot より小さいなら左側に行ってほしいので
a[i], a[j] = a[j], a[i] # i の位置と交換して
$swap += 1
i += 1 # 次のためにポインタをずらす
end
end
# pivot を真ん中あたりに戻す
a[i], a[r] = a[r], a[i]
# すると不思議なことに
# [小さい数たち, pivot, 大きい数たち]
# な構造になっている
sort(a, l, i.pred) # pivot 未満の左グループ
sort(a, i.next, r) # pivot 以上の右グループ (超過ではない)
end
$swap = 0
a = [3, 1, 0, 2, 3, 1, 0, 2] # => [3, 1, 0, 2, 3, 1, 0, 2]
sort(a)
a # => [0, 0, 1, 1, 2, 2, 3, 3]
$swap # => 5
pivot を右端に寄せておくとかそんなん思いつかん。
ここに来て分かったことはクイックソートが難しいのではなく in-place にするのが難しいということだった。だからアルゴリズムを学び、実装する際には「仕組みをそのまま実装する方法」と「メモリを節約して実装する方法」を別々にして段階的に進めた方がよさそう。いきなり後者から入るのはハードルが高すぎる。
ChatGPT の工夫と欠点
とくに条件も付けずにクイックソートのコードを提示してもらい、リファクタリングしたのがこれだけど新しい工夫と発見があった。
def sort(a, l = 0, r = a.size.pred)
if r <= l
return
end
# pivot は a[r] にあると決め打ちしてた
i = l
(l..r).each do |j| # j は pivot の位置まで来る
if a[j] <= a[r] # 最後に pivot <= pivot が成立する
a[i], a[j] = a[j], a[i]
$swap += 1
i += 1
end
end
# 上の入れ替えによって右端にいた pivot は真ん中らへん(iの位置)に移動している
sort(a, l, i - 2) # pivot 以下の左グループ
sort(a, i, r) # pivot 超過の右グループ
end
$swap = 0
a = [3, 1, 0, 2, 3, 1, 0, 2] # => [3, 1, 0, 2, 3, 1, 0, 2]
sort(a)
a # => [0, 0, 1, 1, 2, 2, 3, 3]
$swap # => 12
- 工夫
- pivot は右端とする
- 最初から右端にあるので右端に移動する処理が省かれている
- pivot を真ん中あたりに移動する処理がない? → ある
- pivot と比較する範囲に pivot が含まれているため最後に
pivot <= pivot
が成立してこっそり移動している
- pivot と比較する範囲に pivot が含まれているため最後に
- pivot は右端とする
- 欠点
- 偏りのあるデータに弱い
- pivot が右端のせいで昇順や降順のデータに対して効率的に分割できない
- 同じ値が多いと遅くなる
- pivot との比較が
<=
のため同じ値が多いと無駄に交換回数が増える
- pivot との比較が
- 偏りのあるデータに弱い
あとで気づいたけどアプリ版アルゴリズム図鑑の解説もこれに似た方法になっている。おそらく入門者向けに簡潔に伝えたかったと思われる。
アプリ版アルゴリズム図鑑の「実験」機能風のポインタの動き
交換する位置を示すポインタがそれぞれ両端から中央に寄ってくる。
def sort(a, l = 0, r = a.size.pred)
if l >= r
return
end
pivot = a[(l + r) / 2]
i = l
j = r
loop do
# 左側で pivot 以上の値(の位置)を見つける
until a[i] >= pivot
i += 1
end
# 右側で pivot 以下の値(の位置)を見つける
until a[j] <= pivot
j -= 1
end
# 完全に交差していたら終わる
if i > j
break
end
# 左にある pivot 以上の値と、右にある pivot 未満の値を入れ替える
a[j], a[i] = a[i], a[j]
i += 1
j -= 1
end
sort(a, l, j)
sort(a, i, r)
end
a = [4, 6, 1, 5, 3, 8, 7, 5] # => [4, 6, 1, 5, 3, 8, 7, 5]
sort(a)
a # => [1, 3, 4, 5, 5, 6, 7, 8]
左側から寄ってくるポインタ i は pivot 以上の値を指し、右側から寄ってくるポインタ j は pivot 以下の値を指す。i と j がクロスしたら終わる──というロジックは完全に左右対称になるので、視覚的に動作を確認しようとしたときに、両方左から進むタイプよりもわかりやすいかもしれない。コードも規則的で読みやすい。
ただ i と j を動かす似たような処理を二箇所に書かないといけないのでそこの重複が匂ってしまうOAOO信者には向かないかもしれない。
再帰は偏りのあるデータに弱い
右端選択のクイックソートに
def sort(a, l = 0, r = a.size.pred)
if r <= l
return
end
pi = r # 右端選択
pivot = a[pi]
a[pi], a[r] = a[r], a[pi]
i = l
(l..r.pred).each do |j|
if a[j] < pivot
a[i], a[j] = a[j], a[i]
i += 1
end
end
a[i], a[r] = a[r], a[i]
sort(a, l, i.pred)
sort(a, i.next, r)
end
N = 10000 の昇順の配列を渡すと
begin
sort(10000.times.to_a)
rescue SystemStackError => error
error # => #<SystemStackError: stack level too deep>
end
偏りすぎて死ぬ。
末尾再帰最適化も
RubyVM::InstructionSequence.compile(<<~CODE, __FILE__, __dir__, __LINE__, tailcall_optimization: true).eval
def sort(a, l = 0, r = a.size.pred)
if r <= l
return
end
pi = r # 右端選択
pivot = a[pi]
a[pi], a[r] = a[r], a[pi]
i = l
(l..r.pred).each do |j|
if a[j] < pivot
a[i], a[j] = a[j], a[i]
i += 1
end
end
a[i], a[r] = a[r], a[i]
sort(a, l, i.pred)
sort(a, i.next, r)
end
CODE
begin
sort(10000.times.to_a)
rescue SystemStackError => error
error # => #<SystemStackError: stack level too deep>
end
効かない。
末尾に2つあるせい?
再帰を使わない方法
これは
def sort(a, l = 0, r = a.size.pred)
# ...
sort(a, l, i.pred)
sort(a, i.next, r)
end
これに置き換えることができる。
def sort(a)
stack = []
stack.push([0, a.size.pred])
until stack.empty?
l, r = stack.pop
# ...
stack << [l, i.pred]
stack << [i.next, r]
end
end
なのでこれになる。
def sort(a)
stack = []
stack.push([0, a.size.pred])
until stack.empty?
l, r = stack.pop
# ---------------------------------------- ↓ここから
if r <= l
next
end
pi = r # 右端選択
pivot = a[pi]
a[pi], a[r] = a[r], a[pi]
i = l
(l..r.pred).each do |j|
if a[j] < pivot
a[i], a[j] = a[j], a[i]
i += 1
end
end
a[i], a[r] = a[r], a[i]
# ---------------------------------------- ↑ここまでは同じ
stack << [l, i.pred]
stack << [i.next, r]
end
end
これで死ななくなる。
a = 10000.times.to_a
sort(a)
a == a.sort # => true
pivot のいろんな選択方法
a[l..r]
を対象としたとき
(l + r) / 2
[
l,
(l + r) / 2,
r,
].sort_by { |i| a[i] }.at(1)
[
l,
l + (r - l) * 1 / 4,
(l + r) / 2,
l + (r - l) * 3 / 4,
r,
].sort_by { |i| a[i] }.at(2)
(l..r).sort_by { |i| a[i] }.at((r - l) / 2)
rand(l..r)
l
r
(l..r).min_by { |i| a[i] }
(l..r).max_by { |i| a[i] }
pivot の選択方法と処理時間の関係
N=5000 単位:ms
選択方法 | ランダム | 昇順 | 降順 | sin | sin(2θ) | 同値 |
---|---|---|---|---|---|---|
中央 | 26 | 12 | 20 | 551 | 386 | 1869 |
3点中央 | 29 | 13 | 39 | 907 | 443 | 1873 |
5点中央 | 24 | 14 | 23 | 22 | 416 | 1870 |
完全中央 | 28 | 17 | 25 | 25 | 27 | 2733 |
乱択 | 26 | 16 | 24 | 25 | 25 | 1873 |
左端 | 25 | 1869 | 2081 | 909 | 447 | 1868 |
右端 | 26 | 2286 | 2076 | 911 | 470 | 1893 |
最小 | 2512 | 2509 | 2506 | 2510 | 2508 | 2507 |
最大 | 2910 | 2930 | 2915 | 2330 | 1402 | 2487 |
pivot の選択方法 x データの並び の可視化
中央
中央 x ランダム = 289 ns (比較: 727, 交換: 332, 書込: 0)
中央 x 昇順 = 157 ns (比較: 543, 交換: 126, 書込: 0)
中央 x 降順 = 223 ns (比較: 578, 交換: 302, 書込: 0)
中央 x sin = 404 ns (比較: 1314, 交換: 530, 書込: 0)
中央 x sin(2θ) = 289 ns (比較: 740, 交換: 483, 書込: 0)
3点中央
3点中央 x ランダム = 248 ns (比較: 593, 交換: 302, 書込: 0)
3点中央 x 昇順 = 183 ns (比較: 543, 交換: 126, 書込: 0)
3点中央 x 降順 = 365 ns (比較: 741, 交換: 468, 書込: 0)
3点中央 x sin = 450 ns (比較: 1039, 交換: 740, 書込: 0)
3点中央 x sin(2θ) = 395 ns (比較: 758, 交換: 554, 書込: 0)
5点中央
5点中央 x ランダム = 303 ns (比較: 574, 交換: 304, 書込: 0)
5点中央 x 昇順 = 202 ns (比較: 543, 交換: 126, 書込: 0)
5点中央 x 降順 = 254 ns (比較: 557, 交換: 318, 書込: 0)
5点中央 x sin = 236 ns (比較: 551, 交換: 292, 書込: 0)
5点中央 x sin(2θ) = 336 ns (比較: 717, 交換: 482, 書込: 0)
完全中央
完全中央 x ランダム = 309 ns (比較: 543, 交換: 300, 書込: 0)
完全中央 x 昇順 = 226 ns (比較: 543, 交換: 126, 書込: 0)
完全中央 x 降順 = 297 ns (比較: 543, 交換: 294, 書込: 0)
完全中央 x sin = 272 ns (比較: 547, 交換: 274, 書込: 0)
完全中央 x sin(2θ) = 294 ns (比較: 584, 交換: 317, 書込: 0)
乱択
乱択 x ランダム = 270 ns (比較: 754, 交換: 328, 書込: 0)
乱択 x 昇順 = 185 ns (比較: 677, 交換: 96, 書込: 0)
乱択 x 降順 = 238 ns (比較: 668, 交換: 290, 書込: 0)
乱択 x sin = 244 ns (比較: 708, 交換: 306, 書込: 0)
乱択 x sin(2θ) = 261 ns (比較: 783, 交換: 322, 書込: 0)
左端
左端 x ランダム = 243 ns (比較: 747, 交換: 290, 書込: 0)
左端 x 昇順 = 851 ns (比較: 5049, 交換: 198, 書込: 0)
左端 x 降順 = 933 ns (比較: 5049, 交換: 148, 書込: 0)
左端 x sin = 445 ns (比較: 1166, 交換: 783, 書込: 0)
左端 x sin(2θ) = 327 ns (比較: 835, 交換: 549, 書込: 0)
右端
右端 x ランダム = 219 ns (比較: 683, 交換: 222, 書込: 0)
右端 x 昇順 = 963 ns (比較: 5049, 交換: 0, 書込: 0)
右端 x 降順 = 890 ns (比較: 5049, 交換: 50, 書込: 0)
右端 x sin = 414 ns (比較: 1096, 交換: 726, 書込: 0)
右端 x sin(2θ) = 301 ns (比較: 759, 交換: 500, 書込: 0)
最小
最小 x ランダム = 1146 ns (比較: 5049, 交換: 194, 書込: 0)
最小 x 昇順 = 1144 ns (比較: 5049, 交換: 198, 書込: 0)
最小 x 降順 = 1171 ns (比較: 5049, 交換: 148, 書込: 0)
最小 x sin = 1164 ns (比較: 5049, 交換: 198, 書込: 0)
最小 x sin(2θ) = 1141 ns (比較: 5049, 交換: 198, 書込: 0)
最大
最大 x ランダム = 1264 ns (比較: 5049, 交換: 94, 書込: 0)
最大 x 昇順 = 1240 ns (比較: 5049, 交換: 0, 書込: 0)
最大 x 降順 = 1250 ns (比較: 5049, 交換: 50, 書込: 0)
最大 x sin = 985 ns (比較: 2328, 交換: 1418, 書込: 0)
最大 x sin(2θ) = 872 ns (比較: 1425, 交換: 1089, 書込: 0)
Discussion