iTranslated by AI

The content below is an AI-generated translation. This is an experimental feature, and may contain errors. View original article

Working with Haskell's PrimMonad Part 2: A Brainfuck Interpreter Case Study

に公開

In the previous article, "Working with Haskell's PrimMonad Part 1," I introduced basic precautions and countermeasures when utilizing PrimMonad in Haskell.

The process discussed last time was one where the main parts were contained within ST. On the other hand, this time we will look at cases where part of the process includes arbitrary actions (provided by the user).

The sample code is available on GitHub at haskell-primmonad-example/brainfuck.

Subject: Brainfuck Interpreter

Brainfuck is a programming language for which it is very easy to implement. Let's try writing an interpreter for it in Haskell.

You can refer to GitHub for the full code, but I will extract the important parts here.

data Instruction = Arith !Word8 -- +-
                 | Pointer !Int -- ><
                 | Input        -- ,
                 | Output       -- .
                 | BeginLoop [Instruction] -- If value is 0, move to execution of [Instruction]
                 | EndLoop [Instruction] -- If value is not 0, move to execution of [Instruction]
                 deriving (Eq, Show)

parseAll :: String -> [Instruction]
parseAll = ... -- Omitted

data State s = State { pointer :: !Int
                     , array   :: !(VUM.MVector s Word8)
                     }

newState :: PrimMonad m => Int -> m (State (PrimState m))
newState !size = do
  a <- VUM.replicate size 0
  pure $ State { pointer = 0, array = a }

The main loop of the interpreter can be written as follows:

runIO :: [Instruction] -> State RealWorld -> IO (State RealWorld)
runIO insns State { pointer = initialPointer, array = array } = do
  let loop !pointer [] = pure pointer
      loop !pointer (insn : insns) = do
        case insn of
          Arith delta -> VUM.modify array (+ delta) pointer >> loop pointer insns
          Pointer delta -> loop (pointer + delta) insns
          Output -> do
            !c <- VUM.read array pointer
            putChar (chr $ fromIntegral c)
            loop pointer insns
          Input -> do
            !c <- getChar `catch` \(_ :: IOError) -> pure '\xFF'
            VUM.write array pointer (fromIntegral $ ord c)
            loop pointer insns
          BeginLoop alt -> do
            !v <- VUM.read array pointer
            if v == 0
              then loop pointer alt
              else loop pointer insns
          EndLoop alt -> do
            !v <- VUM.read array pointer
            if v == 0
              then loop pointer insns
              else loop pointer alt
  !pointer <- loop initialPointer insns
  pure $ State { pointer = pointer, array = array }

Now, in this implementation, I/O is hard-coded directly using putChar and getChar. However, sometimes you might want to redirect I/O to a file or read/write through a GUI. Alternatively, you might want to use the ContT monad to convert the Brainfuck program into a function like String -> m String.

Therefore, we will make the actions executed by Output and Input injectable from the outside as actions of type Word8 -> m () and m Word8. Since we need to read and write to the array, the monad shouldn't be just any monad, but one that satisfies the PrimMonad constraint. The generalized code looks like this:

runPM :: PrimMonad m => (Word8 -> m ()) -> m Word8 -> [Instruction] -> State (PrimState m) -> m (State (PrimState m))
runPM put get insns State { pointer = initialPointer, array = array } = do
  let loop !pointer [] = pure pointer
      loop !pointer (insn : insns) = do
        case insn of
          Arith delta -> VUM.modify array (+ delta) pointer >> loop pointer insns
          Pointer delta -> loop (pointer + delta) insns
          Output -> do
            !c <- VUM.read array pointer
            put c
            loop pointer insns
          Input -> do
            !c <- get
            VUM.write array pointer c
            loop pointer insns
          BeginLoop alt -> do
            !v <- VUM.read array pointer
            if v == 0
              then loop pointer alt
              else loop pointer insns
          EndLoop alt -> do
            !v <- VUM.read array pointer
            if v == 0
              then loop pointer insns
              else loop pointer alt
  !pointer <- loop initialPointer insns
  pure $ State { pointer = pointer, array = array }

Just like last time, we will also prepare a version with INLINABLE specified:

{-# INLINABLE runPMInlinable #-}
runPMInlinable :: PrimMonad m => (Word8 -> m ()) -> m Word8 -> [Instruction] -> State (PrimState m) -> m (State (PrimState m))

Aside: Utilizing the ContT Monad Transformer

As an aside, I have implemented the part about "wanting to use the ContT monad to convert a Brainfuck program into a function like String -> m String." You can try it out as cps-bf in the sample program.

In this case, the result of the interpreter can be retrieved as a value of the following type:

data Answer s = NeedMoreInput String (String -> ST s (Answer s))
              | Done String

If the execution completes (without requiring input), it returns Done (output), and if additional input is required, it returns NeedMoreInput (the output so far) (a function to continue).

The code looks like this:

type K s = ContT (Answer s) (ST s)
type M s = ReaderT (STRef s String, STRef s String, String -> K s String) (K s)

put :: Word8 -> M s ()
put c = do
  let c' = chr $ fromIntegral c
  (_, bufOut, _) <- ask
  lift $ lift $ modifySTRef' bufOut (c' :)

get :: M s Word8
get = do
  (bufIn, bufOut, needMoreInput) <- ask
  input <- lift $ lift $ readSTRef bufIn
  case input of
    c:cs -> do
      lift $ lift $ writeSTRef bufIn cs
      pure $ fromIntegral $ ord c
    [] -> do
      output <- fmap reverse $ lift $ lift $ readSTRef bufOut
      lift $ lift $ writeSTRef bufOut []
      newInput <- lift $ needMoreInput output
      case newInput of
        c:cs -> do
          lift $ lift $ writeSTRef bufIn cs
          pure $ fromIntegral $ ord c
        [] -> pure 0xFF

interpret :: M s (Brainfuck.State s) -> ST s (Answer s)
interpret run = do
  bufIn <- newSTRef []
  bufOut <- newSTRef []
  let action = callCC $ \k -> do
        let needMoreInput output = callCC $ \l -> k (NeedMoreInput output (\newInput -> evalContT (l newInput)))
        _finalState <- runReaderT run (bufIn, bufOut, needMoreInput)
        output <- fmap reverse $ lift $ readSTRef bufOut
        pure $ Done output
  evalContT action

main :: IO ()
main = do
  args <- getArgs
  ans0 <- case args of
    "pm":programName:_ -> do
      program <- Brainfuck.parseAll <$> readFile programName
      state <- Brainfuck.newState 67108864
      ans0 <- stToIO $ interpret $ Brainfuck.runPM put get program state
    "pmi":programName:_ -> do
      program <- Brainfuck.parseAll <$> readFile programName
      state <- Brainfuck.newState 67108864
      stToIO $ interpret $ Brainfuck.runPMInlinable put get program state
    "mixed":programName:_ -> do
      program <- Brainfuck.parseAll <$> readFile programName
      state <- Brainfuck.newState 67108864
      stToIO $ interpret $ Brainfuck.runPMMixed put get program state
    _ -> do hPutStrLn stderr "usage: cps-bf pm/pmi/mixed <program.bf>"
            exitFailure
  -- Usage example
  case ans0 of
    Done s -> putStrLn $ "output[0]: " ++ s
    NeedMoreInput s more -> do
      putStrLn $ "output[0]: " ++ s
      ans1 <- stToIO $ more "Hello!"
      case ans1 of
        Done s1 -> putStrLn $ "output[1]: " ++ s1
        NeedMoreInput s1 _ -> do
          putStrLn $ "output[1]: " ++ s1

Well, this is just a demo, so just recognize that "ContT seems powerful."

Checking Execution Speed

Let's run a non-trivial Brainfuck program and measure the execution time. For this, we will use a program that performs FizzBuzz up to 15. Since writing this by hand is difficult, it was written in C and compiled using ELVM's 8cc.js. The resulting Brainfuck program is available as fizzbuzz15.bf in the sample Git repository.

// Compile this into Brainfuck using ELVM+8cc.js
#include <stdio.h>

// printf is slow, so we'll write our own integer-to-string conversion
void putint(int n)
{
    if (n == 0) {
        putchar('0');
        return;
    } else if (n < 0) {
        putchar('-');
        n = -n;
    }
    char buf[10]; // This is a quick hack, so let's hope it doesn't overflow
    char *p = buf;
    while (n != 0) {
        *p++ = '0' + (n % 10);
        n = n / 10;
    }
    while (p != buf) {
        putchar(*--p);
    }
}

int main(void)
{
    for (int i = 1; i <= 15; i++) {
        if (i % 5) {
            if (i % 3) {
                // printf("%d\n", i);
                putint(i);
                putchar('\n');
            } else {
                puts("Fizz");
            }
        } else {
            puts("FizzBuzz" + i * i % 3 * 4);
        }
    }
    return 0;
}

On my environment (Apple M4 Pro), running this with the initial runIO implementation took 35 seconds.

$ cabal exec -O2 time simple-bf io fizzbuzz15.bf
1
2
Fizz
4
Buzz
Fizz
7
8
Fizz
Buzz
11
Fizz
13
14
FizzBuzz
       35.62 real        35.28 user         0.30 sys

As with the previous article, using runPM without specialization makes it slower, while using runPMInlinable with INLINABLE allows it to be specialized and run faster. For reference, runPM without specialization took 190 seconds on my environment.

$ cabal exec -O2 time simple-bf pm fizzbuzz15.bf 
1
2
Fizz
4
Buzz
Fizz
7
8
Fizz
Buzz
11
Fizz
13
14
FizzBuzz
      190.13 real       188.15 user         1.85 sys

Improving the Speed of the ContT Version

On the other hand, the interpreter using ContT took 40 seconds even after specialization.

$ cabal exec -O2 time cps-bf pmi fizzbuzz15.bf
output[0]: 1
2
Fizz
4
Buzz
Fizz
7
8
Fizz
Buzz
11
Fizz
13
14
FizzBuzz

       40.84 real        40.12 user         0.37 sys

Is this an unavoidable cost of using ContT? Can we improve it a little more while still enjoying the power of ContT?

If we consider how the interpreter operates, the execution of arithmetic instructions, pointer instructions, and loops (e.g., +-><[] in Brainfuck) likely occurs far more frequently than I/O instructions (e.g., ,. in Brainfuck). Therefore, let's consider executing instructions other than I/O in the ST monad and only executing I/O instructions in the provided monad. The code would look like this:

runPMMixed :: forall m. PrimMonad m => (Word8 -> m ()) -> m Word8 -> [Instruction] -> State (PrimState m) -> m (State (PrimState m))
runPMMixed put get insns State { pointer = initialPointer, array = array } = do
      -- loopST executes instructions other than I/O in the ST monad.
      -- When I/O or the end of the program is encountered, it returns the execution state at that point (remaining instructions and pointer position).
  let loopST :: Int -> [Instruction] -> ST (PrimState m) ([Instruction], Int)
      loopST !pointer [] = pure ([], pointer)
      loopST !pointer insns0@(insn : insns) = do
        case insn of
          Arith delta -> VUM.modify array (+ delta) pointer >> loopST pointer insns
          Pointer delta -> loopST (pointer + delta) insns
          Output -> pure (insns0, pointer)
          Input -> pure (insns0, pointer)
          BeginLoop alt -> do
            !v <- VUM.read array pointer
            if v == 0
              then loopST pointer alt
              else loopST pointer insns
          EndLoop alt -> do
            !v <- VUM.read array pointer
            if v == 0
              then loopST pointer insns
              else loopST pointer alt
      -- loop executes I/O in the provided m monad.
  let loop !pointer [] = pure pointer
      loop !pointer (Output : insns) = do
        !c <- VUM.read array pointer
        put c
        continue pointer insns
      loop !pointer (Input : insns) = do
        !c <- get
        VUM.write array pointer c
        continue pointer insns
      loop !pointer insns = continue pointer insns
      continue pointer insns = do
        (insns', pointer') <- stToPrim $ loopST pointer insns
        loop pointer' insns'
  !pointer <- loop initialPointer insns
  pure $ State { pointer = pointer, array = array }

Using this implementation, the execution completed in 30 seconds, which is about the same as when written directly in IO:

$ cabal exec -O2 time cps-bf mixed fizzbuzz15.bf
output[0]: 1
2
Fizz
4
Buzz
Fizz
7
8
Fizz
Buzz
11
Fizz
13
14
FizzBuzz

       30.90 real        30.61 user         0.25 sys

In other words, even in situations where you must execute actions provided by the user, if the majority of the processing can be done in ST, you can minimize the increase in cost due to abstraction by writing that part in ST.

Analogies in JIT Compilation

In the world of JIT compilation, returning control from JIT-compiled code back to the interpreter for certain reasons (such as when JIT compilation assumptions are no longer valid or when encountering a process that cannot be JIT-compiled) is often called deoptimization or bailout. The runPMMixed function in this article can be viewed in a similar light: "returning control from a high-speed interpreter written in the ST monad to an interpreter over an abstract m due to some reason (needing to execute actions provided by the user)." In that sense, it could be considered a variant of deoptimization or bailout.

Discussion