iTranslated by AI

The content below is an AI-generated translation. This is an experimental feature, and may contain errors. View original article
😸

Building an EDSL in Haskell: LLVM Part — JIT Compilation

に公開

Series:

This article is for the 13th day of the Language Implementation Advent Calendar 2024.


In Building an EDSL in Haskell: StableName Edition, we saw how to recover calculation sharing using StableName.

This time, we will look at how to JIT compile the arithmetic DSL we created using LLVM. Sample code is provided in haskell-dsl-example/llvm.

Overview

How to Call LLVM

There are several ways to call LLVM from Haskell.

First, there is a method where you write out the LLVM IR to a file and call LLVM as a command. For example, GHC writes LLVM IR into .ll files and uses LLVM's opt and llc commands to obtain object files. With this method, the compiler only needs to handle text processing and file I/O, so there is no need to worry about tedious FFI.

Next, there is a method where you call LLVM as a C/C++ library. To call it from Haskell, FFI is required. This method is suitable for JIT compilation.

In this article, we will take the latter approach—calling LLVM via FFI. However, since writing the FFI part from scratch is a lot of work, we will use existing bindings.

LLVM Bindings

There are several bindings available to call LLVM from Haskell. Here, we will use the llvm-hs family.

The package structure of the llvm-hs family is as follows:

  • llvm-hs-pure: A pure Haskell part that does not depend on the C++ side. It allows you to construct LLVM IR.
  • llvm-hs: Bindings to LLVM implemented in C++ (FFI).
  • llvm-hs-pretty: A pure Haskell pretty printer, but it is unmaintained. Pretty printing itself can be done using the llvm-hs package (which depends on LLVM written in C++).

Since only older versions of llvm-hs are available on Hackage, we will use the versions from GitHub. At the time of writing, the llvm-15 branch, which supports LLVM 15, is the latest, and we will use it. However, the current version does not seem to support GHC 9.8 or later (due to a conflict with unzip imported in LLVM.Prelude). Also, newer versions of Cabal do not seem to work; it needs to be 3.10 or lower. Therefore, the content of this article has been verified to work with GHC 9.6.6/Cabal 3.10. Incidentally, looking at the GitHub PRs for llvm-hs, there seem to be versions that have been made compatible with newer GHC and Cabal.

To fetch dependency packages from Git in a Cabal project, describe the source-repository-package in cabal.project as follows:

cabal.project
packages: ./*.cabal

source-repository-package
    type: git
    location: https://github.com/llvm-hs/llvm-hs.git
    tag: 5bca2c1a2a3aa98ecfb19181e7a5ebbf3e212b76
    subdir: llvm-hs-pure

source-repository-package
    type: git
    location: https://github.com/llvm-hs/llvm-hs.git
    tag: 5bca2c1a2a3aa98ecfb19181e7a5ebbf3e212b76
    subdir: llvm-hs

Building llvm-hs requires LLVM. Specifically, the llvm-config-15 or llvm-config command must be accessible during the configure phase. If you use Homebrew or similar, the PATH to these is not set by default, so you can describe the following in cabal.project.local so that llvm-hs can find llvm-config:

cabal.project.local
package llvm-hs
    extra-prog-path: /opt/homebrew/opt/llvm@15/bin

In the case of Homebrew, the specific location can be found with echo $(brew --prefix llvm@15)/bin.

Haddock for the llvm-15 branch of llvm-hs is not available on Hackage, so if you want to see the documentation, please git clone it yourself and run cabal haddock.

Reference Code

There are several samples in llvm-hs-examples, but the master branch is for LLVM 9, and otherwise there is only an llvm-12 branch, so they are slightly old.

The one closest to our arithmetic DSL is llvm-hs-examples/arith/Arith.hs.

Here, referring to arith/Arith.hs, I will present code compatible with LLVM 15.

Calling Compiled Functions

Suppose we JIT compiled a function and obtained a function pointer of type FunPtr. To call this from Haskell, we define a function using foreign import ccall "dynamic" like this:

foreign import ccall unsafe "dynamic"
  mkDoubleFn :: FunPtr (Double -> Double) -> (Double -> Double)

Using the mkDoubleFn function defined this way, you can convert a function pointer into a Haskell function.

Unsafe FFI has low overhead, but it is intended for processes that complete in a short time. For processes that take a long time, it might be better to use safe FFI. I touched on the types of FFI in my previous post "【Low-level Haskell】I want to get as close as possible to inline assembly even in Haskell (GHC)!" (Japanese).

In this method, the type of the function must be determined at the time of writing the code. If you want to JIT compile a function that supports an arbitrary number of arguments, you should either make the interface with Haskell a pointer to a struct or an array, such as a function of type Ptr SomeStruct -> IO (), or use some libffi binding.

Practice

Adding Features to the DSL

The DSL we created last time only supported basic arithmetic. However, since I want to demonstrate function calls for the LLVM demo, I will add several mathematical functions here. I will also add a type parameter so that we can see that the variable type is Double.

src/Exp.hs
data UnaryOp = Negate
             | Abs
             | Exp
             | Log
             | Sin
             | Cos
             deriving (Eq, Show)

data Exp a where
  Const :: a -> Exp a
  Var :: Exp Double
  Unary :: UnaryOp -> Exp Double -> Exp Double
  Add :: Exp Double -> Exp Double -> Exp Double
  Sub :: Exp Double -> Exp Double -> Exp Double
  Mul :: Exp Double -> Exp Double -> Exp Double
  Div :: Exp Double -> Exp Double -> Exp Double

deriving instance Show a => Show (Exp a)

instance Num (Exp Double) where
  (+) = Add
  (-) = Sub
  (*) = Mul
  fromInteger = Const . fromInteger
  negate = Unary Negate
  abs = Unary Abs
  signum = undefined

instance Fractional (Exp Double) where
  (/) = Div
  fromRational = Const . fromRational

instance Floating (Exp Double) where
  pi = Const pi
  exp = Unary Exp
  log = Unary Log
  sin = Unary Sin
  cos = Unary Cos
  asin = undefined
  acos = undefined
  atan = undefined
  sinh = undefined
  cosh = undefined
  asinh = undefined
  acosh = undefined
  atanh = undefined

eval :: Exp a -> Double -> a
eval (Const k) _ = k
eval Var x = x
eval (Unary Negate a) x = - eval a x
eval (Unary Abs a) x = abs (eval a x)
eval (Unary Exp a) x = exp (eval a x)
eval (Unary Log a) x = log (eval a x)
eval (Unary Sin a) x = sin (eval a x)
eval (Unary Cos a) x = cos (eval a x)
eval (Add a b) x = eval a x + eval b x
eval (Sub a b) x = eval a x - eval b x
eval (Mul a b) x = eval a x * eval b x
eval (Div a b) x = eval a x / eval b x

The part that recovers sharing is as follows:

src/ExpS.hs
type Id = Int -- Variable identifier

data Value a = ConstV a
             | VarV Id
             deriving Show

-- Expression on the right side of a let
data SimpleExp a where
  UnaryS :: UnaryOp -> Value Double -> SimpleExp Double
  AddS :: Value Double -> Value Double -> SimpleExp Double
  SubS :: Value Double -> Value Double -> SimpleExp Double
  MulS :: Value Double -> Value Double -> SimpleExp Double
  DivS :: Value Double -> Value Double -> SimpleExp Double

deriving instance Show (SimpleExp a)

-- Expression that can represent sharing
data ExpS a = Let Id (SimpleExp Double) (ExpS a)
            | Value (Value a)
            deriving Show

-- State: (Next identifier to use, List of variables and expressions defined so far, Mapping from StableName to variable ID)
type M = StateT (Int, [(Id, SimpleExp Double)], HM.HashMap Name Id) IO

recoverSharing :: Exp Double -> IO (ExpS Double)
recoverSharing expr = do
  -- Omitted

Please refer to the sample code repository for the complete source code.

Code Generation with llvm-hs-pure

The part that generates LLVM IR can be done with llvm-hs-pure, which is a pure Haskell library. I will present the code first:

src/Codegen.hs
import qualified LLVM.AST as AST (Module, Operand(ConstantOperand))
import qualified LLVM.AST.Constant as AST (Constant(Float))
import qualified LLVM.AST.Float as AST (SomeFloat(Double))
import qualified LLVM.AST.Type as AST (Type(FunctionType, resultType, argumentTypes, isVarArg), double)
import qualified LLVM.IRBuilder.Instruction as IR (fneg, call, fadd, fsub, fmul, fdiv, ret)
import qualified LLVM.IRBuilder.Module as IR (MonadModuleBuilder, ParameterName(ParameterName), buildModule, extern, function)
import qualified LLVM.IRBuilder.Monad as IR (MonadIRBuilder, named)

doubleFnType :: AST.Type
doubleFnType = AST.FunctionType
  { AST.resultType = AST.double
  , AST.argumentTypes = [AST.double]
  , AST.isVarArg = False
  }

codegen :: ExpS Double -> AST.Module
codegen expr = IR.buildModule "dsl.ll" $ do
  absFn <- IR.extern "llvm.fabs.f64" [AST.double] AST.double
  expFn <- IR.extern "llvm.exp.f64" [AST.double] AST.double
  logFn <- IR.extern "llvm.log.f64" [AST.double] AST.double
  sinFn <- IR.extern "llvm.sin.f64" [AST.double] AST.double
  cosFn <- IR.extern "llvm.cos.f64" [AST.double] AST.double

  let goValue :: IntMap.IntMap AST.Operand -> Value Double -> AST.Operand
      goValue _ (ConstV x) = AST.ConstantOperand $ AST.Float $ AST.Double x
      goValue env (VarV i) = env IntMap.! i
      goSimpleExp :: (IR.MonadIRBuilder m, IR.MonadModuleBuilder m) => IntMap.IntMap AST.Operand -> SimpleExp Double -> m AST.Operand
      -- goSimpleExp :: IntMap.IntMap AST.Operand -> SimpleExp Double -> LLVM.IRBuilder.Module.IRBuilderT LLVM.IRBuilder.Module.ModuleBuilder AST.Operand
      goSimpleExp env (UnaryS Negate x) = IR.fneg (goValue env x)
      goSimpleExp env (UnaryS Abs x) = IR.call doubleFnType absFn [(goValue env x, [])]
      goSimpleExp env (UnaryS Exp x) = IR.call doubleFnType expFn [(goValue env x, [])]
      goSimpleExp env (UnaryS Log x) = IR.call doubleFnType logFn [(goValue env x, [])]
      goSimpleExp env (UnaryS Sin x) = IR.call doubleFnType sinFn [(goValue env x, [])]
      goSimpleExp env (UnaryS Cos x) = IR.call doubleFnType cosFn [(goValue env x, [])]
      goSimpleExp env (AddS x y) = IR.fadd (goValue env x) (goValue env y)
      goSimpleExp env (SubS x y) = IR.fsub (goValue env x) (goValue env y)
      goSimpleExp env (MulS x y) = IR.fmul (goValue env x) (goValue env y)
      goSimpleExp env (DivS x y) = IR.fdiv (goValue env x) (goValue env y)
      goExp :: (IR.MonadIRBuilder m, IR.MonadModuleBuilder m) => IntMap.IntMap AST.Operand -> ExpS Double -> m AST.Operand
      -- goExp :: IntMap.IntMap AST.Operand -> ExpS Double -> LLVM.IRBuilder.Module.IRBuilderT LLVM.IRBuilder.Module.ModuleBuilder AST.Operand
      goExp env (Let i x body) = do
        v <- goSimpleExp env x `IR.named` SBS.toShort (BS.Char8.pack ("x" ++ show i))
        goExp (IntMap.insert i v env) body
      goExp env (Value v) = pure $ goValue env v

  let xparam :: IR.ParameterName
      xparam = IR.ParameterName "x"
  _ <- IR.function "f" [(AST.double, xparam)] AST.double $ \[arg] -> do
    result <- goExp (IntMap.singleton 0 arg) expr
    IR.ret result
  pure ()

The codegen function takes the syntax tree of our DSL and creates a module containing the LLVM IR. The module includes function definitions and declarations for external functions.

llvm-hs-pure provides an EDSL (monad) for constructing LLVM IR. This example is too simple to feel like a proper DSL, but it will be clearer in the auto-vectorization example discussed later.

An LLVM IR operand (variable or immediate value) has the type LLVM.AST.Operand. Variables are obtained as return values of functions corresponding to IR instructions.

By using the LLVM.IRBuilder.Monad.named function, you can assign names to defined variables and labels. In practice, suffixes like numbers are added, resulting in something like %x_0.

Pretty Printing the LLVM IR

Let's display the generated LLVM IR in text format. Since the llvm-hs-pretty package is not available, we will use the llvm-hs package. We use the LLVM.Module.moduleLLVMAssembly function.

pp/Main.hs
import qualified Data.ByteString as BS
import qualified LLVM.Context
import qualified LLVM.Module

main :: IO ()
main = do
  let f x = (x + 1)^10
  expr <- recoverSharing (f Var)
  let code = codegen expr
  LLVM.Context.withContext $ \context ->
    LLVM.Module.withModuleFromAST context code $ \mod' -> do
      asm <- LLVM.Module.moduleLLVMAssembly mod'
      BS.putStr asm

An example output is as follows:

; ModuleID = 'dsl.ll'
source_filename = "<string>"

declare double @llvm.fabs.f64(double)

declare double @llvm.exp.f64(double)

declare double @llvm.log.f64(double)

declare double @llvm.sin.f64(double)

declare double @llvm.cos.f64(double)

define double @f(double %x_0) {
  %x1_0 = fadd double %x_0, 1.000000e+00
  %x2_0 = fmul double %x1_0, %x1_0
  %x3_0 = fmul double %x2_0, %x2_0
  %x4_0 = fmul double %x3_0, %x3_0
  %x5_0 = fmul double %x4_0, %x2_0
  ret double %x5_0
}

It looks like it is working correctly.

JIT Compilation

JIT compilation is performed as follows:

app/Main.hs
import qualified LLVM.CodeGenOpt
import qualified LLVM.CodeModel
import qualified LLVM.Context
import qualified LLVM.Linking
import qualified LLVM.Module
import qualified LLVM.OrcJIT as JIT
import qualified LLVM.Relocation
import qualified LLVM.Passes
import qualified LLVM.Target

foreign import ccall unsafe "dynamic"
  mkDoubleFun :: FunPtr (Double -> Double) -> (Double -> Double)

withSimpleJIT :: NFData a => ExpS Double -> ((Double -> Double) -> a) -> IO (Maybe a)
withSimpleJIT expr doFun = do
  LLVM.Context.withContext $ \context -> do
    _ <- LLVM.Linking.loadLibraryPermanently Nothing
    LLVM.Module.withModuleFromAST context (codegen expr) $ \mod' -> do
      LLVM.Target.withHostTargetMachine LLVM.Relocation.PIC LLVM.CodeModel.JITDefault LLVM.CodeGenOpt.Default $ \targetMachine -> do
        asm <- LLVM.Module.moduleLLVMAssembly mod'
        putStrLn "*** Before optimization ***"
        BS.putStr asm
        putStrLn "***************************"

        let passSetSpec = LLVM.Passes.PassSetSpec
                          { LLVM.Passes.passes = [LLVM.Passes.CuratedPassSet 2]
                          , LLVM.Passes.targetMachine = Nothing -- Just targetMachine
                          }
        LLVM.Passes.runPasses passSetSpec mod'

        asm' <- LLVM.Module.moduleLLVMAssembly mod'
        putStrLn "*** After optimization ***"
        BS.putStr asm'
        putStrLn "**************************"

        tasm <- LLVM.Module.moduleTargetAssembly targetMachine mod'
        putStrLn "*** Target assembly ***"
        BS.putStr tasm
        putStrLn "***********************"

        JIT.withExecutionSession $ \executionSession -> do
          dylib <- JIT.createJITDylib executionSession "myDylib"
          JIT.withClonedThreadSafeModule mod' $ \threadSafeModule -> do
            objectLayer <- JIT.createRTDyldObjectLinkingLayer executionSession
            compileLayer <- JIT.createIRCompileLayer executionSession objectLayer targetMachine
            JIT.addDynamicLibrarySearchGeneratorForCurrentProcess compileLayer dylib
            JIT.addModule threadSafeModule dylib compileLayer

            sym <- JIT.lookupSymbol executionSession compileLayer dylib "f"
            case sym of
              Left (JIT.JITSymbolError err) -> do
                print err
                pure Nothing
              Right (JIT.JITSymbol fnAddr _jitSymbolFlags) -> do
                let fn = mkDoubleFun . castPtrToFunPtr $ wordPtrToPtr fnAddr
                Just <$> evaluate (force $ doFun fn)

Since I merely transcribed this from llvm-hs-examples/arith without fully understanding it, I cannot provide a detailed explanation. Nonetheless, I will explain a few points within the scope of my understanding.

The return value of the codegen function we just created (the LLVM.AST.Module module) is passed to LLVM.Module.withModuleFromAST.

As we did earlier, the text representation of the created LLVM IR can be obtained with the LLVM.Module.moduleLLVMAssembly function. Here, it is displayed both before and after optimization.

The generated machine-dependent assembly code can be obtained using the LLVM.Module.moduleTargetAssembly function.

In the llvm-hs-examples sample, CodeModel.Default is used as an argument for LLVM.Target.withHostTargetMachine, but this needs to be set to JITDefault; otherwise, external function calls (like exp or sin) will fail on AArch64.

LLVM.Passes.runPasses is what executes the optimization. This function destructively updates the module passed as an argument. Please note that this must be executed before entering JIT compilation; otherwise, it won't be reflected in the generated machine code. I struggled because I missed this, which caused the auto-vectorization mentioned later to be ineffective (no performance improvement was observed).

The function pointer created by JIT compilation can be obtained as fnAddr :: WordPtr. This is converted to a FunPtr and passed to a foreign import ccall "dynamic" function to convert it into a Haskell function.

The created function disappears once it leaves the scope. The function is consumed by the callback function doFun, and to ensure its evaluation is completed before exiting the scope, force from Control.DeepSeq is used.

Usage

The usage side looks like this:

app/Main.hs
main :: IO ()
main = do
  let f x = (x + 1)^10 * (x + 1) + cos x
  expr <- recoverSharing (f Var)
  result <- withSimpleJIT expr (\f' -> f' 1.0)
  case result of
    Nothing -> pure ()
    Just result' -> print result'

I deliberately made x + 1 appear twice to see the effect of optimization. However, since GHC's optimization would merge them via common subexpression elimination if enabled, I'll turn off GHC's optimization.

$ cabal run -O0 example-run
*** Before optimization ***
; ModuleID = 'dsl.ll'
source_filename = "<string>"

declare double @llvm.fabs.f64(double)

declare double @llvm.exp.f64(double)

declare double @llvm.log.f64(double)

declare double @llvm.sin.f64(double)

declare double @llvm.cos.f64(double)

define double @f(double %x_0) {
  %x1_0 = fadd double %x_0, 1.000000e+00
  %x2_0 = fmul double %x1_0, %x1_0
  %x3_0 = fmul double %x2_0, %x2_0
  %x4_0 = fmul double %x3_0, %x3_0
  %x5_0 = fmul double %x4_0, %x2_0
  %x6_0 = fadd double %x_0, 1.000000e+00
  %x7_0 = fmul double %x5_0, %x6_0
  %x8_0 = call double @llvm.cos.f64(double %x_0)
  %x9_0 = fadd double %x7_0, %x8_0
  ret double %x9_0
}
***************************
*** After optimization ***
; ModuleID = 'dsl.ll'
source_filename = "<string>"

declare double @llvm.cos.f64(double)

define double @f(double %x_0) local_unnamed_addr {
  %x1_0 = fadd double %x_0, 1.000000e+00
  %x2_0 = fmul double %x1_0, %x1_0
  %x3_0 = fmul double %x2_0, %x2_0
  %x4_0 = fmul double %x3_0, %x3_0
  %x5_0 = fmul double %x2_0, %x4_0
  %x7_0 = fmul double %x1_0, %x5_0
  %x8_0 = tail call double @llvm.cos.f64(double %x_0)
  %x9_0 = fadd double %x7_0, %x8_0
  ret double %x9_0
}
**************************

You can see that fadd double %x_0, 1.000000e+00, which appeared in two places before optimization, was consolidated into one after optimization.

The generated assembly and the result of the execution (at x = 1) are as follows:

*** Target assembly ***
	.section	__TEXT,__text,regular,pure_instructions
	.build_version macos, 15, 0
	.globl	_f
	.p2align	2
_f:
	.cfi_startproc
	stp	d9, d8, [sp, #-32]!
	.cfi_def_cfa_offset 32
	stp	x29, x30, [sp, #16]
	.cfi_offset w30, -8
	.cfi_offset w29, -16
	.cfi_offset b8, -24
	.cfi_offset b9, -32
	fmov	d1, #1.00000000
	fadd	d1, d0, d1
	fmul	d2, d1, d1
	fmul	d3, d2, d2
	fmul	d3, d3, d3
	fmul	d2, d2, d3
	fmul	d8, d1, d2
Lloh0:
	adrp	x8, _cos@GOTPAGE
Lloh1:
	ldr	x8, [x8, _cos@GOTPAGEOFF]
	blr	x8
	fadd	d0, d8, d0
	ldp	x29, x30, [sp, #16]
	ldp	d9, d8, [sp], #32
	ret
	.loh AdrpLdrGot	Lloh0, Lloh1
	.cfi_endproc

.subsections_via_symbols
***********************
2048.540302305868

Since this was executed on AArch64 macOS, AArch64 assembly source is produced, and an underscore is prefixed to the function name. Other environments would likely yield different results.

Using Auto-vectorization

We've seen that LLVM performs common subexpression elimination, but let's try other optimizations. For example, auto-vectorization is a feature quite distant from standard Haskell, but wouldn't it be easier to utilize if we generate LLVM IR ourselves like this?

Specifically, we define a function like the following:

// Pseudocode in C
void f(int size, double * restrict resultArray, const double *inputArray)
{
    for (int i = 0; i < size; ++i) {
        double x = inputArray[i];
        resultArray[i] = /* Calculation using x */;
    }
}

Then, the idea is that LLVM will transform the loop to use SIMD instructions through auto-vectorization. From a Haskell perspective, the function type would be Int32 -> Ptr Double -> Ptr Double -> IO (), but it's best to wrap it so it can be used as VS.Vector Double -> VS.Vector Double using storable vectors (Data.Vector.Storable). There are several variations of Data.Vector, but when taking addresses for use with FFI, storable vectors are generally used. Primitive vectors and the unboxed vectors that use them are difficult to use with FFI because their addresses can change due to GC (they can be used with FFI if you pin them upon allocation or stop the GC during FFI, but that's for advanced users).

The part that generates the LLVM IR for the loop is as follows:

src/LoopCodegen.hs
{-# LANGUAGE RecursiveDo #-}

codegen :: ExpS Double -> AST.Module
codegen expr = IR.buildModule "dsl.ll" $ do
  absFn <- IR.extern "llvm.fabs.f64" [AST.double] AST.double
  expFn <- IR.extern "llvm.exp.f64" [AST.double] AST.double
  logFn <- IR.extern "llvm.log.f64" [AST.double] AST.double
  sinFn <- IR.extern "llvm.sin.f64" [AST.double] AST.double
  cosFn <- IR.extern "llvm.cos.f64" [AST.double] AST.double

  let goValue :: IntMap.IntMap AST.Operand -> Value Double -> AST.Operand
      -- Omitted
      goSimpleExp :: (IR.MonadIRBuilder m, IR.MonadModuleBuilder m) => IntMap.IntMap AST.Operand -> SimpleExp Double -> m AST.Operand
      -- Omitted
      goExp :: (IR.MonadIRBuilder m, IR.MonadModuleBuilder m) => IntMap.IntMap AST.Operand -> ExpS Double -> m AST.Operand
      -- Omitted

  let sizeName :: IR.ParameterName
      sizeName = IR.ParameterName "size"
  let resultArrayName :: IR.ParameterName
      resultArrayName = IR.ParameterName "resultArray"
  let inputArrayName :: IR.ParameterName
      inputArrayName = IR.ParameterName "inputArray"
  _ <- IR.function "f" [(AST.i32, sizeName), (AST.ptr, resultArrayName), (AST.ptr, inputArrayName)] AST.void $ \[size, resultArray, inputArray] -> mdo
    prologue <- IR.block
    IR.br loop

    loop <- IR.block `IR.named` "loop"
    counter <- IR.phi [(IR.int32 0, prologue), (nextCounter, loopBody)] `IR.named` "counter"
    lt <- IR.icmp AST.SLT counter size `IR.named` "lt"
    IR.condBr lt loopBody epilogue

    loopBody <- IR.block `IR.named` "loopBody"
    xPtr <- IR.gep (AST.ArrayType 0 AST.double) inputArray [IR.int32 0, counter] `IR.named` "xPtr"
    x <- IR.load AST.double xPtr 0 `IR.named` "x"
    result <- goExp (IntMap.singleton 0 x) expr `IR.named` "result"
    resultPtr <- IR.gep (AST.ArrayType 0 AST.double) resultArray [IR.int32 0, counter] `IR.named` "resultPtr"
    IR.store resultPtr 0 result
    nextCounter <- IR.add counter (IR.int32 1) `IR.named` "nextCounter"
    IR.br loop

    epilogue <- IR.block `IR.named` "epilogue"
    IR.retVoid

  pure ()

I won't provide a detailed explanation here. Please consult the LLVM Language Reference Manual. Here are a few supplemental points:

  • Most functions in llvm-hs-pure that output LLVM IR instructions are fairly intuitive. However, note that LLVM's getelementptr is abbreviated as gep.
  • By using the mdo syntax from the RecursiveDo extension, we can refer to labels defined further down in the code.
  • The 0 passed to load and store represents alignment; it seems that specifying 0 causes the default for the type to be used.

The JIT compilation part is largely the same as before, but there are a few points to note:

loop/Main.hs
foreign import ccall unsafe "dynamic"
  mkDoubleArrayFun :: FunPtr (Int32 -> Ptr Double -> Ptr Double -> IO ()) -> (Int32 -> Ptr Double -> Ptr Double -> IO ())

withArrayJIT :: NFData a => ExpS Double -> ((VS.Vector Double -> VS.Vector Double) -> IO a) -> IO (Maybe a)
withArrayJIT expr doFun = do
  -- Omitted
        let passSetSpec = LLVM.Passes.PassSetSpec
                          { LLVM.Passes.passes = [LLVM.Passes.CuratedPassSet 2]
                          , LLVM.Passes.targetMachine = Just targetMachine
                          }
        LLVM.Passes.runPasses passSetSpec mod'

            -- Omitted

            sym <- JIT.lookupSymbol executionSession compileLayer dylib "f"
            case sym of
              Left (JIT.JITSymbolError err) -> do
                print err
                pure Nothing
              Right (JIT.JITSymbol fnAddr _jitSymbolFlags) -> do
                let ptrFn = mkDoubleArrayFun . castPtrToFunPtr $ wordPtrToPtr fnAddr
                    vecFn !inputVec = unsafePerformIO $ do
                      let !n = VS.length inputVec
                      resultVec <- VSM.unsafeNew n
                      VS.unsafeWith inputVec $ \inputPtr ->
                        VSM.unsafeWith resultVec $ \resultPtr ->
                          ptrFn (fromIntegral n) resultPtr inputPtr
                      VS.unsafeFreeze resultVec
                result <- doFun vecFn
                Just <$> evaluate (force result)

First, the function type changes. Here, I have made it possible to execute an arbitrary IO action using the JIT-compiled function. This works well.

Next, you must specify the targetMachine in passSetSpec. Without this, LLVM will only perform platform-independent optimizations, so vectorization will not be performed.

Finally, we do some manipulation to make the provided function appear as VS.Vector Double -> VS.Vector Double (the vecFn function).

The usage side looks like this:

loop/Main.hs
main :: IO ()
main = do
  let f x = (x + 1)^10 * (x + 1)
  expr <- recoverSharing (f Var)
  _ <- withArrayJIT expr $ \vf -> do
    print $ vf (VS.fromList [1..20])
  pure ()

I will also provide the execution results on AArch64 macOS:

*** Before optimization ***
; ModuleID = 'dsl.ll'
source_filename = "<string>"

declare double @llvm.fabs.f64(double)

declare double @llvm.exp.f64(double)

declare double @llvm.log.f64(double)

declare double @llvm.sin.f64(double)

declare double @llvm.cos.f64(double)

define void @f(i32 %size_0, ptr %resultArray_0, ptr %inputArray_0) {
  br label %loop_0

loop_0:                                           ; preds = %loopBody_0, %0
  %counter_0 = phi i32 [ 0, %0 ], [ %nextCounter_0, %loopBody_0 ]
  %lt_0 = icmp slt i32 %counter_0, %size_0
  br i1 %lt_0, label %loopBody_0, label %epilogue_0

loopBody_0:                                       ; preds = %loop_0
  %xPtr_0 = getelementptr [0 x double], ptr %inputArray_0, i32 0, i32 %counter_0
  %x_0 = load double, ptr %xPtr_0, align 8
  %x1_0 = fadd double %x_0, 1.000000e+00
  %x2_0 = fmul double %x1_0, %x1_0
  %x3_0 = fmul double %x2_0, %x2_0
  %x4_0 = fmul double %x3_0, %x3_0
  %x5_0 = fmul double %x4_0, %x2_0
  %x6_0 = fmul double %x5_0, %x1_0
  %resultPtr_0 = getelementptr [0 x double], ptr %resultArray_0, i32 0, i32 %counter_0
  store double %x6_0, ptr %resultPtr_0, align 8
  %nextCounter_0 = add i32 %counter_0, 1
  br label %loop_0

epilogue_0:                                       ; preds = %loop_0
  ret void
}
***************************
*** After optimization ***
; ModuleID = 'dsl.ll'
source_filename = "<string>"

; Function Attrs: argmemonly nofree norecurse nosync nounwind
define void @f(i32 %size_0, ptr nocapture writeonly %resultArray_0, ptr nocapture readonly %inputArray_0) local_unnamed_addr #0 {
  %lt_01 = icmp sgt i32 %size_0, 0
  br i1 %lt_01, label %loopBody_0.preheader, label %epilogue_0

loopBody_0.preheader:                             ; preds = %0
  %resultArray_03 = ptrtoint ptr %resultArray_0 to i64
  %inputArray_04 = ptrtoint ptr %inputArray_0 to i64
  %min.iters.check = icmp ult i32 %size_0, 4
  %1 = sub i64 %resultArray_03, %inputArray_04
  %diff.check = icmp ult i64 %1, 32
  %or.cond = select i1 %min.iters.check, i1 true, i1 %diff.check
  br i1 %or.cond, label %loopBody_0.preheader6, label %vector.ph

vector.ph:                                        ; preds = %loopBody_0.preheader
  %n.vec = and i32 %size_0, -4
  br label %vector.body

vector.body:                                      ; preds = %vector.body, %vector.ph
  %index = phi i32 [ 0, %vector.ph ], [ %index.next, %vector.body ]
  %2 = zext i32 %index to i64
  %3 = getelementptr [0 x double], ptr %inputArray_0, i64 0, i64 %2
  %wide.load = load <2 x double>, ptr %3, align 8
  %4 = getelementptr double, ptr %3, i64 2
  %wide.load5 = load <2 x double>, ptr %4, align 8
  %5 = fadd <2 x double> %wide.load, <double 1.000000e+00, double 1.000000e+00>
  %6 = fadd <2 x double> %wide.load5, <double 1.000000e+00, double 1.000000e+00>
  %7 = fmul <2 x double> %5, %5
  %8 = fmul <2 x double> %6, %6
  %9 = fmul <2 x double> %7, %7
  %10 = fmul <2 x double> %8, %8
  %11 = fmul <2 x double> %9, %9
  %12 = fmul <2 x double> %10, %10
  %13 = fmul <2 x double> %7, %11
  %14 = fmul <2 x double> %8, %12
  %15 = fmul <2 x double> %5, %13
  %16 = fmul <2 x double> %6, %14
  %17 = getelementptr [0 x double], ptr %resultArray_0, i64 0, i64 %2
  store <2 x double> %15, ptr %17, align 8
  %18 = getelementptr double, ptr %17, i64 2
  store <2 x double> %16, ptr %18, align 8
  %index.next = add nuw i32 %index, 4
  %19 = icmp eq i32 %index.next, %n.vec
  br i1 %19, label %middle.block, label %vector.body, !llvm.loop !0

middle.block:                                     ; preds = %vector.body
  %cmp.n = icmp eq i32 %n.vec, %size_0
  br i1 %cmp.n, label %epilogue_0, label %loopBody_0.preheader6

loopBody_0.preheader6:                            ; preds = %loopBody_0.preheader, %middle.block
  %counter_02.ph = phi i32 [ 0, %loopBody_0.preheader ], [ %n.vec, %middle.block ]
  br label %loopBody_0

loopBody_0:                                       ; preds = %loopBody_0.preheader6, %loopBody_0
  %counter_02 = phi i32 [ %nextCounter_0, %loopBody_0 ], [ %counter_02.ph, %loopBody_0.preheader6 ]
  %20 = zext i32 %counter_02 to i64
  %xPtr_0 = getelementptr [0 x double], ptr %inputArray_0, i64 0, i64 %20
  %x_0 = load double, ptr %xPtr_0, align 8
  %x1_0 = fadd double %x_0, 1.000000e+00
  %x2_0 = fmul double %x1_0, %x1_0
  %x3_0 = fmul double %x2_0, %x2_0
  %x4_0 = fmul double %x3_0, %x3_0
  %x5_0 = fmul double %x2_0, %x4_0
  %x6_0 = fmul double %x1_0, %x5_0
  %resultPtr_0 = getelementptr [0 x double], ptr %resultArray_0, i64 0, i64 %20
  store double %x6_0, ptr %resultPtr_0, align 8
  %nextCounter_0 = add nuw nsw i32 %counter_02, 1
  %lt_0 = icmp slt i32 %nextCounter_0, %size_0
  br i1 %lt_0, label %loopBody_0, label %epilogue_0, !llvm.loop !2

epilogue_0:                                       ; preds = %loopBody_0, %middle.block, %0
  ret void
}

attributes #0 = { argmemonly nofree norecurse nosync nounwind }

!0 = distinct !{!0, !1}
!1 = !{!"llvm.loop.isvectorized", i32 1}
!2 = distinct !{!2, !1}
**************************
*** Target assembly ***
	.section	__TEXT,__text,regular,pure_instructions
	.build_version macos, 15, 0
	.globl	_f
	.p2align	2
_f:
	cmp	w0, #1
	b.lt	LBB0_8
	mov	w8, #0
	cmp	w0, #4
	b.lo	LBB0_6
	sub	x9, x1, x2
	cmp	x9, #32
	b.lo	LBB0_6
	and	w8, w0, #0xfffffffc
	add	x9, x1, #16
	add	x10, x2, #16
	fmov.2d	v0, #1.00000000
	mov	x11, x8
LBB0_4:
	ldp	q1, q2, [x10, #-16]
	fadd.2d	v1, v1, v0
	fadd.2d	v2, v2, v0
	fmul.2d	v3, v1, v1
	fmul.2d	v4, v2, v2
	fmul.2d	v5, v3, v3
	fmul.2d	v6, v4, v4
	fmul.2d	v5, v5, v5
	fmul.2d	v6, v6, v6
	fmul.2d	v3, v3, v5
	fmul.2d	v4, v4, v6
	fmul.2d	v1, v1, v3
	fmul.2d	v2, v2, v4
	stp	q1, q2, [x9, #-16]
	add	x9, x9, #32
	add	x10, x10, #32
	subs	w11, w11, #4
	b.ne	LBB0_4
	cmp	w8, w0
	b.eq	LBB0_8
LBB0_6:
	mov	w8, w8
	fmov	d0, #1.00000000
LBB0_7:
	lsl	x9, x8, #3
	ldr	d1, [x2, x9]
	fadd	d1, d1, d0
	fmul	d2, d1, d1
	fmul	d3, d2, d2
	fmul	d3, d3, d3
	fmul	d2, d2, d3
	fmul	d1, d1, d2
	str	d1, [x1, x9]
	add	x8, x8, #1
	cmp	w8, w0
	b.lt	LBB0_7
LBB0_8:
	ret

.subsections_via_symbols
***********************
[2048.0,177147.0,4194304.0,4.8828125e7,3.62797056e8,1.977326743e9,8.589934592e9,3.1381059609e10,1.0e11,2.85311670611e11,7.43008370688e11,1.792160394037e12,4.049565169664e12,8.649755859375e12,1.7592186044416e13,3.4271896307633e13,6.4268410079232e13,1.16490258898219e14,2.048e14,3.50277500542221e14]

The SIMD available on my CPU (Apple M4 Pro) is the 128-bit wide NEON, which can process two Double values at a time. Furthermore, because the loop is unrolled twice, it processes four Double values in a single loop iteration.

On x86 systems, using AVX or AVX-512 might result in a different number of elements processed per iteration.

Utilizing the noalias Attribute

In the C pseudocode I presented earlier, I used restrict, but the LLVM IR I just generated doesn't use anything equivalent.

In LLVM, you can add the noalias attribute to function arguments (specifically LLVM.AST.ParameterAttribute.NoAlias in llvm-hs-pure), but LLVM.IRBuilder.Module.function doesn't allow specifying attributes for arguments. As a workaround, I will write the function definition part manually.

src/LoopCodegen.hs
codegenNoAlias :: ExpS Double -> AST.Module
codegenNoAlias expr = IR.buildModule "dsl.ll" $ do
  absFn <- IR.extern "llvm.fabs.f64" [AST.double] AST.double
  -- Omitted
  cosFn <- IR.extern "llvm.cos.f64" [AST.double] AST.double

  let goValue :: IntMap.IntMap AST.Operand -> Value Double -> AST.Operand
      -- Omitted
      goSimpleExp :: (IR.MonadIRBuilder m, IR.MonadModuleBuilder m) => IntMap.IntMap AST.Operand -> SimpleExp Double -> m AST.Operand
      -- Omitted
      goExp :: (IR.MonadIRBuilder m, IR.MonadModuleBuilder m) => IntMap.IntMap AST.Operand -> ExpS Double -> m AST.Operand
      -- Omitted

  let functionBody size resultArray inputArray = mdo
        prologue <- IR.block
        -- Omitted
        IR.retVoid

  ((sizeName, resultArrayName, inputArrayName), blocks) <- IR.runIRBuilderT IR.emptyIRBuilder $ do
    sizeName <- IR.fresh `IR.named` "size"
    resultArrayName <- IR.fresh `IR.named` "resultArray"
    inputArrayName <- IR.fresh `IR.named` "inputArray"
    functionBody (AST.LocalReference AST.i32 sizeName) (AST.LocalReference AST.ptr resultArrayName) (AST.LocalReference AST.ptr inputArrayName)
    pure (sizeName, resultArrayName, inputArrayName)

  let def = AST.GlobalDefinition AST.functionDefaults
            { AST.name = "f"
            , AST.parameters = ([AST.Parameter AST.i32 sizeName [], AST.Parameter AST.ptr resultArrayName [AST.NoAlias], AST.Parameter AST.ptr inputArrayName []], False)
            , AST.returnType = AST.void
            , AST.basicBlocks = blocks
            }
  IR.emitDefn def

  pure ()

The execution results using this are as follows:

*** Before optimization ***
; ModuleID = 'dsl.ll'
source_filename = "<string>"

declare double @llvm.fabs.f64(double)

declare double @llvm.exp.f64(double)

declare double @llvm.log.f64(double)

declare double @llvm.sin.f64(double)

declare double @llvm.cos.f64(double)

define void @f(i32 %size_0, ptr noalias %resultArray_0, ptr %inputArray_0) {
  br label %loop_0

loop_0:                                           ; preds = %loopBody_0, %0
  %counter_0 = phi i32 [ 0, %0 ], [ %nextCounter_0, %loopBody_0 ]
  %lt_0 = icmp slt i32 %counter_0, %size_0
  br i1 %lt_0, label %loopBody_0, label %epilogue_0

loopBody_0:                                       ; preds = %loop_0
  %xPtr_0 = getelementptr [0 x double], ptr %inputArray_0, i32 0, i32 %counter_0
  %x_0 = load double, ptr %xPtr_0, align 8
  %x1_0 = fadd double %x_0, 1.000000e+00
  %x2_0 = fmul double %x1_0, %x1_0
  %x3_0 = fmul double %x2_0, %x2_0
  %x4_0 = fmul double %x3_0, %x3_0
  %x5_0 = fmul double %x4_0, %x2_0
  %x6_0 = fmul double %x5_0, %x1_0
  %resultPtr_0 = getelementptr [0 x double], ptr %resultArray_0, i32 0, i32 %counter_0
  store double %x6_0, ptr %resultPtr_0, align 8
  %nextCounter_0 = add i32 %counter_0, 1
  br label %loop_0

epilogue_0:                                       ; preds = %loop_0
  ret void
}
***************************
*** After optimization ***
; ModuleID = 'dsl.ll'
source_filename = "<string>"

; Function Attrs: argmemonly nofree norecurse nosync nounwind
define void @f(i32 %size_0, ptr noalias nocapture writeonly %resultArray_0, ptr nocapture readonly %inputArray_0) local_unnamed_addr #0 {
  %lt_01 = icmp sgt i32 %size_0, 0
  br i1 %lt_01, label %loopBody_0.preheader, label %epilogue_0

loopBody_0.preheader:                             ; preds = %0
  %min.iters.check = icmp ult i32 %size_0, 4
  br i1 %min.iters.check, label %loopBody_0.preheader4, label %vector.ph

vector.ph:                                        ; preds = %loopBody_0.preheader
  %n.vec = and i32 %size_0, -4
  br label %vector.body

vector.body:                                      ; preds = %vector.body, %vector.ph
  %index = phi i32 [ 0, %vector.ph ], [ %index.next, %vector.body ]
  %1 = zext i32 %index to i64
  %2 = getelementptr [0 x double], ptr %inputArray_0, i64 0, i64 %1
  %wide.load = load <2 x double>, ptr %2, align 8
  %3 = getelementptr double, ptr %2, i64 2
  %wide.load3 = load <2 x double>, ptr %3, align 8
  %4 = fadd <2 x double> %wide.load, <double 1.000000e+00, double 1.000000e+00>
  %5 = fadd <2 x double> %wide.load3, <double 1.000000e+00, double 1.000000e+00>
  %6 = fmul <2 x double> %4, %4
  %7 = fmul <2 x double> %5, %5
  %8 = fmul <2 x double> %6, %6
  %9 = fmul <2 x double> %7, %7
  %10 = fmul <2 x double> %8, %8
  %11 = fmul <2 x double> %9, %9
  %12 = fmul <2 x double> %6, %10
  %13 = fmul <2 x double> %7, %11
  %14 = fmul <2 x double> %4, %12
  %15 = fmul <2 x double> %5, %13
  %16 = getelementptr [0 x double], ptr %resultArray_0, i64 0, i64 %1
  store <2 x double> %14, ptr %16, align 8
  %17 = getelementptr double, ptr %16, i64 2
  store <2 x double> %15, ptr %17, align 8
  %index.next = add nuw i32 %index, 4
  %18 = icmp eq i32 %index.next, %n.vec
  br i1 %18, label %middle.block, label %vector.body, !llvm.loop !0

middle.block:                                     ; preds = %vector.body
  %cmp.n = icmp eq i32 %n.vec, %size_0
  br i1 %cmp.n, label %epilogue_0, label %loopBody_0.preheader4

loopBody_0.preheader4:                            ; preds = %loopBody_0.preheader, %middle.block
  %counter_02.ph = phi i32 [ 0, %loopBody_0.preheader ], [ %n.vec, %middle.block ]
  br label %loopBody_0

loopBody_0:                                       ; preds = %loopBody_0.preheader4, %loopBody_0
  %counter_02 = phi i32 [ %nextCounter_0, %loopBody_0 ], [ %counter_02.ph, %loopBody_0.preheader4 ]
  %19 = zext i32 %counter_02 to i64
  %xPtr_0 = getelementptr [0 x double], ptr %inputArray_0, i64 0, i64 %19
  %x_0 = load double, ptr %xPtr_0, align 8
  %x1_0 = fadd double %x_0, 1.000000e+00
  %x2_0 = fmul double %x1_0, %x1_0
  %x3_0 = fmul double %x2_0, %x2_0
  %x4_0 = fmul double %x3_0, %x3_0
  %x5_0 = fmul double %x2_0, %x4_0
  %x6_0 = fmul double %x1_0, %x5_0
  %resultPtr_0 = getelementptr [0 x double], ptr %resultArray_0, i64 0, i64 %19
  store double %x6_0, ptr %resultPtr_0, align 8
  %nextCounter_0 = add nuw nsw i32 %counter_02, 1
  %lt_0 = icmp slt i32 %nextCounter_0, %size_0
  br i1 %lt_0, label %loopBody_0, label %epilogue_0, !llvm.loop !2

epilogue_0:                                       ; preds = %loopBody_0, %middle.block, %0
  ret void
}

attributes #0 = { argmemonly nofree norecurse nosync nounwind }

!0 = distinct !{!0, !1}
!1 = !{!"llvm.loop.isvectorized", i32 1}
!2 = distinct !{!2, !3, !1}
!3 = !{!"llvm.loop.unroll.runtime.disable"}
**************************
*** Target assembly ***
	.section	__TEXT,__text,regular,pure_instructions
	.build_version macos, 15, 0
	.globl	_f
	.p2align	2
_f:
	cmp	w0, #1
	b.lt	LBB0_8
	cmp	w0, #4
	b.hs	LBB0_3
	mov	w8, #0
	b	LBB0_6
LBB0_3:
	and	w8, w0, #0xfffffffc
	add	x9, x1, #16
	add	x10, x2, #16
	fmov.2d	v0, #1.00000000
	mov	x11, x8
LBB0_4:
	ldp	q1, q2, [x10, #-16]
	fadd.2d	v1, v1, v0
	fadd.2d	v2, v2, v0
	fmul.2d	v3, v1, v1
	fmul.2d	v4, v2, v2
	fmul.2d	v5, v3, v3
	fmul.2d	v6, v4, v4
	fmul.2d	v5, v5, v5
	fmul.2d	v6, v6, v6
	fmul.2d	v3, v3, v5
	fmul.2d	v4, v4, v6
	fmul.2d	v1, v1, v3
	fmul.2d	v2, v2, v4
	stp	q1, q2, [x9, #-16]
	add	x9, x9, #32
	add	x10, x10, #32
	subs	w11, w11, #4
	b.ne	LBB0_4
	cmp	w8, w0
	b.eq	LBB0_8
LBB0_6:
	mov	w8, w8
	fmov	d0, #1.00000000
LBB0_7:
	lsl	x9, x8, #3
	ldr	d1, [x2, x9]
	fadd	d1, d1, d0
	fmul	d2, d1, d1
	fmul	d3, d2, d2
	fmul	d3, d3, d3
	fmul	d2, d2, d3
	fmul	d1, d1, d2
	str	d1, [x1, x9]
	add	x8, x8, #1
	cmp	w8, w0
	b.lt	LBB0_7
LBB0_8:
	ret

.subsections_via_symbols
***********************
[2048.0,177147.0,4194304.0,4.8828125e7,3.62797056e8,1.977326743e9,8.589934592e9,3.1381059609e10,1.0e11,2.85311670611e11,7.43008370688e11,1.792160394037e12,4.049565169664e12,8.649755859375e12,1.7592186044416e13,3.4271896307633e13,6.4268410079232e13,1.16490258898219e14,2.048e14,3.50277500542221e14]

If you look closely, you can see that the noalias attribute is added to the argument, and one branch in the function prologue has been removed.

Benchmark

Let's compare the performance between the process written in pure Haskell and the code generated by LLVM.

For the pure Haskell side, we prepare a version that calculates (x + 1)^(10 :: Int) normally and one where ^10 is unrolled. We also prepare a case using map from storable vector to run a Double -> Double function, versus passing the entire array via FFI.

benchmark/Main.hs
import           Criterion.Main

f :: Num a => a -> a
f x = (x + 1)^(10 :: Int)
{-# SPECIALIZE f :: Double -> Double #-}

g :: Num a => a -> a
g x = let y1 = x + 1
          y2 = y1 * y1
          y4 = y2 * y2
          y8 = y4 * y4
      in y8 * y2
{-# SPECIALIZE g :: Double -> Double #-}

main :: IO ()
main = do
  expr <- recoverSharing (f Var)
  let !input = VS.fromList [0..10000]
  _ <- withSimpleJIT expr $ \simpleF ->
    withArrayJIT expr $ \arrayF ->
      defaultMain
        [ bench "Haskell/vector" $ whnf (VS.map f) input
        , bench "Haskell unrolled/vector" $ whnf (VS.map g) input
        , bench "JIT/vector" $ whnf (VS.map simpleF) input
        , bench "JIT/array" $ whnf arrayF input
        ]
  pure ()

Results on Apple M4 Pro

First, here are the results using GHC's NCG backend:

$ cabal-3.10.3.0 bench -w ghc-9.6.6 -O2 --builddir=dist-ncg
benchmarking Haskell/map
time                 26.91 μs   (26.88 μs .. 26.93 μs)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 26.90 μs   (26.85 μs .. 26.93 μs)
std dev              138.9 ns   (111.8 ns .. 185.2 ns)

benchmarking Haskell unrolled/map
time                 7.239 μs   (7.228 μs .. 7.251 μs)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 7.241 μs   (7.229 μs .. 7.254 μs)
std dev              41.28 ns   (33.48 ns .. 51.65 ns)

benchmarking JIT/map
time                 20.97 μs   (20.93 μs .. 21.00 μs)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 20.96 μs   (20.93 μs .. 20.99 μs)
std dev              102.1 ns   (79.48 ns .. 131.4 ns)

benchmarking JIT/array
time                 1.632 μs   (1.631 μs .. 1.634 μs)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 1.633 μs   (1.630 μs .. 1.635 μs)
std dev              7.347 ns   (5.794 ns .. 9.874 ns)

Here are the results using GHC's LLVM backend as well:

$ cabal-3.10.3.0 bench -w ghc-9.6.6 -O2 --ghc-options=-fllvm --builddir=dist-llvm
benchmarking Haskell/map
time                 3.181 μs   (3.175 μs .. 3.187 μs)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 3.180 μs   (3.173 μs .. 3.191 μs)
std dev              27.62 ns   (17.49 ns .. 48.06 ns)

benchmarking Haskell unrolled/map
time                 3.212 μs   (3.208 μs .. 3.217 μs)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 3.214 μs   (3.211 μs .. 3.217 μs)
std dev              11.45 ns   (9.470 ns .. 15.02 ns)

benchmarking JIT/map
time                 7.238 μs   (7.223 μs .. 7.252 μs)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 7.227 μs   (7.211 μs .. 7.250 μs)
std dev              61.64 ns   (40.19 ns .. 98.23 ns)

benchmarking JIT/array
time                 1.653 μs   (1.650 μs .. 1.658 μs)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 1.649 μs   (1.647 μs .. 1.653 μs)
std dev              8.994 ns   (6.586 ns .. 12.49 ns)

The approach of calling JIT-compiled code element by element (JIT/map) resulted in performance that couldn't match the unrolled ^10 in pure Haskell (Haskell unrolled/vector), likely due to overhead in function calls or elsewhere.

The approach of processing the whole array with JIT-compiled code (JIT/array) performed well. The time taken was about half that of the unrolled ^10 in pure Haskell using GHC's LLVM backend (Haskell unrolled/map). Since it processes two elements simultaneously with SIMD, this result is as expected. It was worth the effort of generating the code manually and calling LLVM.

Results on Ryzen 9 7940HS

I will also include the results for the Ryzen 9 7940HS (Zen 4), which supports AVX-512. The OS used was Ubuntu 22.04 on WSL2.

The assembly code generated for (x + 1)^10 * (x + 1) was as follows:

*** Target assembly ***
        .text
        .file   "0cstring0e"
        .section        .rodata.cst8,"aM",@progbits,8
        .p2align        3
.LCPI0_0:
        .quad   0x3ff0000000000000
        .text
        .globl  f
        .p2align        4, 0x90
        .type   f,@function
f:
        pushq   %rbp
        pushq   %r14
        pushq   %rbx
.L0$pb:
        leaq    .L0$pb(%rip), %rax
        movabsq $_GLOBAL_OFFSET_TABLE_-.L0$pb, %r14
        addq    %rax, %r14
        testl   %edi, %edi
        jle     .LBB0_11
        movabsq $.LCPI0_0@GOTOFF, %r8
        xorl    %r11d, %r11d
        cmpl    $32, %edi
        jb      .LBB0_9
        movl    %edi, %r11d
        andl    $-32, %r11d
        xorl    %ecx, %ecx
        leal    -32(%r11), %eax
        movl    %eax, %r9d
        shrl    $5, %r9d
        leal    1(%r9), %r10d
        cmpl    $224, %eax
        jb      .LBB0_5
        vbroadcastsd    (%r14,%r8), %zmm0
        movl    %r10d, %ebp
        andl    $-8, %ebp
        leaq    1984(%rsi), %rax
        leaq    1984(%rdx), %rbx
        xorl    %ecx, %ecx
        .p2align        4, 0x90
.LBB0_4:
        vaddpd  -1984(%rbx,%rcx,8), %zmm0, %zmm1
        vaddpd  -1920(%rbx,%rcx,8), %zmm0, %zmm2
        vaddpd  -1856(%rbx,%rcx,8), %zmm0, %zmm3
        vaddpd  -1792(%rbx,%rcx,8), %zmm0, %zmm4
        vmulpd  %zmm1, %zmm1, %zmm5
        vmulpd  %zmm2, %zmm2, %zmm6
        vmulpd  %zmm3, %zmm3, %zmm7
        vmulpd  %zmm4, %zmm4, %zmm8
        vmulpd  %zmm5, %zmm5, %zmm9
        vmulpd  %zmm6, %zmm6, %zmm10
        vmulpd  %zmm7, %zmm7, %zmm11
        vmulpd  %zmm8, %zmm8, %zmm12
        vmulpd  %zmm9, %zmm9, %zmm9
        vmulpd  %zmm10, %zmm10, %zmm10
        vmulpd  %zmm11, %zmm11, %zmm11
        vmulpd  %zmm12, %zmm12, %zmm12
        vmulpd  %zmm9, %zmm5, %zmm5
        vmulpd  %zmm10, %zmm6, %zmm6
        vmulpd  %zmm11, %zmm7, %zmm7
        vmulpd  %zmm12, %zmm8, %zmm8
        vmulpd  %zmm5, %zmm1, %zmm1
        vmulpd  %zmm6, %zmm2, %zmm2
        vmulpd  %zmm7, %zmm3, %zmm3
        vmulpd  %zmm8, %zmm4, %zmm4
        vmovupd %zmm1, -1984(%rax,%rcx,8)
        vmovupd %zmm2, -1920(%rax,%rcx,8)
        vmovupd %zmm3, -1856(%rax,%rcx,8)
        vmovupd %zmm4, -1792(%rax,%rcx,8)
        vaddpd  -1728(%rbx,%rcx,8), %zmm0, %zmm1
        vaddpd  -1664(%rbx,%rcx,8), %zmm0, %zmm2
        vaddpd  -1600(%rbx,%rcx,8), %zmm0, %zmm3
        vaddpd  -1536(%rbx,%rcx,8), %zmm0, %zmm4
        vmulpd  %zmm1, %zmm1, %zmm5
        vmulpd  %zmm2, %zmm2, %zmm6
        vmulpd  %zmm3, %zmm3, %zmm7
        vmulpd  %zmm4, %zmm4, %zmm8
        vmulpd  %zmm5, %zmm5, %zmm9
        vmulpd  %zmm6, %zmm6, %zmm10
        vmulpd  %zmm7, %zmm7, %zmm11
        vmulpd  %zmm8, %zmm8, %zmm12
        vmulpd  %zmm9, %zmm9, %zmm9
        vmulpd  %zmm10, %zmm10, %zmm10
        vmulpd  %zmm11, %zmm11, %zmm11
        vmulpd  %zmm12, %zmm12, %zmm12
        vmulpd  %zmm9, %zmm5, %zmm5
        vmulpd  %zmm10, %zmm6, %zmm6
        vmulpd  %zmm11, %zmm7, %zmm7
        vmulpd  %zmm12, %zmm8, %zmm8
        vmulpd  %zmm5, %zmm1, %zmm1
        vmulpd  %zmm6, %zmm2, %zmm2
        vmulpd  %zmm7, %zmm3, %zmm3
        vmulpd  %zmm8, %zmm4, %zmm4
        vmovupd %zmm1, -1728(%rax,%rcx,8)
        vmovupd %zmm2, -1664(%rax,%rcx,8)
        vmovupd %zmm3, -1600(%rax,%rcx,8)
        vmovupd %zmm4, -1536(%rax,%rcx,8)
        vaddpd  -1472(%rbx,%rcx,8), %zmm0, %zmm1
        vaddpd  -1408(%rbx,%rcx,8), %zmm0, %zmm2
        vaddpd  -1344(%rbx,%rcx,8), %zmm0, %zmm3
        vaddpd  -1280(%rbx,%rcx,8), %zmm0, %zmm4
        vmulpd  %zmm1, %zmm1, %zmm5
        vmulpd  %zmm2, %zmm2, %zmm6
        vmulpd  %zmm3, %zmm3, %zmm7
        vmulpd  %zmm4, %zmm4, %zmm8
        vmulpd  %zmm5, %zmm5, %zmm9
        vmulpd  %zmm6, %zmm6, %zmm10
        vmulpd  %zmm7, %zmm7, %zmm11
        vmulpd  %zmm8, %zmm8, %zmm12
        vmulpd  %zmm9, %zmm9, %zmm9
        vmulpd  %zmm10, %zmm10, %zmm10
        vmulpd  %zmm11, %zmm11, %zmm11
        vmulpd  %zmm12, %zmm12, %zmm12
        vmulpd  %zmm9, %zmm5, %zmm5
        vmulpd  %zmm10, %zmm6, %zmm6
        vmulpd  %zmm11, %zmm7, %zmm7
        vmulpd  %zmm12, %zmm8, %zmm8
        vmulpd  %zmm5, %zmm1, %zmm1
        vmulpd  %zmm6, %zmm2, %zmm2
        vmulpd  %zmm7, %zmm3, %zmm3
        vmulpd  %zmm8, %zmm4, %zmm4
        vmovupd %zmm1, -1472(%rax,%rcx,8)
        vmovupd %zmm2, -1408(%rax,%rcx,8)
        vmovupd %zmm3, -1344(%rax,%rcx,8)
        vmovupd %zmm4, -1280(%rax,%rcx,8)
        vaddpd  -1216(%rbx,%rcx,8), %zmm0, %zmm1
        vaddpd  -1152(%rbx,%rcx,8), %zmm0, %zmm2
        vaddpd  -1088(%rbx,%rcx,8), %zmm0, %zmm3
        vaddpd  -1024(%rbx,%rcx,8), %zmm0, %zmm4
        vmulpd  %zmm1, %zmm1, %zmm5
        vmulpd  %zmm2, %zmm2, %zmm6
        vmulpd  %zmm3, %zmm3, %zmm7
        vmulpd  %zmm4, %zmm4, %zmm8
        vmulpd  %zmm5, %zmm5, %zmm9
        vmulpd  %zmm6, %zmm6, %zmm10
        vmulpd  %zmm7, %zmm7, %zmm11
        vmulpd  %zmm8, %zmm8, %zmm12
        vmulpd  %zmm9, %zmm9, %zmm9
        vmulpd  %zmm10, %zmm10, %zmm10
        vmulpd  %zmm11, %zmm11, %zmm11
        vmulpd  %zmm12, %zmm12, %zmm12
        vmulpd  %zmm9, %zmm5, %zmm5
        vmulpd  %zmm10, %zmm6, %zmm6
        vmulpd  %zmm11, %zmm7, %zmm7
        vmulpd  %zmm12, %zmm8, %zmm8
        vmulpd  %zmm5, %zmm1, %zmm1
        vmulpd  %zmm6, %zmm2, %zmm2
        vmulpd  %zmm7, %zmm3, %zmm3
        vmulpd  %zmm8, %zmm4, %zmm4
        vmovupd %zmm1, -1216(%rax,%rcx,8)
        vmovupd %zmm2, -1152(%rax,%rcx,8)
        vmovupd %zmm3, -1088(%rax,%rcx,8)
        vmovupd %zmm4, -1024(%rax,%rcx,8)
        vaddpd  -960(%rbx,%rcx,8), %zmm0, %zmm1
        vaddpd  -896(%rbx,%rcx,8), %zmm0, %zmm2
        vaddpd  -832(%rbx,%rcx,8), %zmm0, %zmm3
        vaddpd  -768(%rbx,%rcx,8), %zmm0, %zmm4
        vmulpd  %zmm1, %zmm1, %zmm5
        vmulpd  %zmm2, %zmm2, %zmm6
        vmulpd  %zmm3, %zmm3, %zmm7
        vmulpd  %zmm4, %zmm4, %zmm8
        vmulpd  %zmm5, %zmm5, %zmm9
        vmulpd  %zmm6, %zmm6, %zmm10
        vmulpd  %zmm7, %zmm7, %zmm11
        vmulpd  %zmm8, %zmm8, %zmm12
        vmulpd  %zmm9, %zmm9, %zmm9
        vmulpd  %zmm10, %zmm10, %zmm10
        vmulpd  %zmm11, %zmm11, %zmm11
        vmulpd  %zmm12, %zmm12, %zmm12
        vmulpd  %zmm9, %zmm5, %zmm5
        vmulpd  %zmm10, %zmm6, %zmm6
        vmulpd  %zmm11, %zmm7, %zmm7
        vmulpd  %zmm12, %zmm8, %zmm8
        vmulpd  %zmm5, %zmm1, %zmm1
        vmulpd  %zmm6, %zmm2, %zmm2
        vmulpd  %zmm7, %zmm3, %zmm3
        vmulpd  %zmm8, %zmm4, %zmm4
        vmovupd %zmm1, -960(%rax,%rcx,8)
        vmovupd %zmm2, -896(%rax,%rcx,8)
        vmovupd %zmm3, -832(%rax,%rcx,8)
        vmovupd %zmm4, -768(%rax,%rcx,8)
        vaddpd  -704(%rbx,%rcx,8), %zmm0, %zmm1
        vaddpd  -640(%rbx,%rcx,8), %zmm0, %zmm2
        vaddpd  -576(%rbx,%rcx,8), %zmm0, %zmm3
        vaddpd  -512(%rbx,%rcx,8), %zmm0, %zmm4
        vmulpd  %zmm1, %zmm1, %zmm5
        vmulpd  %zmm2, %zmm2, %zmm6
        vmulpd  %zmm3, %zmm3, %zmm7
        vmulpd  %zmm4, %zmm4, %zmm8
        vmulpd  %zmm5, %zmm5, %zmm9
        vmulpd  %zmm6, %zmm6, %zmm10
        vmulpd  %zmm7, %zmm7, %zmm11
        vmulpd  %zmm8, %zmm8, %zmm12
        vmulpd  %zmm9, %zmm9, %zmm9
        vmulpd  %zmm10, %zmm10, %zmm10
        vmulpd  %zmm11, %zmm11, %zmm11
        vmulpd  %zmm12, %zmm12, %zmm12
        vmulpd  %zmm9, %zmm5, %zmm5
        vmulpd  %zmm10, %zmm6, %zmm6
        vmulpd  %zmm11, %zmm7, %zmm7
        vmulpd  %zmm12, %zmm8, %zmm8
        vmulpd  %zmm5, %zmm1, %zmm1
        vmulpd  %zmm6, %zmm2, %zmm2
        vmulpd  %zmm7, %zmm3, %zmm3
        vmulpd  %zmm8, %zmm4, %zmm4
        vmovupd %zmm1, -704(%rax,%rcx,8)
        vmovupd %zmm2, -640(%rax,%rcx,8)
        vmovupd %zmm3, -576(%rax,%rcx,8)
        vmovupd %zmm4, -512(%rax,%rcx,8)
        vaddpd  -448(%rbx,%rcx,8), %zmm0, %zmm1
        vaddpd  -384(%rbx,%rcx,8), %zmm0, %zmm2
        vaddpd  -320(%rbx,%rcx,8), %zmm0, %zmm3
        vaddpd  -256(%rbx,%rcx,8), %zmm0, %zmm4
        vmulpd  %zmm1, %zmm1, %zmm5
        vmulpd  %zmm2, %zmm2, %zmm6
        vmulpd  %zmm3, %zmm3, %zmm7
        vmulpd  %zmm4, %zmm4, %zmm8
        vmulpd  %zmm5, %zmm5, %zmm9
        vmulpd  %zmm6, %zmm6, %zmm10
        vmulpd  %zmm7, %zmm7, %zmm11
        vmulpd  %zmm8, %zmm8, %zmm12
        vmulpd  %zmm9, %zmm9, %zmm9
        vmulpd  %zmm10, %zmm10, %zmm10
        vmulpd  %zmm11, %zmm11, %zmm11
        vmulpd  %zmm12, %zmm12, %zmm12
        vmulpd  %zmm9, %zmm5, %zmm5
        vmulpd  %zmm10, %zmm6, %zmm6
        vmulpd  %zmm11, %zmm7, %zmm7
        vmulpd  %zmm12, %zmm8, %zmm8
        vmulpd  %zmm5, %zmm1, %zmm1
        vmulpd  %zmm6, %zmm2, %zmm2
        vmulpd  %zmm7, %zmm3, %zmm3
        vmulpd  %zmm8, %zmm4, %zmm4
        vmovupd %zmm1, -448(%rax,%rcx,8)
        vmovupd %zmm2, -384(%rax,%rcx,8)
        vmovupd %zmm3, -320(%rax,%rcx,8)
        vmovupd %zmm4, -256(%rax,%rcx,8)
        vaddpd  -192(%rbx,%rcx,8), %zmm0, %zmm1
        vaddpd  -128(%rbx,%rcx,8), %zmm0, %zmm2
        vaddpd  -64(%rbx,%rcx,8), %zmm0, %zmm3
        vaddpd  (%rbx,%rcx,8), %zmm0, %zmm4
        vmulpd  %zmm1, %zmm1, %zmm5
        vmulpd  %zmm2, %zmm2, %zmm6
        vmulpd  %zmm3, %zmm3, %zmm7
        vmulpd  %zmm4, %zmm4, %zmm8
        vmulpd  %zmm5, %zmm5, %zmm9
        vmulpd  %zmm6, %zmm6, %zmm10
        vmulpd  %zmm7, %zmm7, %zmm11
        vmulpd  %zmm8, %zmm8, %zmm12
        vmulpd  %zmm9, %zmm9, %zmm9
        vmulpd  %zmm10, %zmm10, %zmm10
        vmulpd  %zmm11, %zmm11, %zmm11
        vmulpd  %zmm12, %zmm12, %zmm12
        vmulpd  %zmm9, %zmm5, %zmm5
        vmulpd  %zmm10, %zmm6, %zmm6
        vmulpd  %zmm11, %zmm7, %zmm7
        vmulpd  %zmm12, %zmm8, %zmm8
        vmulpd  %zmm5, %zmm1, %zmm1
        vmulpd  %zmm6, %zmm2, %zmm2
        vmulpd  %zmm7, %zmm3, %zmm3
        vmulpd  %zmm8, %zmm4, %zmm4
        vmovupd %zmm1, -192(%rax,%rcx,8)
        vmovupd %zmm2, -128(%rax,%rcx,8)
        vmovupd %zmm3, -64(%rax,%rcx,8)
        vmovupd %zmm4, (%rax,%rcx,8)
        addq    $256, %rcx
        addl    $-8, %ebp
        jne     .LBB0_4
.LBB0_5:
        testb   $7, %r10b
        je      .LBB0_8
        vbroadcastsd    (%r14,%r8), %zmm0
        incb    %r9b
        movl    %ecx, %ecx
        leaq    192(%rsi,%rcx,8), %rax
        leaq    192(%rdx,%rcx,8), %rcx
        xorl    %ebx, %ebx
        movzbl  %r9b, %ebp
        andl    $7, %ebp
        shlq    $8, %rbp
        .p2align        4, 0x90
.LBB0_7:
        vaddpd  -192(%rcx,%rbx), %zmm0, %zmm1
        vaddpd  -128(%rcx,%rbx), %zmm0, %zmm2
        vaddpd  -64(%rcx,%rbx), %zmm0, %zmm3
        vaddpd  (%rcx,%rbx), %zmm0, %zmm4
        vmulpd  %zmm1, %zmm1, %zmm5
        vmulpd  %zmm2, %zmm2, %zmm6
        vmulpd  %zmm3, %zmm3, %zmm7
        vmulpd  %zmm4, %zmm4, %zmm8
        vmulpd  %zmm5, %zmm5, %zmm9
        vmulpd  %zmm6, %zmm6, %zmm10
        vmulpd  %zmm7, %zmm7, %zmm11
        vmulpd  %zmm8, %zmm8, %zmm12
        vmulpd  %zmm9, %zmm9, %zmm9
        vmulpd  %zmm10, %zmm10, %zmm10
        vmulpd  %zmm11, %zmm11, %zmm11
        vmulpd  %zmm12, %zmm12, %zmm12
        vmulpd  %zmm9, %zmm5, %zmm5
        vmulpd  %zmm10, %zmm6, %zmm6
        vmulpd  %zmm11, %zmm7, %zmm7
        vmulpd  %zmm12, %zmm8, %zmm8
        vmulpd  %zmm5, %zmm1, %zmm1
        vmulpd  %zmm6, %zmm2, %zmm2
        vmulpd  %zmm7, %zmm3, %zmm3
        vmulpd  %zmm8, %zmm4, %zmm4
        vmovupd %zmm1, -192(%rax,%rbx)
        vmovupd %zmm2, -128(%rax,%rbx)
        vmovupd %zmm3, -64(%rax,%rbx)
        vmovupd %zmm4, (%rax,%rbx)
        addq    $256, %rbx
        cmpl    %ebx, %ebp
        jne     .LBB0_7
.LBB0_8:
        cmpl    %edi, %r11d
        je      .LBB0_11
.LBB0_9:
        vmovsd  (%r14,%r8), %xmm0
        movl    %r11d, %eax
        .p2align        4, 0x90
.LBB0_10:
        vaddsd  (%rdx,%rax,8), %xmm0, %xmm1
        vmulsd  %xmm1, %xmm1, %xmm2
        vmulsd  %xmm2, %xmm2, %xmm3
        vmulsd  %xmm3, %xmm3, %xmm3
        vmulsd  %xmm3, %xmm2, %xmm2
        vmulsd  %xmm2, %xmm1, %xmm1
        vmovsd  %xmm1, (%rsi,%rax,8)
        incq    %rax
        cmpl    %edi, %eax
        jl      .LBB0_10
.LBB0_11:
        popq    %rbx
        popq    %r14
        popq    %rbp
        vzeroupper
        retq
.Lfunc_end0:
        .size   f, .Lfunc_end0-f

        .section        ".note.GNU-stack","",@progbits
***********************

The loop unrolling is quite intense, not just the use of SIMD. Do they have some sort of grudge against jumps and branches?

The benchmark results are as follows:

$ cabal-3.10.3.0 bench -w ghc-9.6.6 -O2 --builddir=dist-ncg
benchmarking Haskell/map
time                 38.65 μs   (38.35 μs .. 38.92 μs)
                     0.999 R²   (0.999 R² .. 1.000 R²)
mean                 38.38 μs   (37.96 μs .. 39.03 μs)
std dev              1.766 μs   (1.215 μs .. 3.018 μs)
variance introduced by outliers: 52% (severely inflated)

benchmarking Haskell unrolled/map
time                 13.79 μs   (13.70 μs .. 13.89 μs)
                     0.999 R²   (0.999 R² .. 1.000 R²)
mean                 13.70 μs   (13.60 μs .. 13.84 μs)
std dev              390.5 ns   (300.6 ns .. 520.6 ns)
variance introduced by outliers: 32% (moderately inflated)

benchmarking JIT/map
time                 17.89 μs   (17.75 μs .. 18.06 μs)
                     0.999 R²   (0.999 R² .. 1.000 R²)
mean                 17.92 μs   (17.81 μs .. 18.10 μs)
std dev              474.0 ns   (349.9 ns .. 644.7 ns)
variance introduced by outliers: 28% (moderately inflated)

benchmarking JIT/array
time                 1.485 μs   (1.467 μs .. 1.506 μs)
                     0.999 R²   (0.998 R² .. 0.999 R²)
mean                 1.487 μs   (1.473 μs .. 1.511 μs)
std dev              56.94 ns   (41.96 ns .. 85.39 ns)
variance introduced by outliers: 52% (severely inflated)
$ cabal-3.10.3.0 bench -w ghc-9.6.6 -O2 --ghc-options=-fllvm --builddir=dist-llvm
benchmarking Haskell/map
time                 4.237 μs   (4.207 μs .. 4.272 μs)
                     0.999 R²   (0.999 R² .. 1.000 R²)
mean                 4.227 μs   (4.195 μs .. 4.271 μs)
std dev              128.6 ns   (93.79 ns .. 171.3 ns)
variance introduced by outliers: 38% (moderately inflated)

benchmarking Haskell unrolled/map
time                 4.202 μs   (4.171 μs .. 4.238 μs)
                     0.999 R²   (0.999 R² .. 1.000 R²)
mean                 4.249 μs   (4.213 μs .. 4.309 μs)
std dev              153.8 ns   (110.9 ns .. 229.4 ns)
variance introduced by outliers: 47% (moderately inflated)

benchmarking JIT/map
time                 17.79 μs   (17.64 μs .. 17.96 μs)
                     0.999 R²   (0.999 R² .. 1.000 R²)
mean                 17.87 μs   (17.73 μs .. 18.20 μs)
std dev              671.3 ns   (342.2 ns .. 1.204 μs)
variance introduced by outliers: 44% (moderately inflated)

benchmarking JIT/array
time                 1.492 μs   (1.475 μs .. 1.514 μs)
                     0.999 R²   (0.998 R² .. 0.999 R²)
mean                 1.492 μs   (1.478 μs .. 1.511 μs)
std dev              53.86 ns   (37.95 ns .. 72.97 ns)
variance introduced by outliers: 49% (moderately inflated)

Comparing "Haskell unrolled/map" using GHC's LLVM backend with the auto-vectorized + JIT-compiled version (JIT/array), we see 4.202 / 1.492 ≈ 2.82, showing that auto-vectorization made it 2.8 times faster. Since it's AVX-512, one might expect 8-way parallelism for Double, but there are claims that internal 512-bit width processing only starts from Zen 5. If we consider it as effectively 4-way parallelism, then 2.8 times... is this about what to expect?

Compared to the Apple M4 Pro, the scalar code loses to the M4 Pro (both are mini-PCs with mobile CPUs). However, when using AVX-512 with auto-vectorization, it beats the Apple M4 Pro.

Conclusion

I believe the techniques introduced here are used in accelerate-llvm (though I haven't looked closely). Accelerate often gives the impression of being GPU-focused, but it can also be used for CPUs. LLVM allows for both CPU and GPU code generation. It seems that the maintainers of Accelerate are also involved in llvm-hs.

To truly push a CPU to its limits, you need to utilize not just SIMD via auto-vectorization but also multi-core processing. I suppose that's doable if you put in the effort. For Haskell, it might be worth looking into libraries like repa or massiv.

I used llvm-hs to write this article, and there were quite a few pitfalls. If you ever feel like overthrowing Accelerate, please walk over my dead body.

Discussion