💭

型推論器の実装③ 副作用のあるアルゴリズム

2022/11/25に公開

型推論器の実装① Hindley-Milner 型システム型推論器の実装② Algorithm Mに引き続き、本稿では副作用のある型推論アルゴリズムを説明する。

[Milner,1978]において、Algorithm W は健全性の証明をするために導入したアルゴリズムであり、効率的ではないと述べられている。それは置換を度々計算しなければならないためである。そこで同論文では、Algorithm W に加えて、副作用を持つより効率的なアルゴリズムである Algorithm J を紹介している。Algorithm J は単一化において、置換を生成するのではなく、その型変数自体に型を記録することによって、計算量を減らしている。
本稿ではこの Algorithm J を実装するが、単一化アルゴリズムなどについては、[Jones,2005]のコードを参考にし、推論関数のみ[Milner,1978]をもとに実装する。

ところで、Algorithm W を副作用によって効率化したアルゴリズムが Algorithm J であるならば、Algorithm M 由来のそれもあっていいように思うが、これには調べた限り、名前がついていないようなので、とりあえず Algorithm Z と名付けてみることにする。

これまでに出てきた型推論アルゴリズムを表にまとめる。

副作用\推論関数 型と置換を生成 型を内部に伝播
なし Algorithm W Algorithm M
あり Algorithm J (Algorithm Z)

なぜ Algorithm W や J と名付けたのかわからないが、 W から J に 13 文字シフトしているので、M から 13 文字シフトして Z と名付けた。

Algorithm J の実装

今回もAlgorithm Wからの差分のみを示す。
完成したコードはこちら
https://github.com/ksrky/type-inference/tree/master/src/algorithmJ

構文

型の構文に変更を加えた。

algorithmJ/Syntax.hs
data Type
        = TyVar TyVar
        | TyCon TyCon
        | TyFun Tau Tau
        | TyAll [TyVar] Tau
        | TyMeta MetaTv
        deriving (Eq, Show)

type Tau = Type
type Sigma = Type

newtype TyVar = BoundTv Name deriving (Eq, Ord, Show)

data TyCon = TUnit deriving (Eq, Show)

type Uniq = Int

TyVar の中身を単なる数値から文字列に置き換え、TyMeta を追加した。TyVar は出力時の型変数として現れ、TyMeta は推論中の書き換え可能なメタ変数となる。そのため Algorithm W や M における TyVar の実質的役割は TyMeta が担っていると言えるだろう。この変更は Algorithm J にとって必ずしも必要な変更とは言えないが、このように TyVar と TyMeta を分離することにより、型注釈を許したときに、プログラマによって注釈された型変数と内部で生成された型変数を区別できるようになる。

MetaTv の中身は以下のようになっている。

algorithmJ/Syntax.hs
data MetaTv = MetaTv Uniq (IORef (Maybe Tau))

instance Eq MetaTv where
        (MetaTv u1 _) == (MetaTv u2 _) = u1 == u2

instance Ord MetaTv where
        MetaTv u1 _ `compare` MetaTv u2 _ = u1 `compare` u2

instance Show MetaTv where
        show (MetaTv u _) = "$" ++ show u

等値性を判定するためのユニークな数値と、書き換え可能な値として型を保持している。Maybe TauNothingのときは自由変数であり、Just Tのときには然るべきタイミングで型 T に置き換えられる。
MetaTv の生成、読み込み、書き換えはそれぞれnewMetaTvreadMetaTvwriteMetaTvによって行われる。

algorithmJ/Monad.hs
-- | Creating, reading and writing IORef
newInfRef :: MonadIO m => a -> Infer m (IORef a)
newInfRef v = lift (liftIO $ newIORef v)

readInfRef :: MonadIO m => IORef a -> Infer m a
readInfRef r = lift (liftIO $ readIORef r)

writeInfRef :: MonadIO m => IORef a -> a -> Infer m ()
writeInfRef r v = lift (liftIO $ writeIORef r v)

-- | Creating, reading and writing MetaTv
newMetaTv :: MonadIO m => Infer m MetaTv
newMetaTv = MetaTv <$> newUniq <*> newInfRef Nothing

readMetaTv :: MonadIO m => MetaTv -> Infer m (Maybe Tau)
readMetaTv (MetaTv _ ref) = readInfRef ref

writeMetaTv :: MonadIO m => MetaTv -> Tau -> Infer m ()
writeMetaTv (MetaTv _ ref) ty = writeInfRef ref (Just ty)

単一化

Algorithm W や Algorithm M と異なるのはunify関数が置換を生成しない点である。その代わりにMetaTvの中に型を書き込むという副作用をもたらす。

algorithmJ/Unify.hs
unify :: (MonadFail m, MonadIO m) => Tau -> Tau -> Infer m ()
unify (TyVar tv1) (TyVar tv2) | tv1 == tv2 = return ()
unify (TyCon tc1) (TyCon tc2) | tc1 == tc2 = return ()
unify (TyFun arg1 res1) (TyFun arg2 res2) = do
        unify arg1 arg2
        unify res1 res2
unify (TyMeta tv1) (TyMeta tv2) | tv1 == tv2 = return ()
unify (TyMeta tv) ty = unifyVar tv ty
unify ty (TyMeta tv) = unifyVar tv ty
unify ty1 ty2 = failInf $ hsep ["Cannot unify types:", squotes $ pretty ty1, "with", squotes $ pretty ty2]

unifyVar :: (MonadFail m, MonadIO m) => MetaTv -> Tau -> Infer m ()
unifyVar tv1 ty2@(TyMeta tv2) = do
        mb_ty1 <- readMetaTv tv1
        mb_ty2 <- readMetaTv tv2
        case (mb_ty1, mb_ty2) of
                (Just ty1, _) -> unify ty1 ty2
                (Nothing, Just ty2) -> unify (TyMeta tv1) ty2
                (Nothing, Nothing) -> writeMetaTv tv1 ty2
unifyVar tv1 ty2 = do
        occursCheck tv1 ty2
        writeMetaTv tv1 ty2

occursCheck :: (MonadFail m, MonadIO m) => MetaTv -> Tau -> Infer m ()
occursCheck tv1 ty2 = do
        tvs2 <- getMetaTvs ty2
        when (tv1 `S.member` tvs2) $ failInf $ hsep ["Infinite type:", squotes $ pretty ty2]

Zonking

単一化における出現検査(occurs check)では、型からメタ変数を取り出すgetMetaTvsが必要があり、これにも副作用が伴う。

algorithmJ/Utils.hs
getMetaTvs :: MonadIO m => Type -> Infer m (S.Set MetaTv)
getMetaTvs ty = do
        ty' <- zonkType ty
        return (metaTvs ty')

metaTvs :: Type -> S.Set MetaTv
metaTvs TyVar{} = S.empty
metaTvs TyCon{} = S.empty
metaTvs (TyFun arg res) = metaTvs arg `S.union` metaTvs res
metaTvs (TyAll _ ty) = metaTvs ty
metaTvs (TyMeta tv) = S.singleton tv

zonkTypeMetaTvが型を保持している場合、TyMetaを取り除いて、その型に置き換える。この操作は Haskell の型推論などで、(なぜかわからないが) Zonking と呼ばれている。
Zonking については、ghc のGHC.Tc.Utils.TcMTypeで概要が述べられている。

Note [What is zonking?]
GHC relies heavily on mutability in the typechecker for efficient operation.
For this reason, throughout much of the type checking process meta type
variables (the MetaTv constructor of TcTyVarDetails) are represented by mutable
variables (known as TcRefs).

Zonking is the process of ripping out these mutable variables and replacing them
with a real Type. This involves traversing the entire type expression, but the
interesting part of replacing the mutable variables occurs in zonkTyVarOcc.

algorithmJ/Utils.hs
zonkType :: MonadIO m => Type -> Infer m Type
zonkType (TyVar tv) = return (TyVar tv)
zonkType (TyCon tc) = return (TyCon tc)
zonkType (TyFun arg res) = do
        arg' <- zonkType arg
        res' <- zonkType res
        return (TyFun arg' res')
zonkType (TyAll tvs ty) = do
        ty' <- zonkType ty
        return (TyAll tvs ty')
zonkType (TyMeta tv) = do
        mb_ty <- readMetaTv tv
        case mb_ty of
                Nothing -> return (TyMeta tv)
                Just ty -> do
                        ty' <- zonkType ty
                        writeMetaTv tv ty'
                        return ty'

型推論

副作用を扱うための技術的問題はクリアしたので、あとは[Milner,1978]の Algorithm J の定義に沿って実装すれば型推論アルゴリズムが得られる。


algorithmJ/Infer.hs
inferTau :: (MonadFail m, MonadIO m) => Term -> Infer m Tau
inferTau (TmLit LUnit) = return $ TyCon TUnit
inferTau (TmVar n) = do
        sigma <- lookupEnv n
        instantiate sigma
inferTau (TmApp fun arg) = do
        fun_ty <- inferTau fun
        arg_ty <- inferTau arg
        res_ty <- newTyVar
        unify fun_ty (TyFun arg_ty res_ty)
        return res_ty
inferTau (TmAbs var body) = do
        var_ty <- newTyVar
        body_ty <- extendEnv var var_ty (inferTau body)
        return $ TyFun var_ty body_ty
inferTau (TmLet var rhs body) = do
        var_ty <- inferTau rhs
        var_sigma <- generalize var_ty
        extendEnv var var_sigma (inferTau body)

一般化

一般化も若干修正が必要である。それは、フレッシュな名前を生成し、TyMetaTyVarに置き換えるためである。

algorithmJ/Infer.hs
generalize :: MonadIO m => Tau -> Infer m Sigma
generalize ty = do
        env_tvs <- mapM getMetaTvs =<< getEnvTypes
        res_tvs <- getMetaTvs ty
        let all_tvs = res_tvs `S.difference` mconcat env_tvs
        if null all_tvs then return ty else quantify (S.toList all_tvs) ty

quantify :: MonadIO m => [MetaTv] -> Tau -> Infer m Sigma
quantify tvs ty = do
        let new_bndrs = take (length tvs) allBinders
        zipWithM_ writeMetaTv tvs (map TyVar new_bndrs)
        ty' <- zonkType ty
        return $ TyAll new_bndrs ty'

allBinders :: [TyVar]
allBinders =
        [BoundTv [x] | x <- ['a' .. 'z']]
        ++ [BoundTv (x : show i) | i <- [1 :: Integer ..], x <- ['a' .. 'z']]

テスト

algoeithmJ/Main.hs
main :: IO ()
main = forM_ tests $ \t -> do
        ty <- inferType t
        putDoc $ pretty ty <> line

tests :: [Term]
tests =
        [ TmAbs "x" (TmVar "x") -- \x -> x
        , TmAbs "f" (TmAbs "x" (TmApp (TmVar "f") (TmVar "x"))) -- \f x -> f x
        , TmAbs "f" (TmLet "x" (TmLit LUnit) (TmApp (TmVar "f") (TmVar "x"))) -- \f -> let x = () in f x
        , TmAbs "x" (TmApp (TmLit LUnit) (TmVar "x")) -- \x -> () x
        ]
出力
∀a. a -> a
∀a b. (a -> b) -> a -> b
∀a. (() -> a) -> a
algorithmJ: user error (Cannot unify types: '()' with '$0 -> $1')

Algorithm J, Z のプログラムでは、フレッシュな名前生成を行うようにしたため、ついでに結果やエラーの表示をわかりやすくするために、Prettyprinter を導入した。

Algorithm M の副作用ありバージョン(Algorithm Z)

Algorithm J から inferTau を修正し、checkTau を追加するだけでよい。

algoeitmZ/Infer.hs
inferTau :: (MonadFail m, MonadIO m) => Term -> Infer m Tau
inferTau t = do
        exp_ty <- newTyVar
        checkTau t exp_ty
        return exp_ty

checkTau :: (MonadFail m, MonadIO m) => Term -> Tau -> Infer m ()
checkTau (TmLit LUnit) exp_ty = unify exp_ty (TyCon TUnit)
checkTau (TmVar n) exp_ty = do
        sigma <- lookupEnv n
        tau <- instantiate sigma
        unify exp_ty tau
checkTau (TmApp fun arg) exp_ty = do
        arg_ty <- newTyVar
        checkTau fun (TyFun arg_ty exp_ty)
        checkTau arg arg_ty
checkTau (TmAbs var body) exp_ty = do
        var_ty <- newTyVar
        body_ty <- newTyVar
        unify exp_ty (TyFun var_ty body_ty)
        extendEnv var var_ty $ checkTau body body_ty
checkTau (TmLet var rhs body) exp_ty = do
        var_ty <- inferSigma rhs
        extendEnv var var_ty $ checkTau body exp_ty

参考文献

Discussion