Haskellで型レベルパーサー

2023/04/10に公開

ghc-9.6.1です。
GHCはChar kindが9.2で入ったので型レベルパーサーが書ける様になってるんですよね。ということで書いてみました。

BNFは以下です:

expr   ::= term *('+' term)
term   ::= factor *('*' factor)
factor ::= '(' expr ')' | nat
nat    ::= 0 | 1 | 2 | ... | 9

雑にコードをそのまま貼ります。

{-# Language TypeData #-}
{-# Language DataKinds #-}
{-# Language TypeFamilies #-}
{-# Language UndecidableInstances #-}
{-# Language TypeApplications #-}
{-# Language MagicHash #-}
{-# LANGUAGE NoStarIsType #-}

import GHC.TypeLits
import GHC.Exts
import Data.Ord
import Data.Type.Bool
import Data.Type.Equality

type data Expr = Literal Nat | Add Expr Expr | Mul Expr Expr

-- Eval Expr
type family Eval (expr :: Expr) :: Nat where
  Eval (Literal n) = n
  Eval (Add e1 e2) = (Eval e1) + (Eval e2)
  Eval (Mul e1 e2) = (Eval e1) * (Eval e2)
type family EvalResult (r :: Result Expr) :: Nat where
  EvalResult (Ok _ expr) = Eval expr
  EvalResult (Err msg) = TypeError (Text ("EvalResult failed: " <> msg))

-- Parser Result
type data Input = In Symbol
type Message = Symbol
type data Result a = Ok Input a | Err Message

-- Expr Parser
type family ParseExpr (i :: Input) :: Result Expr where
  ParseExpr input = ParseExpr_1 (ParseTerm (SkipWs input))
type family ParseExpr_1 (r :: Result Expr) :: Result Expr where
  ParseExpr_1 (Ok inp term) = ParseExpr_tail term inp
  ParseExpr_1 (Err msg) = Err msg

type family ParseExpr_tail (leftOp :: Expr) (i :: Input) :: Result Expr where
  ParseExpr_tail term inp = ParseExpr_tail_0 term (LookAhead (SkipWs inp))
type family ParseExpr_tail_0 (ctx :: Expr) (r :: Result (Maybe Char)) :: Result Expr where
  ParseExpr_tail_0 term (Ok inp (Just c)) =
    If (c == '+')
      (ParseExpr_tail_1 term inp)
      (Ok inp term)
  ParseExpr_tail_0 term (Ok inp Nothing) = Ok inp term
  ParseExpr_tail_0 _ (Err msg) = Err msg
type family ParseExpr_tail_1 (ctx :: Expr) (i :: Input) :: Result Expr where
  ParseExpr_tail_1 term inp = ParseExpr_tail_2 term (NextChar (SkipWs inp))
type family ParseExpr_tail_2 (ctx :: Expr) (r :: Result Char) :: Result Expr where
  ParseExpr_tail_2 term (Ok inp _) = ParseExpr_tail_3 term (ParseTerm (SkipWs inp))
  ParseExpr_tail_2 _ (Err msg) = Err msg
type family ParseExpr_tail_3 (ctx :: Expr) (r :: Result Expr) :: Result Expr where
  ParseExpr_tail_3 term (Ok inp term2) = ParseExpr_tail (Add term term2) (SkipWs inp)
  ParseExpr_tail_3 _ (Err msg) = Err msg

-- Term Parser
type family ParseTerm (i :: Input) :: Result Expr where
  ParseTerm input = ParseTerm_1 (ParseFactor (SkipWs input))
type family ParseTerm_1 (r :: Result Expr) :: Result Expr where
  ParseTerm_1 (Ok inp term) = ParseTerm_tail term inp
  ParseTerm_1 (Err msg) = Err msg

type family ParseTerm_tail (leftOp :: Expr) (i :: Input) :: Result Expr where
  ParseTerm_tail expr inp = ParseTerm_tail_0 expr (LookAhead (SkipWs inp))
type family ParseTerm_tail_0 (ctx :: Expr) (r :: Result (Maybe Char)) :: Result Expr where
  ParseTerm_tail_0 expr (Ok inp (Just c)) =
    If (c == '*')
      (ParseTerm_tail_1 expr inp)
      (Ok inp expr)
  ParseTerm_tail_0 expr (Ok inp Nothing) = Ok inp expr
  ParseTerm_tail_0 _ (Err msg) = Err msg
type family ParseTerm_tail_1 (ctx :: Expr) (i :: Input) :: Result Expr where
  ParseTerm_tail_1 expr inp = ParseTerm_tail_2 expr (NextChar (SkipWs inp))
type family ParseTerm_tail_2 (ctx :: Expr) (r :: Result Char) :: Result Expr where
  ParseTerm_tail_2 expr (Ok inp _) = ParseTerm_tail_3 expr (ParseFactor (SkipWs inp))
  ParseTerm_tail_2 _ (Err msg) = Err msg
type family ParseTerm_tail_3 (ctx :: Expr) (r :: Result Expr) :: Result Expr where
  ParseTerm_tail_3 expr (Ok inp factor) = ParseTerm_tail (Mul expr factor) (SkipWs inp)
  ParseTerm_tail_3 _ (Err msg) = Err msg

-- Factor Parser
type family ParseFactor (i :: Input) :: Result Expr where
  ParseFactor inp = ParseFactor_1 inp (ParseChar '(' (SkipWs inp))
type family ParseFactor_1 (ctx :: Input) (r :: Result Char) :: Result Expr where
  ParseFactor_1 orig (Ok inp _) = ParseFactor_2 orig (ParseExpr (SkipWs inp))
  ParseFactor_1 orig (Err _) = ParseNatLit orig
type family ParseFactor_2 (ctx :: Input) (r :: Result Expr) :: Result Expr where
  ParseFactor_2 orig (Ok inp expr) = ParseFactor_3 '(orig, expr) (ParseChar ')' (SkipWs inp))
  ParseFactor_2 orig (Err _) = ParseNatLit orig
type family ParseFactor_3 (ctx :: (Input, Expr)) (r :: Result Char) :: Result Expr where
  ParseFactor_3 '(_, expr) (Ok inp _) = Ok inp expr
  ParseFactor_3 '(orig, _) (Err _) = ParseNatLit orig

-- Lit Parser
type family ParseNatLit (i :: Input) :: Result Expr where
  ParseNatLit inp = ParseNatLit_1 (ParseDigit (SkipWs inp))
type family ParseNatLit_1 (r :: Result Nat) :: Result Expr where
  ParseNatLit_1 (Ok inp n) = Ok inp (Literal n)
  ParseNatLit_1 (Err msg) = Err msg


-- Parser Utilities
type family LookAhead (i :: Input) :: Result (Maybe Char) where
  LookAhead (In s) = Ok (In s) (Head s) -- inputは変えない

type family ParseChar (c :: Char) (i :: Input) :: Result Char where
  ParseChar c input = ParseChar' c (NextChar input)
type family ParseChar' (c :: Char) (r :: Result Char) :: Result Char where
  ParseChar' c (Ok input c') =
    If (c == c')
      (Ok input c)
      (Err ("ParseChar failed: expected=" <> CharToSymbol c <> ", actual=" <> CharToSymbol c'))

type family NextChar (i :: Input) :: Result Char where
  NextChar (In s) = NextChar' (UnconsSymbol s)
type family NextChar' (m :: Maybe (Char, Symbol)) :: Result Char where
  NextChar' (Just '(c, rest)) = Ok (In rest) c
  NextChar' Nothing = Err "NextChar failed: input is empty"

type family SkipWs (r :: Input) :: Input where
  SkipWs inp = SkipChars0 ' ' inp

type family SkipChars0 (c :: Char) (i :: Input) :: Input where
  SkipChars0 c (In s) =
    If (MaybeIsChar c (Head s))
      (SkipChars0 c (SkipChars0_1 (UnconsSymbol s)))
      (In s)
type family SkipChars0_1 (m :: Maybe (Char, Symbol)) :: Input where
  SkipChars0_1 (Just '(_, rest)) = In rest
  SkipChars0_1 Nothing = TypeError (Text "impossible")

type family MaybeIsChar (c :: Char) (m :: Maybe Char) :: Bool where
  MaybeIsChar c (Just c') = c == c'
  MaybeIsChar _ Nothing = False

type family ParseString (target :: Symbol) (i :: Input) :: Result Symbol where
  ParseString target (In s) = ParseString_1 target (SymbolToChars s) (In s)
type family ParseString_1 (target :: Symbol) (cs :: [Char]) (i :: Input) :: Result Symbol where
  ParseString_1 target '[] (In i) = Ok (In i) target
  ParseString_1 target (c ': cs) input = ParseString_2 target cs (ParseChar c input)
type family ParseString_2 (target :: Symbol) (cs :: [Char]) (r :: Result Char) :: Result Symbol where
  ParseString_2 target cs (Ok input _) = ParseString_1 target cs input
  ParseString_2 _ _ (Err msg) = Err msg

type family OneOf (cs :: [Char]) (i :: Input) :: Result Char where
  OneOf cs input = OneOf' cs (NextChar input)
type family OneOf' (cs :: [Char]) (r :: Result Char) :: Result Char where
  OneOf' cs (Ok input c) =
    If (Elem c cs)
      (Ok input c)
      (Err ("Expected one of [....], actual=" <> CharToSymbol c))
  OneOf' _ (Err msg) = Err msg

type family ParseDigit (i :: Input) :: Result Nat where
  ParseDigit input = ParseDigit' (NextChar input)
type family ParseDigit' (r :: Result Char) :: Result Nat where
  ParseDigit' (Ok input c) =
    If (IsDigit c)
      (Ok input (DigitToNat c))
      (Err ("Expected a digit, actual=" <> CharToSymbol c))

type family CharToSymbol (c :: Char) :: Symbol where
  CharToSymbol c = ConsSymbol c ""

type family IsDigit (c :: Char) :: Bool where
  IsDigit c = Elem c ['0','1','2','3','4','5','6','7','8','9']

type family DigitToNat (c :: Char) :: Nat where
  DigitToNat '0' = 0
  DigitToNat '1' = 1
  DigitToNat '2' = 2
  DigitToNat '3' = 3
  DigitToNat '4' = 4
  DigitToNat '5' = 5
  DigitToNat '6' = 6
  DigitToNat '7' = 7
  DigitToNat '8' = 8
  DigitToNat '9' = 9

type family Elem (c :: Char) (cs :: [Char]) :: Bool where
  Elem _ '[] = False
  Elem c (c':cs) = If (c == c') True (Elem c cs)

-- Prelude
type family Head (s :: Symbol) :: Maybe Char where
  Head s = MaybeFst (UnconsSymbol s)
type family Fst (t :: (a, b)) :: a where
  Fst '(a, _) = a
type family Snd (t :: (a, b)) :: b where
  Snd '(_, b) = b

-- Maybe
type family MaybeFst (m :: Maybe (a, b)) :: Maybe a where
  MaybeFst (Just '(a, _)) = Just a
  MaybeFst Nothing = 'Nothing
type family MaybeSnd (m :: Maybe (a, b)) :: Maybe b where
  MaybeSnd (Just '(_, b)) = Just b
  MaybeSnd Nothing = Nothing

-- Symbol
infixr 6 <>
type family (s1 :: Symbol) <> (s2 :: Symbol) :: Symbol where
  s1 <> s2 = AppendSymbol s1 s2

type family SymbolToChars (s :: Symbol) :: [Char] where
  SymbolToChars s = SymbolToChars' (UnconsSymbol s)
type family SymbolToChars' (m :: Maybe (Char, Symbol)) :: [Char] where
  SymbolToChars' (Just '(c, rest)) = c ': SymbolToChars' (UnconsSymbol rest)
  SymbolToChars' Nothing = '[]


main = do
  print $ natVal' (proxy# @(EvalResult (ParseExpr (In " 4 "))))
  print $ natVal' (proxy# @(EvalResult (ParseExpr (In "4 + 3"))))
  print $ natVal' (proxy# @(EvalResult (ParseExpr (In "4 * 3"))))
  print $ natVal' (proxy# @(EvalResult (ParseExpr (In "(3 + 2) * (2 + 1)"))))

case, let, 高階関数がない状態でのプログラミングは非常につらいということが判明しました。
確かに書けなくはないけどこれは少し厳しいですね。

Discussion