🦥

TAOCPのソートアルゴリズムをHaskellで実装する

2022/12/25に公開

Haskellでクイックソートは簡潔に書ける

QuickSort.Wow

src/QuickSort/Wow.hs
sort :: Ord a => [a] -> [a]
sort = \case
	[] -> []
	x : xs -> sort ss ++ [x] ++ sort bs
		where (ss, bs) = partition (< x) xs

リストの中身を先頭の値以下のものと、それより大きいものとにわけるという作業を再帰的に行っている。

「おお! すごい。さすがHaskell」と思う。みんなそう思う。僕もそう思った。でも、これは全然「クイック」ではない。もちろんO(n^2)のアルゴリズムと比べたら速いのだけど、「クイックソート」のアイデンティティは(平均で見れば)、ヒープソートとかマージソートとかよりも「速い」というところにある(んじゃないかな)。

ここで、おなじように実装されたマージソート(?)を見てみよう。

MergeSort.Wow

src/MergeSort/Wow.hs
sort :: Ord a => [a] -> [a]
sort = \case
	[] -> []
	xs -> head . head . dropWhile multi . iterate pairs $ (: []) <$> xs

multi :: [a] -> Bool
multi = \case [_] -> False; _ -> True

pairs :: Ord a => [[a]] -> [[a]]
pairs = \case [] -> []; [xs] -> [xs]; (xs : ys : xss) -> merge xs ys : pairs xss

merge :: Ord a => [a] -> [a] -> [a]
merge [] bs = bs; merge as [] = as
merge aa@(a : as) ba@(b : bs) = bool (a : merge as ba) (b : merge aa bs) (a > b)

まずリストを「値ひとつからなるリスト」のリストに変換する((: []) <$> xs)。そのリストのリストについて、ふたつずつのペアについて、それぞれmergeを行う(pairs)。その操作をくりかえす(iterate pairs)。操作の結果のリストから「すべてがmergeされてひとつになった結果」を選ぶ(head . dropWhile multi)。そこから結果のリストを取り出す(head)。

これも由緒正しいマージソートとは異なるけれど、「マージによるソート」ではある。これらの「クイックソート(?)」と「マージソート」の速度を比較してみよう。

縦軸は並びかえにかかった時間を、要素数をNとしてN log Nで除算している。横軸は対数グラフとしてあり、1000要素から100万要素までとしてある。Nの値を変えながらの試行を8回行った。結果を見ると、「クイックソート(?)」が全然クイックではないことがわかる。

クイックソートとは

「Haskellのシンプルなクイックソートの実装」を「クイックソートと呼んでもいいのか」という話は、数年おきに何度もくりかえされている話題だ。「おおまかな考えかた」はおなじなので、「そう呼んでもいいのでは」とも思ったが確信がなかったので、TAOCP(The Art of Computer Programming)のVolume 3を参照した。TAOCPだとQuicksortは5.2.2のAlgorithm Qとして掲載されている。で、5.2.2は「Sorting by Exchanging」つまり「交換によるソート」だ。この時点で、例のコードはクイックソートとは呼びづらいと感じる。ぜんぜん「交換によるソート」じゃないので。さらに、「クイックソートが速い理由」について、以下のような一文がある(P.121)。

Its speed is primarily due to the fact that the inner loops, in steps Q3 and Q4, are extremely short

「クイックソートが速いのは、アルゴリズムのQ3とQ4の「内側のループ」が非常に簡潔であるからだ」と書かれている。Q3, Q4はつぎのような部分だ。

  • 配列の前のほうから要素を読んでいき、pivot以上の値を見つける
  • 配列の後ろのほうから要素を読んでいき、pivot以下の値を見つける

で、その見つけたふたつの値の位置を交換する。Q3は以下のように実装できる。

  • i番目の値を読み込む
  • レジスタの値と比較する
    • より小さければiを1増やしくりかえす
    • そうでなければループを終了する

Q4も同様だ。この操作は非常に単純だ。実行回数が大きくなりがちな内側のループが、このように非常に簡潔な処理になっている。これがクイックソートが「クイック」である理由であり、クイックソートの肝とも言えるだろう。で、この「肝」の部分が実装されていないので、「Haskellのシンプルなクイックソート」はクイックソートとは呼べないと思われる。

TAOCPのソートアルゴリズム

TAOCPに掲載されているクイックソートであれば「由緒正しい」クイックソートと言えるだろう。なのでHaskellで、そのアルゴリズムを実装してみることにする。まったくおなじになることを目指すのではないが、「肝」の部分は変えずに、また流れも保ったまま実装してみたい。

まずは練習として、より簡単なO(n^2)である3つのアルゴリズムを実装してみよう。

O(n^2)のソートアルゴリズム

O(n^2)のソートアルゴリズムのうち、つぎのものを実装してみる。

  • 単純挿入ソート
  • バブルソート
  • 単純選択ソート

単純挿入ソート

単純挿入ソートは、トランプで説明するとつぎのようになる。

  • 1枚トランプをめくる
  • 手元にトランプがあれば、新しくめくったトランプをそのなかの適切な場所にいれる
  • すべてのトランプをめくり終えるまでくりかえす

単純挿入ソートはTAOCPの5.2.1のアルゴリズムS(P.80)に掲載されている。

  • アルゴリズムS. レコードR1, ..., RNは整列される。ソートが終わったときキーK1<=...<=KNとなる
  • S1. [jについてのくりかえし] j = 2, 3, ..., NについてS2からS5をくりかえして終了
  • S2. [i, K, Rの初期化] i <- j - 1, K <- Kj, R <- Rj
  • S3. [KとKiの比較] もし K >= Ki ならS5に行く
  • S4. [Riを移動し、iを1減らす] Ri+1 <- Ri, i <- i - 1とする。もし i > 0 ならS3に行く
  • S5. [Ri+1にRを代入する]

レコードR1, ..., RNをキーK1, ..., KNについてソートしている。これをHaskellで実装してみよう。TAOCPのアルゴリズムでは、並べかえる値(Ri)と比較の対象となる値(Ki)とを区別しているが、今回の記事では並べかえる値そのものを比較の対象にすることにする。

StraightInsertionSort

src/StraightInsertionSort.hs
insertionSort :: Ord a => [a] -> [a]
insertionSort ks =
	runST $ (>>) <$> isort n <*> getElems =<< newListArray (1, n) ks
	where n = length ks

isort :: Ord a => Int -> STArray s Int a -> ST s ()
isort n ks = for_ [2 .. n] \j -> insert ks (j - 1) =<< readArray ks j

insert :: Ord a => STArray s Int a -> Int -> a -> ST s ()
insert ks i k = readArray ks i >>= \ki -> if k >= ki
	then writeArray ks (i + 1) k
	else do	writeArray ks (i + 1) ki
		if i - 1 > 0 then insert ks (i - 1) k else writeArray ks i k

STArrayはSTモナドで使える配列だ。STモナドには、その内側では状態変化を許すけれど、その状態変化を外側には見せない仕組みがあり、外側からは純粋に見えるようになっている。

(>>) <$> isort n <*> getElems =<< newListArray (1, n) ks

の部分は、わかりやすく書くと、つぎのようになる。

    do  a <- newListArray (1, n) ks
        isort n a
        getElems a

newListArrayで1からnまでのインデックスを持つ配列を生成して、それをisortでソートして、getElemsで結果をリストにして返している。この全体に対してrunSTすることで、「状態変化をともないリストを返すモナド」を「リストを返す純粋な関数」に変換している。

関数isortで配列のソートを行っている。jの値を2からnに変えながら、関数insertを引数(j - 1)と配列ksのj番目の要素を引数にして呼び出している。

関数insertは第2引数iから、さかのぼりながら第3引数kを挿入する場所を探す。値kiは配列ksのi番目の要素である。値kがkiより大きければ値kが挿入される場所は配列のi+1番目なので、そこに値kを書き込む。そうでなければ、i-1番目から、さらにさかのぼって挿入される場所を探す。iが1ならば、それ以上さかのぼらずに、そこが値kを挿入する位置になる。

バブルソート

  • アルゴリズムB (バブルソート).
  • B1. [BOUNDの初期化] BOUND <- N とする.
  • B2. [jについてのくりかえし] t <- 0 とする. B3をj = 1, 2, ..., BOUND - 1についてくりかえし, B4に行く.
  • B3. [比較と交換 Rj:Rj+1] もし Kj > Kj+1 なら Rj <-> Rj+1 の交換をし t <- j とする.
  • B4. [交換はさらに必要か?] もし t = 0 ならアルゴリズムは終了. そうでなければ BOUND <- t としてB2に行く.

BubbleSort

src/BubbleSort.hs
bubbleSort :: Ord a => [a] -> [a]
bubbleSort ks = runST $ (>>) <$> bsort n <*> getElems =<< newListArray (1, n) ks
	where n = length ks

bsort :: Ord a => Int -> STArray s Int a -> ST s ()
bsort bound ks = swap ks 0 [1 .. bound - 1] >>= \t ->
	bool (bsort t ks) (pure ()) (t == 0)

swap :: Ord a => STArray s Int a -> Int -> [Int] -> ST s Int
swap ks t = \case
	[] -> pure t
	(j : js) -> readArray ks j >>= \kj -> readArray ks (j + 1) >>= \kj1 ->
		if kj > kj1
			then do	writeArray ks j kj1; writeArray ks (j + 1) kj
				swap ks j js
			else swap ks t js

関数bsortの引数boundは「どこまでソートが終わっているか」を示す。関数bsortでは1, 2, 3 ..., bound - 1までの値のリストを作り、関数swapの第3引数にあたえる。第2引数tには初期値として0をあたえておく。引数tは「最後に交換した位置」を示している。関数bsortでは関数swapの第2引数tの最後の値を返り値として受け取る。そして、それが0でなければ、その値を引数boundの値として再帰的に関数bsortを呼び出す。

関数swapでは配列ksの隣り合う2つの値が逆順になっていたときに、その2つの値の位置を交換したうえで、その位置を引数tとして記憶している。

単純選択ソート

  • アルゴリズムS (単純選択ソート)
  • S1. [jについてのくりかえし] S2とS3を j = N, N - 1, ..., 2 についてくりかえす.
  • S2. [max(K1, ..., Kj)を探す] Kj, Kj-1, ..., K1 から最大のものを探す. それを Ki としたとき, iは最大になるようにする.
  • S3. [Rjとの交換] Ri <-> Rj の交換をする.

StraightSelectionSort

src/StraightSelectionSort.hs
selectionSort :: Ord a => [a] -> [a]
selectionSort ks =
	runST $ (>>) <$> ssort n <*> getElems =<< newListArray (1, n) ks
	where n = length ks

ssort :: Ord a => Int -> STArray s Int a -> ST s ()
ssort n ks = for_ [n, n - 1 .. 2] \j -> findMax ks j (j - 1) >>= \i -> do
	ki <- readArray ks i; kj <- readArray ks j
	writeArray ks i kj; writeArray ks j ki

findMax :: Ord a => STArray s Int a -> Int -> Int -> ST s Int
findMax ks i k = readArray ks i >>= \ki -> readArray ks k >>= \kk -> do
	let	i' = bool k i (ki >= kk)
	bool (pure i') (findMax ks i' (k - 1)) (k > 1)

関数ssortでは変数jにn, n - 1, ..., 2をつぎつぎに代入していき、jより小さいインデックスについて、対応する要素が最大になるものをiとして、配列のi番目とj番目とを入れ換えている。関数findMaxは引数kの値を小さくしていきながら、インデックスiに対応する値kiよりも大きな値になる要素のインデックスがあればインデックスiをそれで置き換えている。

O(n^2)のソートアルゴリズムの速度の比較

要素数Nを横軸に、ソートにかかった時間をN ^ 2で除算した値を縦軸としてある。今回の実装だと、単純挿入ソートのほうが単純選択ソートよりも3倍弱効率的になっている。

O(n log n)のソートアルゴリズム

O(n log n)のソートアルゴリズムのうち、つぎのものを実装してみる。

  • クイックソート
  • ヒープソート
  • マージソート

クイックソート

今回の導入となったクイックソートだ。本当のクイックソートを実装してみる。TAOCPに掲載されているアルゴリズムでは、ソートする区間のサイズがM以下になったら、クイックソートは終了とし、全体を単純挿入ソートとする実装になっている。

  • アルゴリズムQ (クイックソート).
    • K0 = -∞ と kN+1= ∞ とがあり K0 <= Ki <= KN+1 (1 <= i <= N) となっているものとする
  • Q1. [初期化] もし N <= M なら手順Q9に行く。そうでないならスタックを空にし、l <- 1, r <- N とする。
  • Q2. [新しい段階の開始] i <- l, j <- r + 1 とする。 K <- Kl とする。
  • Q3. [Ki : K の比較] iを1増やし、つぎにもし Ki < K ならばこの手順をくりかえす。
  • Q4. [K : Kj の比較] jを1減らし、つぎにもし K < Kj ならばこの手順をくりかえす。
  • Q5. [i : j を調べる] もし j <= i ならば Rl <-> Rj の交換をして手順Q7に行く。
  • Q6. [交換] Ri <-> Rj の交換をして手順Q3にもどる。
  • Q7. [スタックに積む] もし r - j >= j - l >= M なら (j + 1, r) をスタックの1番上に挿入し、r <- j - 1 としQ2に行く。もし j - l >= r - j >= M なら (l, j - 1)をスタックの1番上に挿入し、l <- j + 1 としQ2に行く。そうでないなら、もし r - j > M >= j - l なら l <- j + 1 としQ2に行く。もし j - l > M >= r - j なら r <- j - 1 としてQ2に行く。
  • Q8. [スタックから取り出す] もしスタックが空でないなら、1番上の要素(l', r')を削除し、 l <- l', r <- r' とし手順Q2に行く。
  • Q9. [単純挿入ソート] j = 2, 3, ..., N について、もし Kj-1 > Kj なら、つぎの処理を行う: K <- Kj, R <- Rj, i <- j - 1 とし、つぎに Ki <= K になるまで1回以上 Ri+1 <- Ri, i <- i - 1 とし、つぎに Ri+1 <- R とする。

Haskellでの実装はつぎのようになる。

QuickSort.Taocp

src/QuickSort/Taocp.hs
quicksortM :: (Ord a, Bounded a) => Int -> [a] -> [a]
quicksortM m ks =
	init . tail $ runST $ (>>) <$> qsort m n <*> getElem =<< array n ks
	where n = length ks

array :: Bounded a => Int -> [a] -> ST s (STArray s Int a)
array n ks = newArray_ (0, n + 1) >>= \a -> do
	writeArray a 0 minBound; writeArray a (n + 1) maxBound
	a <$ uncurry (writeArray a) `mapM_` zip [1 ..] ks

qsort :: Ord a => Int -> Int -> STArray s Int a -> ST s ()
qsort m n ks = do
	when (n > m) $ stage m ks 1 n
	when (m > 1) $ isort n ks

stage :: Ord a => Int -> STArray s Int a -> Int -> Int -> ST s ()
stage m ks l r = readArray ks l >>= \k -> do
	(j, kj) <- exchange ks k (l + 1) r
	writeArray ks l kj; writeArray ks j k
	case (r - j >= j - l, r - j > m, j - l > m) of
		(True, _, True) -> stage m ks l (j - 1) >> stage m ks (j + 1) r
		(False, True, _) -> stage m ks (j + 1) r >> stage m ks l (j - 1)
		(_, True, False) -> stage m ks (j + 1) r
		(_, False, True) -> stage m ks l (j - 1)
		_ -> pure ()

exchange ks k i j = do
	(i', ki) <- fromL i (< k) ks
	(j', kj) <- fromR j (k <) ks
	if j' <= i' then pure (j', kj) else do
		writeArray ks i' kj; writeArray ks j' ki
		exchange ks k (i' + 1) (j' - 1)

fromL, fromR :: Int (a -> Bool) -> STArray s Int a -> ST s (Int, a)
fromL i p ks =
	readArray ks i >>= \k -> bool (pure (i, k)) (fromL (i + 1) p ks) (p k)

fromR j p ks =
	readArray ks j >>= \k -> bool (pure (j, k)) (fromR (j - 1) p ks) (p k)

関数arrayはリストから配列を作る。リストの要素は配列の1からNまでに置かれる。アルゴリズムの都合上、配列の左端は最小値、右端は最大値になっていたほうがいい。なので、0からn+1までの配列を作り、位置0にminBoundを、位置n+1にmaxBoundをそれぞれ置いている。

関数qsortは関数stageとisortを順に呼び出している。要素数がM以下であればクイックソートは不要で単純挿入ソートのみでいい。またMが1の場合は単純挿入ソートは必要ない。

関数stageがクイックソートの本体だ。引数l, rはそれぞれソートする範囲の左端と右端だ。まずは左端の要素を読み出して、それを値kとする。関数exchangeを呼び出し、l+1からrまでの範囲について、値k以下のものは左に、値k以上のものは右に集める。関数exchangeは最終的な位置jの値とその場所の要素の値を返り値として返す。なので、それと左端(位置l)の要素とを交換することで、左端の要素を適切な位置に移動させることができる。

関数fromLとfromRとは、それぞれ左と右から条件pを満たす値を読み飛ばしていき、最初に条件pを満たさなくなった位置とその要素とを返す。関数exchangeはそれらの関数を使って、左から見ていってはじめに値k以上になる位置の値と、右から見ていって初めに値k以下になる位置の値とを交換していく。そして位置j'がi'以下になったところで終わりにする。

Mの値による速度の比較

上のアルゴリズムは、それぞれの区間に含まれる要素数がM以下になったところで、クイックソートの本体のアルゴリズムは終了とし、全体を単純挿入ソートでソートするようになっている。要素数が少なくなってくるとクイックソートの本体のアルゴリズムよりも単純挿入ソートのほうが効率が良くなるからだ。TAOCPで説明に使われる仮想的な機械であるMIXではMは9が効率的となっている。実際にMの値によって実行時間がどのように変化するかを調べてみた。1000から10万までのランダムな長さの、ランダムな値のリストに対して、Mを1から256のあいだで変化させたときの実行時間を調べた。縦軸は要素数Nに対してN log Nで除算した値になっている。横軸は対数にした。

これを見ると僕の環境ではM = 32あたりが効率的であるようだ。すべてをクイックソートの本体のアルゴリズムでソートした場合(M = 1)と比較すると、1.4倍ほどの効率になっているように見える。

なので、以下のように関数quicksortを定義する。

src/QuickSort/Taocp.hs
quickSort :: (Ord a, Bounded a) => [a] -> [a]
quickSort = quicksortM 32

ヒープソート

  • アルゴリズムH (ヒープソート).
  • H1. [初期化] l <- [N / 2 の整数部分] + 1, r <- N とする
  • H2. [lまたはrを減らす] もし l > 1 なら、 l <- l - 1, R <- Rl, K <- Kl とする。そうでないなら、 R <- Rr, K <- Kr, Rr <- R1, r <- r - 1 とし、もしそれで r = 1 となるなら、 R1 <- R としアルゴリズムを終了する。
  • H3. [シフトアップの準備] j <- l とする。
  • H4. [下に進める] i <- j, j <- 2j とする。もし j < r なら、そのまま手順H5に行く。もし j = r なら手順H6に行き、もし j > r ならH8に行く。
  • H5. [大きいほうの子を選ぶ] もし Kj < Kj+1 ならば、 j <- j + 1 とする
  • H6. [Kより大きいかどうか] もし K >= Kj なら、H8に行く。
  • H7. [それを上に上げる] Ri <- Rj とし、H4にもどる。
  • H8. [Rを置く] Ri <- R とする。H2にもどる。

H2のところは、2つのフェーズにわけることができる。値lを[N / 2の整数部分]+1から1まで減らしていく部分と、値rをNから1まで減らしていく部分だ。アルゴリズムのそれぞれの段階で、位置lからrまでがヒープ木になっている。つまり、初めのフェーズでは位置lを1まで減らしているので、全体をヒープ木にしていると考えることができる。そして、つぎのフェーズでは位置rをNから1に減らしていくので、ヒープ木を右から崩していっていると考えることができる。

HeapSort.TaocpFree

src/HeapSort/TaocpFree.hs
heapsort :: Ord a => [a] -> [a]
heapsort ks = runST $ (>>) <$> hsort m n <*> getElems =<< newListArray (1, n) ks
	where m = n `div` 2; n = length ks

hsort :: Ord a => Int -> Int -> STArray s Int a -> ST s ()
hsort m n ks = do
	for_ [m, m - 1 .. 1] \l -> shiftup ks l n =<< readArray ks l
	for_ [n, n - 1 .. 2] \r -> readArray ks r >>= \k -> do
		writeArray ks r =<< readArray ks 1
		shiftup ks 1 (r - 1) k

shiftup :: Ord a => STArray s Int a -> Int -> Int -> a -> ST s ()
shiftup ks j r k
	| c <= r = readArray ks c >>= \kc -> do
		(e, ke) <- if c < r
			then (<$> readArray ks d) \kd ->
				bool (c, kc) (d, kd) (kc < kd)
			else pure (c, kc)
		if k >= ke
			then writeArray ks j k
			else writeArray ks j ke >> shiftup ks e r k
	| otherwise = writeArray ks j k
	where c = 2 * j; d = c + 1

関数hsortでは

  • l = m, m - 1, ..., 1について、配列のl番目の値を関数shiftupでヒープ木に追加し、
  • r = n, n - 1, ..., 2について、配列のr番目の値を値kとして保存してから
    • ヒープ木の一番上の値を配列のr番目に書き込み
    • 値kを、1要素ぶん小さくなったヒープ木に追加している

関数shiftupは位置jからrまでをヒープ木とみなして、ヒープ木の1番上の値を削除して、値kを追加する関数だ。やっていることは、わりと単純だがヒープ木の末端で、つぎの3通りの場合分けのためにすこし複雑になっている。

  • 子要素がない場合 (c > r)
  • 左の子要素だけある場合 (c == r)
  • 左右の子要素がある場合 (c < r)

子要素がない場合は、関数shiftupのガードのotherwiseの部分であり、単に配列のj番目(ヒープ木の頂点)に値kを書き込めばいい。子要素がある場合には、まず左の子要素を読み出し(readArray ks r c)、それを値kcとする。さらに、もしc < rであれば右の子要素も存在するので、それを読み込み(readArray ks d)、値kdとする。値kcとkdとを比較してkcのほうが大きければ、e = c, ke = kcとし、そうでないなら、e = d, ke = kdとする(bool (c, kc) (d, kd) (kc < kd))。値kがke以上であれば、そのまま値kを位置jに書き込めばいい。もし値kがkeより小さければ値keを位置jに書き込んだうえで、位置eからrをヒープ木とみなして値kを追加する(shiftup ks e r k)操作を再帰的に行う。

マージソート

  • アルゴリズムN (自然2方向マージソート).
  • N1. [初期化] s <- 0 とする.
  • N2. [道の準備] もし s = 0 なら, i <- 1, j <- N, k <- N + 1, l <- 2N とし; もし s = 1 なら, i <- N + 1, j <- 2N, k <- 1, l <- N とする. d <- 1, f <- 1 とする.
  • N3. [KiとKjの比較] もし Ki > Kj なら, 手順N8に行く. もし i = j なら Rk <- Ri としてN13に行く.
  • N4. [Riを送る] Rk <- Ri, k <- k + d とする
  • N5. [小さくなる?] i を1増やす. そしてもし Ki-1 <= Ki なら, 手順N3にもどる.
  • N6. [Rjを送る] Rk <- Rj, k <- k + d とする.
  • N7. [小さくなる?] j を1減らす. もし kj+1 <= Kj なら, 手順N6にもどる; そうでないなら手順N12に進む.
  • N8. [Rjを送る] Rk <- Rj, k <- k + d とする.
  • N9. [小さくなる?] j を1減らす. もし Kj+1 <= Kj なら, 手順N3にもどる.
  • N10. [Riを送る] Rk <- Ri, k <- k + d とする.
  • N11. [小さくなる?] i を1増やす. もし Ki-1 <= Ki なら, 手痛N10にもどる.
  • N12. [どちら側かをいれかえる] f <- 0, d <- -d とし, k <-> l のいれかえをする. 手順N3にもどる.
  • N13. [領域をいれかえる] もし f = 0 なら, s <- 1 - s としN2にもどる. そうでないならソートは完了; もし s = 0 なら, (R1, ..., RN) <- (RN+1, ..., R2N) とする.

まずは、関数copyを定義する。

Data.Array.Tools

src/Data/Array/Tools
{-# INLINE copy #-}
copy :: (MArray a e m, Ix i) => a i e -> i -> i -> m ()
copy a d s = writeArray a d =<< readArray a s

関数copyは配列aの位置sの要素を位置dにコピーする。{-# INLINE copy #-}とすることで、関数copyはインライン展開される。こうしとかないと、関数の実行のたびに多相性の解決が必要になり、効率が目に見えて悪化してしまう。

MergeSort.Natural

src/MergeSort/Natural.hs
naturalSort :: Ord a => [a] -> [a]
naturalSort ks = take n
	$ runST $ (>>) <$> nsort n False <*> getElem =<< prepareArray n ks
	where n = length ks

prepareArray :: Int -> [a] -> ST s (STArray s Int a)
prepareArray n ks = newArray_ (1, n * 2) >>= \a ->
	a <$ uncurry (writeArray a) `mapM_` zip [1 ..] ks

nsort :: Ord a => Int -> Bool -> STArray s Int a -> ST s ()
nsort n s ks = inner ks i j k l 1 True >>= bool
	(nsort n (not s) ks) (unless s $ for_ [1 .. n] \m -> copy ks m (n + m))
	where (i, j, k, l) = bool (1, n, n + 1, 2 * n) (n + 1, 2 * n, 1, n) s

inner :: Ord a =>
	STArray s Int a -> Int -> Int -> Int -> Int -> Int -> Bool -> ST s Bool
inner ks i j k l d f
	| i == j = f <$ copy ks k i
	| otherwise = readArray ks i >>= \ki -> readArray ks j >>= \kj ->
		if ki > kj
		then transR ks i j k d >>= \case
			Nothing -> inner ks i j' k' l d f
			Just (ii, kk) -> inner ks ii j' l kk (- d) False
		else transL ks i j k d >>= \case
			Nothing -> inner ks i' j k' l d f
			Just (jj, kk) -> inner ks i' jj l kk(- d) False
	where i' = i + 1; j' = j - 1; k' = k + d

transL :: Ord a =>
	STArray s Int a -> Int -> Int -> Int -> Int -> ST s (Maybe (Int, Int))
transL ks i j k d = readArray ks i >>= \ki ->
	writeArray ks k ki >> readArray ks i' >>=
		bool (Just <$> flushR ks j k' d) (pure Nothing) . (ki <=)
	where i' = i + 1; k' = k + d

flushR :: Ord a => STArray s Int a -> Int -> Int -> Int -> ST s (Int, Int)
flushR ks j k d = readArray ks j >>= \kj ->
	writeArray ks k kj >> readArray ks j' >>=
		bool (pure (j', k')) (flushR ks j' k' d) . (kj <=)
	where j' = j - 1; k' = k + d

transR :: Ord a =>
	STArray s Int a -> Int -> Int -> Int -> Int -> ST s (Maybe (Int, Int))
transR ks i j k d = readArray ks j >>= \kj ->
	writeArray ks k kj >> readArray ks j' >>=
		bool (Just <$> flushL ks i k' d) (pure Nothing) . (kj <=)
	where j' = j - 1; k' = k + d

flushL :: Ord a => STArray s Int a -> Int -> Int -> Int -> ST s (Int, Int)
flushL ks i k d = readArray ks i >>= \ki ->
	writeArray ks k ki >> readArray ks i' >>=
		bool (pure (i', k')) (flushL ks i' k' d) . (ki <=)
	where i' = i + 1; k' = k + d

関数prepareArrayでは位置1から2Nまでの配列を用意している。これは、位置1からNまでの配列と、位置N+1から2Nまでの配列の2つを用意したのと同じことだ。この2つの配列を、コピー元とコピー先として交互に使うことになる。

関数nsortは関数innerをソートが終了するまで呼び出し続ける。ソートが終了したかどうかは関数innerの引数fで示される。関数innerは返り値として、引数fの最後の値を返す。もし終了してなければ(関数innerの返り値はFalse)、関数nsort自身を引数sを真偽を反転させて呼び出す。引数sを反転することでコピー先とコピー元を入れ換えることになる。ソートが終了していた場合、もし引数sがFalseならば(unless s)、コピー先が配列の後半になっているので、その要素をすべて前半にコピーする(for_ [1 .. n] \m -> copy ks m (n + m))。

関数innerはそれぞれのペアごとにマージを行う。ペアとは、値が大きくなっていく順に連続して並ぶ部分をブロックとして、先頭と末尾からそれぞれひとつずつブロックを取り出したものである。それをくりかえすことで複数のペアを作ることができる。

関数innerの実際の動きは、つぎのようになる。引数iとjとはそれぞれ先頭と末尾から順に進んできた、問題とする要素の位置を示す。位置iとjがおなじときは、それ以上マージするものがないので、その位置の要素をコピー先に送る(copy ks k i)。そうでない場合には、位置iの値kiと位置jの値kjとを比較して、小さいほうをコピー先に送る。たとえばkiのほうが小さい場合、else節のほうに進み関数transLが呼ばれる。

関数transLは左側のブロックから1要素をコピー先に送る。もしそれがブロックの末尾だったら、つまり配列のつぎの要素の値が、より小さい値であったら、右側のブロックの残りの値をすべて送る(flushR ks j k' d)。左側のブロックから送った値がブロックの末尾でなければ、関数transLはNothingを返す、末尾だった場合には新しいjとkの値をJust (j, k)の形で返す。

関数flushRは右側のブロックの残りの要素をすべて送る。「つぎの値(readArray ks j'の返す値)」が、より小さな値になるまで、つぎつぎと要素をコピー先に送り続ける。

関数transRとflushLはそれぞれ、transLとflushRの左右を入れ換えたものになっている。

O(n log n)のソートアルゴリズムの速度の比較

実行効率がO(n log n)になる3つのソートアルゴリズムの速度を比較する。

縦軸は実行時間をn log nで除算したもの。横軸は対数とした。TAOCPを参考に実装した3つの関数と標準ライブラリのData.List.sortとを比較している。

標準ライブラリのソート関数は要素数1万くらいまではいいが、それを越えるとどんどん遅くなっていく。標準ライブラリは「なんちゃってマージソート」アルゴリズムで、リストを大量に生成するので、GCに時間がかかっているのではないだろうか。

クイックソートは確かにクイックだ。マージソートもまあそれなりに速い。ヒープソートはNが10万くらいまではいいが、それ以上になるとO(n ^ 2)くらいの感じで遅くなっている。よくはわからないが、多分メモリのキャッシュミス率の問題なのではないだろうか。クイックソートやマージソートのメモリの使いかたとして、ある時間範囲では限られた範囲のメモリしか使わないため、メモリのキャッシュの使用効率がいいのに対して、ヒープソートでは短時間で広い範囲のメモリを使用するので、キャッシュの使用効率が悪いということだと思う。

ヒープソートはアルゴリズムとしてきれいで僕は好きなのだけど、「メモリは連続的に使われることが多い」という前提で効率化された現在のキャッシュのありかたにそぐわない。不遇なアルゴリズムと言えるんじゃないかな。

まとめ

「Haskellのあれをクイックソートと呼んでいいか」という話題がくりかえされるのは、ありがたいことだ。それはつまり、Haskellに入門する人が常に一定数存在するということだからだ。なので、うんざりすることなく「こういう点はクイックソートと呼べるけど、こういう点はそうは呼びにくい」という話を、相手への敬意を失うことなくしていくべきだろう。

TAOCPのソートアルゴリズムをいくつかHaskellで実装し、速度を比較してみた。今回はまにあわなかったが、他にやりたいことがいくつかあった。

  • 再帰がある程度深くなったところでクイックソートからヒープソートに切り換えるイントロソート
  • クイックソートの応用で中央値などを効率的に求めることができる「中央値の中央値」アルゴリズム
  • ある程度の大きさにリストを分けてヒープソートをして、ソートされたリストをマージするアルゴリズム(ヒープソートは要素数がそれほど大きくなければかなり高速なので)

機会があったら実装してみたい。

Discussion