💊

SOPを使ってジェネリックにCase Analysis関数を実装する

2021/01/08に公開

case analysis関数が何かについては、まず kakkun61氏 の以下の記事を参照してください。

https://kakkun61.hatenablog.com/entry/2021/01/06/Case_Analysis_関数

case analysis関数はデータ型毎に定まる関数ですがよくよく見てみるとシンプルなルールで統一的に実装できそうです。

https://twitter.com/lotz84_/status/1346780048653619200

Haskellで同じ名前の関数を使って複数のデータ型を扱えるようにするには、モジュールを分けたり型クラスを利用したり工夫する必要があります。

https://twitter.com/ryotakameoka/status/1346796279955767299

この記事ではジェネリックプログラミングの考え方に基づき、様々なデータ型に対応した一つのcase analysis関数を実装してみたいと思います。

これから実装するgfold'(generic fold)という関数は以下のような振る舞いをするようになります。

> :t unFun . gfold' @Bool
unFun . gfold' @Bool :: Bool -> r -> r -> r

> :t unFun . gfold' @(Maybe Bool)
unFun . gfold' @(Maybe Bool) :: Maybe Bool -> r -> (Bool -> r) -> r

> :t unFun . gfold' @(Either Bool Int)
unFun . gfold' @(Either Bool Int)
  :: Either Bool Int -> (Bool -> r) -> (Int -> r) -> r

SOP

SOPは sums of products の略でデータ型を直積の直和に分解する考え方です。SOPの論文で解説されている実装をライブラリにしたものがgenerics-sopです。

直和や直積といったワードに慣れない方はまずは以下の記事を参照してみてください。

https://ryota-ka.hatenablog.com/entry/2018/07/09/110000

通常、代数的データ型の文脈で直和や直積と言うとEitherとタプル(,)を使って表現しますが、実際に使うデータ型では複数の型による直積と直和を考えることが多いためSOPでは n-ary sums を表すNSと n-ary products を表す NP という型を使います(SOPが sum"s" of product"s" の略であることを思い出してください)。どちらも型レベルリストによって直積あるいは直和される型を管理しています。

実際にSOPを使って得られるデータ型の表現を見てみましょう。

> from (Just 'a')
SOP (S (Z (I 'a' :* Nil)))

GHC.Generics と違ってメタ情報が含まれていないため、単純に直和と直積の構造を扱いたい場合はSOPを使うのがシンプルでしょう。

> :t from (Just 'a')
from (Just 'a') :: SOP I '[ '[], '[Char]]

型を見てみるとネストされた型レベルリストがあることが分かります。外側の型レベルリストが直和に、内側の型レベルリストが直積に対応しています。 これはMaybe型の表現なので'[]Nothingに、'[Char]Just Charに対応しているというわけです。詳しく実装を見てみましょう。

-- | ジェネリックなSOPの表現との相互変換を扱う型クラス
class ... => Generic (a :: Type) where
    type Code a :: [[Type]]
    from :: a -> Rep a
    to :: Rep a -> a

-- | 型aのジェネリックな表現
type Rep a = SOP I (Code a)

-- | 直積の直和を表す型
newtype SOP (f :: k -> Type) (xss :: [[k]]) = SOP (NS (NP f) xss)

-- | 恒等関手
newtype I a = I a

Genericsの型クラス制約に関しては気にしなくていいので省略しました。ライブラリは標準的なほとんどの型に対してGenericsのインスタンスを定義していますし、自分で定義した型に対してGenericsのインスタンスを自動的に導出することも可能です。

N個の型の直和を表すNSは以下の様に定義されています。

data NS :: (k -> Type) -> [k] -> Type where
  Z :: f x -> NS f (x ': xs)
  S :: NS f xs -> NS f (x ': xs)

ペアノの自然数のような実装になっていますね。これは型レベルリストは直和を表しているため値としてはリストに含まれるいずれかの型の値しか持っていないので何番目の型の値を持っているのかを表すために自然数のような実装になっています。例えば以下のようなCharIntStringの直和を表すことができます。

> Z (I 'a') :: NS I '[Char, Int, String]
Z (I 'a')
> S $ Z (I 1) :: NS I '[Char, Int, String]
S (Z (I 1))
> S . S $ Z (I "abc") :: NS I '[Char, Int, String]
S (S (Z (I "abc")))

次にN個の型の直積を表すNPの実装を見てみましょう。

data NP :: (k -> Type) -> [k] -> Type where
  Nil  :: NP f '[]
  (:*) :: f x -> NP f xs -> NP f (x ': xs)

これは言わずもがなへテロリストと同様の実装になっていますね。例えば以下のようなCharIntStringの直積を表すことができます。

> I 'a' :* I 1 :* I "abc" :* Nil :: NP I '[Char, Int, String]
I 'a' :* I 1 :* I "abc" :* Nil

ジェネリックなcase analysis関数の型

それではSOPを使ってジェネリックなcase analysis関数gfoldを実装していきましょう。

まず型aに関するcase analysis関数の型を考えてみましょう。

gfold :: (直和の各成分を処理する関数の直積) -> a -> r

このように書けるはずですがカッコ内の関数の数はaの形によって変わってくるため、このままでは表現することが難しいです。そこでaの表現Rep aSOP I xssだったとしてもう一度case analysis関数の型を考えてみましょう。

gfold :: SOP I xss -> Fun xss r

第一引数はaの表現です。返り値の型であるFunはこれから定義しますが、(直和の各成分を処理する関数の直積) -> rを表していると考えてください。Funは型族を扱うための型であり、方針としてはaのSOPの構造を反映したxssを使って直和の各成分を処理する関数の直積の型を型族によって求めようと考えます。

例えば型の直積'[a, b, c]からrへの関数はカリー化を考えるとa -> b -> c -> rとなるでしょう。これを型族で実装すると以下のように書けます。

type family FunP (xs :: [Type]) r where
    FunP '[] r = r
    FunP (x ': xs) r = x -> FunP xs r

更に型の直和'[a, b, c]からrへの関数は(a -> r, b -> r, c -> r)と書けるでしょう。case analysis関数とし最終的に必要なのはこの対応の逆(a -> r, b -> r, c -> r) -> '[a, b, c] -> rであり、'[a, b, c]はSOPで得られているので(a -> r, b -> r, c -> r) -> rを得るための型族を実装します。

type family FunS (xss :: [[Type]]) r where
    FunS '[] r = r
    FunS (xs ': xss) r = FunP xs r -> FunS xss r

内側のリストが直積であることも考慮に入れて実装にはFunPも利用しています。

最後にFunSは単射な型族ではないため型として取り扱うのが面倒です。なのでFunSを単純にnewtypeによってラップしたFunを用意しておきます。

newtype Fun xss r = Fun {unFun :: FunS xss r}

これで gfold の型は完成です。次は実装に進みましょう。

ジェネリックなcase analysis関数の実装

gfoldは型レベルリストを型変数に持つため型クラスを使って帰納的に定義するのが良さそうです。

class GFold (xss :: [[Type]]) where
    gfold :: SOP I xss -> Fun xss r

まずxssが空リストの場合を考えましょう。これは型としては値を持たないVoidに相当します。Voidのcase analysis関数は何も値を返さないundefinedので実装としては以下のようにすると良さそうです。

instance GFold '[] where
    gfold _ = Fun undefined

次にxssが空リストではなく値を持つ場合を考えていきましょう。

instance GFold xss => GFold (xs ': xss) where
    gfold (SOP (S xs)) = ...
    gfold (SOP (Z x))  = ...

帰納的な定義を考えているのでtailに相当する部分はGFoldのインスタンスになっていることを前提にしています。gfoldの実装は更に直和を表すNSが何かの後続Sであるのか、Zであるのかによって場合分けが行われます。後続Sである場合は帰納的にxsに対して再びgfoldを行うだけです。

    gfold (SOP (S xs)) = constFun (gfold (SOP xs))

ここでconstFunは型レベルリストの型を合わせるための関数です。

constFun :: Fun xss r -> Fun (xs ': xss) r
constFun (Fun f) = Fun $ const f

ようするにxsに対応する関数を無視するようにして型を合わせています。case analysisが直和で対応する関数だけ実行するという挙動がこれによって実現されるわけです。

次に直和で対応する関数を実行する実装に当たるZのケースを見てみましょう。

    gfold (SOP (Z x))  = embed (Fun $ \f -> apply f x)

embedapplyは未定義なので後述します。直和の型に対応する値xが存在した場合、Fun $ \f ->によって対応する関数を取り出します。fの型はa -> b -> c -> rの様になっていてxの型はNP I (xs :: [Type])の様になっているので、関数適用するための工夫が必要です。そのため以下のような型クラスを用意します。

class Apply (xs :: [Type]) where
    apply :: FunP xs r -> NP I xs -> r

instance Apply '[] where
    apply r _ = r

instance Apply xs => Apply (x ': xs) where
    apply f ((I x) :* xs) = apply (f x) xs

これを使ってapply f xとすることでfxによって評価することができます。

Fun $ \f -> apply f xの型は結局Fun '[xs] rになりますが、型を合わせるためにはFun (xs ': xss) rにする必要があります。そこで必要になるのがembedです。embedも型クラスによって以下のように実装されています。

class Embed (xss :: [[Type]]) where
    embed :: Fun (xs ': '[]) r -> Fun (xs ': xss) r

instance Embed '[] where
    embed = id

instance Embed xss => Embed (xs ': xss) where
    embed = flipFun . constFun . embed

flipFun :: Fun (xs ': ys ': xss) r -> Fun (ys ': xs ': xss) r
flipFun f = Fun $ \ys xs -> unFun f xs ys

つまりcase analysis関数として対応する関数で評価した後に並んでいる関数は全て無視するconstFun様な実装になっています。

以上により最終的なgfoldの実装は以下のようになります。

class GFold (xss :: [[Type]]) where
    gfold :: SOP I xss -> Fun xss r

instance GFold '[] where
    gfold _ = Fun undefined

instance (Apply xs, Embed xss, GFold xss) => GFold (xs ': xss) where
    gfold (SOP (S xs)) = constFun (gfold (SOP xs))
    gfold (SOP (Z x))  = embed (Fun $ \f -> apply f x)

constFun :: Fun xss r -> Fun (xs ': xss) r
constFun (Fun f) = Fun $ const f


class Apply (xs :: [Type]) where
    apply :: FunP xs r -> NP I xs -> r

instance Apply '[] where
    apply r _ = r

instance Apply xs => Apply (x ': xs) where
    apply f ((I x) :* xs) = apply (f x) xs


class Embed (xss :: [[Type]]) where
    embed :: Fun (xs ': '[]) r -> Fun (xs ': xss) r

instance Embed '[] where
    embed = id

instance Embed xss => Embed (xs ': xss) where
    embed = flipFun . constFun . embed

flipFun :: Fun (xs ': ys ': xss) r -> Fun (ys ': xs ': xss) r
flipFun f = Fun $ \ys xs -> unFun f xs ys

実装したgfoldを使ってジェネリックなcase analysis関数を実装しましょう。

gfold' :: (GFold (Code a), Generic a) => a -> Fun (Code a) r
gfold' = gfold . from

本当は gfold' = unFun . gfold . from としたいところですが、Fun の中身はaが与えられるまで決まらないのでこの実装ではコンパイルを通すことはできません。

冒頭に上げた例をもう一度見てみましょう。

> :t unFun . gfold' @Bool
unFun . gfold' @Bool :: Bool -> r -> r -> r

> :t unFun . gfold' @(Maybe Bool)
unFun . gfold' @(Maybe Bool) :: Maybe Bool -> r -> (Bool -> r) -> r

> :t unFun . gfold' @(Either Bool Int)
unFun . gfold' @(Either Bool Int)
  :: Either Bool Int -> (Bool -> r) -> (Int -> r) -> r

もちろん実際に実行することも可能です。

> (unFun . gfold') True 1 2
2

> (unFun . gfold') (Just 1) "empty" show
"1"

> (unFun . gfold') (Right "Haskell") (++ "??") (++ "!!")
"Haskell!!"

ここで実装したgfold'Genericのインスタンスにさえしてしまえば自前で実装した型にも適用することができます。


ところでリストについてはどうでしょうか?

> :t unFun . gfold' @[Int]
unFun . gfold' @[Int] :: [Int] -> r -> (Int -> [Int] -> r) -> r

リストのcase analysis関数を思い出すと

> :t GHC.OldList.foldr
GHC.OldList.foldr :: (a -> b -> b) -> b -> [a] -> b

となっていて処理を行う関数側にはリストの型が現れていないことが分かります。

これを実現するためにはRecursion SchemesにおけるCatamorphismという考え方に触れる必要があります。

cata :: (Base t a -> a) -> t -> a

リスト[a]Base [a]に対応する型は

data ListF a b = Nil | Cons a b

という型であり、ListFのcase analysis関数を考えると

(a -> b -> b) -> b -> ListF a b -> b

という型になることが分かります。これとcataを組み合わせれば

(a -> b -> b) -> b -> [a] -> b

となりリストのcase analysis関数と一致することが分かります。

これをジェネリックに作るためには

distFun :: (a -> Fun xss r) -> Fun xss (a -> r)
mapFun :: (a -> b) -> Fun xss a -> Fun xss b

のような関数を実装した上で

gcata = mapFun cata (distFun gfold)

と実装すれば良さそうです。

ただ実装は大変そうなので今回は方針に触れるだけで終わりにしたいと思います。


\読んでいただきありがとうございました!/
この記事が面白かったら いいね♡ をいただけると嬉しいです☺️
バッジを贈っていただければ次の記事を書くため励みになります🙌

Discussion