型安全で高速な連鎖行列積の計算
この記事は Haskell Advent Calendar 2021 の22日目の記事です。
次のような3つの行列の積を考えてみましょう。
行列積は結合律が成り立つので
しかし計算に必要は演算の回数はどうでしょうか。まず
反対に
計算する順番は与えられた行列積にどのような括弧つけるのかに対応しています。括弧を付ける組み合わせの数は3つの行列の積では2通りしかありませんでしたが、掛け合わせる行列の数が増えるにつれて指数関数的に増加していきます。なので全ての組み合わせを列挙して行列積のコストが一番低いパターンを探索するのは困難です。しかし
ところでHaskellではベクトルや行列の型にそれらの次元に関する情報を持たせるができます。そこで行列積を行う行列の組み合わせが与えられた時に、型にある次元の情報から最適な行列積を行う順番を自動的に計算しその通りに行列積を実行するような関数を実装することはできないでしょうか。
行列と連鎖行列の型
この記事では線形代数ライブラリとしてhmatrixのNumeric.LinearAlgebra.Staticを用います。このライブラリで
data L m n = ...
と定義されています。
行列積を行う対象となる複数の行列の組、すなわち連鎖行列を表す型を考えてみましょう。最も簡単に思いつくのは行列の単純なリストとして表現することでしょう。
type MCList m n = [L m n]
ここでMCList
はMatrix Chain Listの意味です。しかしこれはうまくいきません。なぜならこの型ではリストの要素は全て
そこでGADTを用いて独自のデータ型を作ってみましょう[2]。
data MCList m n where
S :: L m n -> MCList m n
(:.) :: KnownNat k => L m k -> MCList k n -> MCList m n
infixr 5 :.
このMCList m n
は行列積を行なった結果がS
は行列を1つ取ってMCList
を構築する値コンストラクタであり、(:.)
は計算結果が(:.)
は中間にある次元k
を存在型で隠蔽してしまうので結果の型MCList m n
にk
が現れることはありません。中間にある次元を明示的に管理しないことでコードが煩雑になることを防いでいます。ただし単なる存在型だと後の実装で扱いに困ってしまうのでKnownNat
のインスタンスであるという制約だけつけています。
このMCList
を使って単純に右から順番に行列積を計算する関数を作ってみましょう。
naiveMcm :: (KnownNat m, KnownNat n) => MCList m n -> L m n
naiveMcm (S a) = a
naiveMcm (a :. b) = a <> naiveMcm b
(<>)
はhmatrixが提供する行列積を行う演算子です。この演算子は
(<>) :: L m k -> L k n -> L m n
という型をしていて、掛け合わせる行列の次元が正しく揃っていないとコンパイルができないようになっています。naiveMcm
の実装ではGADTのパターンマッチが行われることで行列a
の列の次元と連鎖行列b
の行の次元が一致することが保証されるので上記のコードは問題なくコンパイルすることができるのです。
型の情報から連鎖行列問題を解く
連鎖行列問題を解くために、まずは連鎖行列の型から行列の次元をリストとして取り出す関数を実装してみましょう。
num :: forall n a. (KnownNat n, Num a) => a
num = fromIntegral $ natVal (Proxy @n)
dims :: forall m n. (KnownNat m, KnownNat n) => MCList m n -> [Int]
dims (S a) = [num @m, num @n]
dims (_ :. mcm) = num @m : dims mcm
num
はTypeApplications
を利用して型レベルの自然数から値を取り出すための補助関数です。dims
はMCList
を再帰的に辿って行列の次元のリストを作る関数です。例えばdims
に与えると[p, q, r]
というリストが返ってきます。このリストの長さは与えた連鎖行列に含まれる行列の数より1つ長くなっていることに注意してください。
いよいよ動的計画法を使って行列積の最小コストを求めてみましょう。ただしよくある競技プログラミングの問題とは違って今回は最終的に行列積自体を計算することが目的です。なので最小コストを求めると同時にそれを達成する計算順序も求める必要があります。計算順序は二分木を使って表現することができます。このことは以下のようにイメージすると良いでしょう。
計算順序を表すデータ型として形状の情報のみを持つ二分木の型を用意しておきます。
data Tree = Leaf | Node Tree Tree
deriving (Eq, Ord)
動的計画法を実装する方針は、i番目の行列からj番目の行列までの積を求める最小コストをボトムアップに計算していくというものです。
iからj番目の行列積を考える時、間にあるk番目で区切った場合にかかる計算コストは、
- iからk番目までの行列積の最小コスト
- k+1からj番目までの行列積の最小コスト
- i,k+1,j+1番目の次元の積
これらの和になります。最後でkとjに1が足されているのは次元のリストが積を行う数のリストより1つ長くなっていることに対応しています。
これをkをiからj-1まで順番に変化させて計算し、その中で一番小さいコストがiからj番目の行列積の最小コストになるというわけです。
この計算を行なった値を保存するために2次元の配列を用意します。計算は積を行う行列の数が少ない方から、つまり2次元配列の対角線から順番に行なっていきます。
minCost :: (KnownNat m, KnownNat n) => MCList m n -> (Int, Tree)
minCost xs = runST $ do
let ds = dims xs
n = length ds - 1
costs = listArray (0,n) ds :: UArray Int Int
indices = [(x, x+offset) | offset <- [0..n-1], x <- [0..n-1-offset]]
table <- newArray_ ((0,0), (n-1,n-1)) :: ST s (STArray s (Int, Int) (Int, Tree))
for_ indices $ \(i,j) ->
if i == j then writeArray table (i,j) (0, Leaf)
else do
candidates <- for [i..j-1] $ \k -> do
(cik, tik) <- readArray table (i,k)
(ckj, tkj) <- readArray table (k+1,j)
pure (cik + ckj + costs!i * costs!(k+1) * costs!(j+1), Node tik tkj)
writeArray table (i, j) $ minimum candidates
readArray table (0, n-1)
このminCost
を使えば行列積のコストが最小になる計算順序を求めることができます。
計算順序に沿って行列積を計算する
連鎖行列MCList
とそれを最小コストで計算する計算順序Tree
が手に入ったので、実際に計算順序に沿って行列積を計算する方法を考えてみましょう。そのためには単なるリストではなく計算順序も反映した連鎖行列の型を考えなければいけません。これを以下のような型を持った木として実装します。
data MCTree m n where
L :: L m n -> MCTree m n
N :: KnownNat k => MCTree m k -> MCTree k n -> MCTree m n
MCTree m n
は計算結果が
mcmTree :: (KnownNat m, KnownNat n) => MCTree m n -> L m n
mcmTree (L a) = a
mcmTree (N l r) = mcmTree l <> mcmTree r
木構造を用いることで計算順序が期待通りに反映されているのが分かります。
それでは最も重要なMCTree
の作り方について見ていきましょう。いきなり複雑な型で考えるのは難しいので、まずは普通のリストを普通の木に沿って木を組み立てる関数を考えてみます。
data Tree' a = Leaf' a
| Node' (Tree' a) (Tree' a)
buildLT :: [a] -> Tree' () -> (Tree' a, [a])
buildLT (a:as) (Leaf' _) = (Leaf' a, as)
buildLT as (Node' l r) =
let (l', as') = buildLT as l
(r', rest) = buildLT as' r
in (Node' l' r', rest)
buildLT
はリスト[a]
の要素を構造Tree' ()
に沿って組み立てた木Tree' a
を計算する関数です。返り値が組み立てた木Tree' a
と余った要素のリスト[a]
のタプルになっているので再帰的に木を組み立てていくことができます。buildLT
の実装でNode'
でパターンマッチされた部分を見てみると、まずNode'
の左側の木l
と与えられたリスト全体as
で木の組み立てを行い、余った要素のリストas'
を使って右側の木r
で木の組み立てを行っているという処理の流れです。
同様の処理をMCList
とMCTree
で行おうとすると1つ困ったことが起こります。木を組み立てる途中で連鎖行列を組み立た木と余りの要素のリストに分轄するのですが、この分割を行った際に境界にある次元が木とリストで一致していることを保証する必要が出てくるのです。これを解決するために以下のようなデータ型を新たに定義します。
data SomeTreeList m n where
NoRest :: MCTree m n -> SomeTreeList m n
SomeTreeList :: KnownNat k => (MCTree m k, MCList k n) -> SomeTreeList m n
NoRest
は余りの要素のリストが空になった場合を表します。SomeTreeList
は木とリストのペアになっていますが存在型によって間の次元k
が一致していることを保証するようにできています。これを利用すればMCList
とTree
から木を組み立てる関数は以下のように実装することができます。
buildMCTree :: (KnownNat m, KnownNat n) => MCList m n -> Tree -> SomeTreeList m n
buildMCTree (S a) Leaf = NoRest (L a)
buildMCTree (a :. b) Leaf = SomeTreeList (L a, b)
buildMCTree as (Node l r) =
case buildMCTree as l of
(NoRest _) -> undefined
(SomeTreeList (l', as')) ->
case buildMCTree as' r of
(NoRest r') -> NoRest (N l' r')
(SomeTreeList (r', rest)) -> (SomeTreeList (N l' r', rest))
動作はbuildLT
の場合とほぼ同じです。組み立た木を使って行列積を行う関数を実装してみましょう。
mcm :: (KnownNat m, KnownNat n) => MCList m n -> L m n
mcm xs =
let (_, parenthesis) = minCost xs
in case buildMCTree xs parenthesis of
(NoRest tree) -> mcmTree tree
(SomeTreeList (tree, _)) -> undefined
buildMCTree
とmcm
にundefined
が現れる箇所が1つずつあります。これは与えられたMCList
とTree
の長さが一致している保証が型レベルで与えられていないためで、それぞれ木が短かった場合と長すぎた場合に対応しています。さらに型レベルプログラミングを進めることでこのundefined
を無くせる可能性はありますが、今回は計算順序の木の構築はmcm
の中で完結しており、木とリストの長さは必ず一致するので深くは追求しないことにします。
最後に、実装したmcm
を使うことで実際に行列積の計算が速くなるのか確認してみましょう。
withTime :: IO a -> IO ()
withTime action = do
start <- getCurrentTime
action
end <- getCurrentTime
putStrLn $ formatTime defaultTimeLocale "Time: %-3Ess" (diffUTCTime end start)
main :: IO ()
main = do
putStrLn "Generating random matrices"
!a <- randn @100 @500
!b <- randn @500 @1000
!c <- randn @1000 @5000
!d <- randn @5000 @10000
let m = a :. b :. c :. S d
putStrLn "# mcm"
withTime . putStrLn $ "norm: " ++ show (norm_2 (mcm m))
putStrLn "# naiveMcm"
withTime . putStrLn $ "norm: " ++ show (norm_2 (naiveMcm m))
この例ではnaiveMcm
により右から計算すると計算コストは
となりますが、mcm
を用いて左から計算することで
となり計算コストをちょうど10分の1に減らすことができます。実行してみると、
Generating random matrices
# mcm
norm: 8072309.181602492
Time: 0.474s
# naiveMcm
norm: 8072309.181602493
Time: 3.209s
となり、もちろん順序計算などのオーバーヘッドがあるため10分の1にはなりませんが、かなり高速に計算できるようになりました👏
今回実装したコードはこちらのgistにアップロードしています。
\読んでいただきありがとうございました!/
この記事が面白かったら いいね♡ をいただけると嬉しいです☺️
バッジを贈っていただければ次の記事を書くため励みになります🙌
-
この問題は更に効率的に解くことができ
で解けることが知られています https://en.wikipedia.org/wiki/Matrix_chain_multiplication#More_efficient_algorithms ↩︎O(n\log n) -
今回はGADTを使いましたがヘテロリストとsingletonsを使った実装も可能です(少々煩雑にはなりますが) ↩︎
Discussion