iTranslated by AI
Building an EDSL in Haskell: LLVM Part — JIT Compilation
Series:
- Building an EDSL in Haskell: atomicModifyIORef Edition — Using Automatic Differentiation as a Subject —
- Building an EDSL in Haskell: StableName Edition — Recovery of Sharing —
- Building an EDSL in Haskell: LLVM Edition — JIT Compilation — (This article)
- Building an EDSL in Haskell: SIMD Edition
This article is for the 13th day of the Language Implementation Advent Calendar 2024.
In Building an EDSL in Haskell: StableName Edition, we saw how to recover calculation sharing using StableName.
This time, we will look at how to JIT compile the arithmetic DSL we created using LLVM. Sample code is provided in haskell-dsl-example/llvm.
Overview
How to Call LLVM
There are several ways to call LLVM from Haskell.
First, there is a method where you write out the LLVM IR to a file and call LLVM as a command. For example, GHC writes LLVM IR into .ll files and uses LLVM's opt and llc commands to obtain object files. With this method, the compiler only needs to handle text processing and file I/O, so there is no need to worry about tedious FFI.
Next, there is a method where you call LLVM as a C/C++ library. To call it from Haskell, FFI is required. This method is suitable for JIT compilation.
In this article, we will take the latter approach—calling LLVM via FFI. However, since writing the FFI part from scratch is a lot of work, we will use existing bindings.
LLVM Bindings
There are several bindings available to call LLVM from Haskell. Here, we will use the llvm-hs family.
The package structure of the llvm-hs family is as follows:
- llvm-hs-pure: A pure Haskell part that does not depend on the C++ side. It allows you to construct LLVM IR.
- llvm-hs: Bindings to LLVM implemented in C++ (FFI).
-
llvm-hs-pretty: A pure Haskell pretty printer, but it is unmaintained. Pretty printing itself can be done using the
llvm-hspackage (which depends on LLVM written in C++).
Since only older versions of llvm-hs are available on Hackage, we will use the versions from GitHub. At the time of writing, the llvm-15 branch, which supports LLVM 15, is the latest, and we will use it. However, the current version does not seem to support GHC 9.8 or later (due to a conflict with unzip imported in LLVM.Prelude). Also, newer versions of Cabal do not seem to work; it needs to be 3.10 or lower. Therefore, the content of this article has been verified to work with GHC 9.6.6/Cabal 3.10. Incidentally, looking at the GitHub PRs for llvm-hs, there seem to be versions that have been made compatible with newer GHC and Cabal.
To fetch dependency packages from Git in a Cabal project, describe the source-repository-package in cabal.project as follows:
packages: ./*.cabal
source-repository-package
type: git
location: https://github.com/llvm-hs/llvm-hs.git
tag: 5bca2c1a2a3aa98ecfb19181e7a5ebbf3e212b76
subdir: llvm-hs-pure
source-repository-package
type: git
location: https://github.com/llvm-hs/llvm-hs.git
tag: 5bca2c1a2a3aa98ecfb19181e7a5ebbf3e212b76
subdir: llvm-hs
Building llvm-hs requires LLVM. Specifically, the llvm-config-15 or llvm-config command must be accessible during the configure phase. If you use Homebrew or similar, the PATH to these is not set by default, so you can describe the following in cabal.project.local so that llvm-hs can find llvm-config:
package llvm-hs
extra-prog-path: /opt/homebrew/opt/llvm@15/bin
In the case of Homebrew, the specific location can be found with echo $(brew --prefix llvm@15)/bin.
Haddock for the llvm-15 branch of llvm-hs is not available on Hackage, so if you want to see the documentation, please git clone it yourself and run cabal haddock.
Reference Code
There are several samples in llvm-hs-examples, but the master branch is for LLVM 9, and otherwise there is only an llvm-12 branch, so they are slightly old.
The one closest to our arithmetic DSL is llvm-hs-examples/arith/Arith.hs.
Here, referring to arith/Arith.hs, I will present code compatible with LLVM 15.
Calling Compiled Functions
Suppose we JIT compiled a function and obtained a function pointer of type FunPtr. To call this from Haskell, we define a function using foreign import ccall "dynamic" like this:
foreign import ccall unsafe "dynamic"
mkDoubleFn :: FunPtr (Double -> Double) -> (Double -> Double)
Using the mkDoubleFn function defined this way, you can convert a function pointer into a Haskell function.
Unsafe FFI has low overhead, but it is intended for processes that complete in a short time. For processes that take a long time, it might be better to use safe FFI. I touched on the types of FFI in my previous post "【Low-level Haskell】I want to get as close as possible to inline assembly even in Haskell (GHC)!" (Japanese).
In this method, the type of the function must be determined at the time of writing the code. If you want to JIT compile a function that supports an arbitrary number of arguments, you should either make the interface with Haskell a pointer to a struct or an array, such as a function of type Ptr SomeStruct -> IO (), or use some libffi binding.
Practice
Adding Features to the DSL
The DSL we created last time only supported basic arithmetic. However, since I want to demonstrate function calls for the LLVM demo, I will add several mathematical functions here. I will also add a type parameter so that we can see that the variable type is Double.
data UnaryOp = Negate
| Abs
| Exp
| Log
| Sin
| Cos
deriving (Eq, Show)
data Exp a where
Const :: a -> Exp a
Var :: Exp Double
Unary :: UnaryOp -> Exp Double -> Exp Double
Add :: Exp Double -> Exp Double -> Exp Double
Sub :: Exp Double -> Exp Double -> Exp Double
Mul :: Exp Double -> Exp Double -> Exp Double
Div :: Exp Double -> Exp Double -> Exp Double
deriving instance Show a => Show (Exp a)
instance Num (Exp Double) where
(+) = Add
(-) = Sub
(*) = Mul
fromInteger = Const . fromInteger
negate = Unary Negate
abs = Unary Abs
signum = undefined
instance Fractional (Exp Double) where
(/) = Div
fromRational = Const . fromRational
instance Floating (Exp Double) where
pi = Const pi
exp = Unary Exp
log = Unary Log
sin = Unary Sin
cos = Unary Cos
asin = undefined
acos = undefined
atan = undefined
sinh = undefined
cosh = undefined
asinh = undefined
acosh = undefined
atanh = undefined
eval :: Exp a -> Double -> a
eval (Const k) _ = k
eval Var x = x
eval (Unary Negate a) x = - eval a x
eval (Unary Abs a) x = abs (eval a x)
eval (Unary Exp a) x = exp (eval a x)
eval (Unary Log a) x = log (eval a x)
eval (Unary Sin a) x = sin (eval a x)
eval (Unary Cos a) x = cos (eval a x)
eval (Add a b) x = eval a x + eval b x
eval (Sub a b) x = eval a x - eval b x
eval (Mul a b) x = eval a x * eval b x
eval (Div a b) x = eval a x / eval b x
The part that recovers sharing is as follows:
type Id = Int -- Variable identifier
data Value a = ConstV a
| VarV Id
deriving Show
-- Expression on the right side of a let
data SimpleExp a where
UnaryS :: UnaryOp -> Value Double -> SimpleExp Double
AddS :: Value Double -> Value Double -> SimpleExp Double
SubS :: Value Double -> Value Double -> SimpleExp Double
MulS :: Value Double -> Value Double -> SimpleExp Double
DivS :: Value Double -> Value Double -> SimpleExp Double
deriving instance Show (SimpleExp a)
-- Expression that can represent sharing
data ExpS a = Let Id (SimpleExp Double) (ExpS a)
| Value (Value a)
deriving Show
-- State: (Next identifier to use, List of variables and expressions defined so far, Mapping from StableName to variable ID)
type M = StateT (Int, [(Id, SimpleExp Double)], HM.HashMap Name Id) IO
recoverSharing :: Exp Double -> IO (ExpS Double)
recoverSharing expr = do
-- Omitted
Please refer to the sample code repository for the complete source code.
Code Generation with llvm-hs-pure
The part that generates LLVM IR can be done with llvm-hs-pure, which is a pure Haskell library. I will present the code first:
import qualified LLVM.AST as AST (Module, Operand(ConstantOperand))
import qualified LLVM.AST.Constant as AST (Constant(Float))
import qualified LLVM.AST.Float as AST (SomeFloat(Double))
import qualified LLVM.AST.Type as AST (Type(FunctionType, resultType, argumentTypes, isVarArg), double)
import qualified LLVM.IRBuilder.Instruction as IR (fneg, call, fadd, fsub, fmul, fdiv, ret)
import qualified LLVM.IRBuilder.Module as IR (MonadModuleBuilder, ParameterName(ParameterName), buildModule, extern, function)
import qualified LLVM.IRBuilder.Monad as IR (MonadIRBuilder, named)
doubleFnType :: AST.Type
doubleFnType = AST.FunctionType
{ AST.resultType = AST.double
, AST.argumentTypes = [AST.double]
, AST.isVarArg = False
}
codegen :: ExpS Double -> AST.Module
codegen expr = IR.buildModule "dsl.ll" $ do
absFn <- IR.extern "llvm.fabs.f64" [AST.double] AST.double
expFn <- IR.extern "llvm.exp.f64" [AST.double] AST.double
logFn <- IR.extern "llvm.log.f64" [AST.double] AST.double
sinFn <- IR.extern "llvm.sin.f64" [AST.double] AST.double
cosFn <- IR.extern "llvm.cos.f64" [AST.double] AST.double
let goValue :: IntMap.IntMap AST.Operand -> Value Double -> AST.Operand
goValue _ (ConstV x) = AST.ConstantOperand $ AST.Float $ AST.Double x
goValue env (VarV i) = env IntMap.! i
goSimpleExp :: (IR.MonadIRBuilder m, IR.MonadModuleBuilder m) => IntMap.IntMap AST.Operand -> SimpleExp Double -> m AST.Operand
-- goSimpleExp :: IntMap.IntMap AST.Operand -> SimpleExp Double -> LLVM.IRBuilder.Module.IRBuilderT LLVM.IRBuilder.Module.ModuleBuilder AST.Operand
goSimpleExp env (UnaryS Negate x) = IR.fneg (goValue env x)
goSimpleExp env (UnaryS Abs x) = IR.call doubleFnType absFn [(goValue env x, [])]
goSimpleExp env (UnaryS Exp x) = IR.call doubleFnType expFn [(goValue env x, [])]
goSimpleExp env (UnaryS Log x) = IR.call doubleFnType logFn [(goValue env x, [])]
goSimpleExp env (UnaryS Sin x) = IR.call doubleFnType sinFn [(goValue env x, [])]
goSimpleExp env (UnaryS Cos x) = IR.call doubleFnType cosFn [(goValue env x, [])]
goSimpleExp env (AddS x y) = IR.fadd (goValue env x) (goValue env y)
goSimpleExp env (SubS x y) = IR.fsub (goValue env x) (goValue env y)
goSimpleExp env (MulS x y) = IR.fmul (goValue env x) (goValue env y)
goSimpleExp env (DivS x y) = IR.fdiv (goValue env x) (goValue env y)
goExp :: (IR.MonadIRBuilder m, IR.MonadModuleBuilder m) => IntMap.IntMap AST.Operand -> ExpS Double -> m AST.Operand
-- goExp :: IntMap.IntMap AST.Operand -> ExpS Double -> LLVM.IRBuilder.Module.IRBuilderT LLVM.IRBuilder.Module.ModuleBuilder AST.Operand
goExp env (Let i x body) = do
v <- goSimpleExp env x `IR.named` SBS.toShort (BS.Char8.pack ("x" ++ show i))
goExp (IntMap.insert i v env) body
goExp env (Value v) = pure $ goValue env v
let xparam :: IR.ParameterName
xparam = IR.ParameterName "x"
_ <- IR.function "f" [(AST.double, xparam)] AST.double $ \[arg] -> do
result <- goExp (IntMap.singleton 0 arg) expr
IR.ret result
pure ()
The codegen function takes the syntax tree of our DSL and creates a module containing the LLVM IR. The module includes function definitions and declarations for external functions.
llvm-hs-pure provides an EDSL (monad) for constructing LLVM IR. This example is too simple to feel like a proper DSL, but it will be clearer in the auto-vectorization example discussed later.
An LLVM IR operand (variable or immediate value) has the type LLVM.AST.Operand. Variables are obtained as return values of functions corresponding to IR instructions.
By using the LLVM.IRBuilder.Monad.named function, you can assign names to defined variables and labels. In practice, suffixes like numbers are added, resulting in something like %x_0.
Pretty Printing the LLVM IR
Let's display the generated LLVM IR in text format. Since the llvm-hs-pretty package is not available, we will use the llvm-hs package. We use the LLVM.Module.moduleLLVMAssembly function.
import qualified Data.ByteString as BS
import qualified LLVM.Context
import qualified LLVM.Module
main :: IO ()
main = do
let f x = (x + 1)^10
expr <- recoverSharing (f Var)
let code = codegen expr
LLVM.Context.withContext $ \context ->
LLVM.Module.withModuleFromAST context code $ \mod' -> do
asm <- LLVM.Module.moduleLLVMAssembly mod'
BS.putStr asm
An example output is as follows:
; ModuleID = 'dsl.ll'
source_filename = "<string>"
declare double @llvm.fabs.f64(double)
declare double @llvm.exp.f64(double)
declare double @llvm.log.f64(double)
declare double @llvm.sin.f64(double)
declare double @llvm.cos.f64(double)
define double @f(double %x_0) {
%x1_0 = fadd double %x_0, 1.000000e+00
%x2_0 = fmul double %x1_0, %x1_0
%x3_0 = fmul double %x2_0, %x2_0
%x4_0 = fmul double %x3_0, %x3_0
%x5_0 = fmul double %x4_0, %x2_0
ret double %x5_0
}
It looks like it is working correctly.
JIT Compilation
JIT compilation is performed as follows:
import qualified LLVM.CodeGenOpt
import qualified LLVM.CodeModel
import qualified LLVM.Context
import qualified LLVM.Linking
import qualified LLVM.Module
import qualified LLVM.OrcJIT as JIT
import qualified LLVM.Relocation
import qualified LLVM.Passes
import qualified LLVM.Target
foreign import ccall unsafe "dynamic"
mkDoubleFun :: FunPtr (Double -> Double) -> (Double -> Double)
withSimpleJIT :: NFData a => ExpS Double -> ((Double -> Double) -> a) -> IO (Maybe a)
withSimpleJIT expr doFun = do
LLVM.Context.withContext $ \context -> do
_ <- LLVM.Linking.loadLibraryPermanently Nothing
LLVM.Module.withModuleFromAST context (codegen expr) $ \mod' -> do
LLVM.Target.withHostTargetMachine LLVM.Relocation.PIC LLVM.CodeModel.JITDefault LLVM.CodeGenOpt.Default $ \targetMachine -> do
asm <- LLVM.Module.moduleLLVMAssembly mod'
putStrLn "*** Before optimization ***"
BS.putStr asm
putStrLn "***************************"
let passSetSpec = LLVM.Passes.PassSetSpec
{ LLVM.Passes.passes = [LLVM.Passes.CuratedPassSet 2]
, LLVM.Passes.targetMachine = Nothing -- Just targetMachine
}
LLVM.Passes.runPasses passSetSpec mod'
asm' <- LLVM.Module.moduleLLVMAssembly mod'
putStrLn "*** After optimization ***"
BS.putStr asm'
putStrLn "**************************"
tasm <- LLVM.Module.moduleTargetAssembly targetMachine mod'
putStrLn "*** Target assembly ***"
BS.putStr tasm
putStrLn "***********************"
JIT.withExecutionSession $ \executionSession -> do
dylib <- JIT.createJITDylib executionSession "myDylib"
JIT.withClonedThreadSafeModule mod' $ \threadSafeModule -> do
objectLayer <- JIT.createRTDyldObjectLinkingLayer executionSession
compileLayer <- JIT.createIRCompileLayer executionSession objectLayer targetMachine
JIT.addDynamicLibrarySearchGeneratorForCurrentProcess compileLayer dylib
JIT.addModule threadSafeModule dylib compileLayer
sym <- JIT.lookupSymbol executionSession compileLayer dylib "f"
case sym of
Left (JIT.JITSymbolError err) -> do
print err
pure Nothing
Right (JIT.JITSymbol fnAddr _jitSymbolFlags) -> do
let fn = mkDoubleFun . castPtrToFunPtr $ wordPtrToPtr fnAddr
Just <$> evaluate (force $ doFun fn)
Since I merely transcribed this from llvm-hs-examples/arith without fully understanding it, I cannot provide a detailed explanation. Nonetheless, I will explain a few points within the scope of my understanding.
The return value of the codegen function we just created (the LLVM.AST.Module module) is passed to LLVM.Module.withModuleFromAST.
As we did earlier, the text representation of the created LLVM IR can be obtained with the LLVM.Module.moduleLLVMAssembly function. Here, it is displayed both before and after optimization.
The generated machine-dependent assembly code can be obtained using the LLVM.Module.moduleTargetAssembly function.
In the llvm-hs-examples sample, CodeModel.Default is used as an argument for LLVM.Target.withHostTargetMachine, but this needs to be set to JITDefault; otherwise, external function calls (like exp or sin) will fail on AArch64.
LLVM.Passes.runPasses is what executes the optimization. This function destructively updates the module passed as an argument. Please note that this must be executed before entering JIT compilation; otherwise, it won't be reflected in the generated machine code. I struggled because I missed this, which caused the auto-vectorization mentioned later to be ineffective (no performance improvement was observed).
The function pointer created by JIT compilation can be obtained as fnAddr :: WordPtr. This is converted to a FunPtr and passed to a foreign import ccall "dynamic" function to convert it into a Haskell function.
The created function disappears once it leaves the scope. The function is consumed by the callback function doFun, and to ensure its evaluation is completed before exiting the scope, force from Control.DeepSeq is used.
Usage
The usage side looks like this:
main :: IO ()
main = do
let f x = (x + 1)^10 * (x + 1) + cos x
expr <- recoverSharing (f Var)
result <- withSimpleJIT expr (\f' -> f' 1.0)
case result of
Nothing -> pure ()
Just result' -> print result'
I deliberately made x + 1 appear twice to see the effect of optimization. However, since GHC's optimization would merge them via common subexpression elimination if enabled, I'll turn off GHC's optimization.
$ cabal run -O0 example-run
*** Before optimization ***
; ModuleID = 'dsl.ll'
source_filename = "<string>"
declare double @llvm.fabs.f64(double)
declare double @llvm.exp.f64(double)
declare double @llvm.log.f64(double)
declare double @llvm.sin.f64(double)
declare double @llvm.cos.f64(double)
define double @f(double %x_0) {
%x1_0 = fadd double %x_0, 1.000000e+00
%x2_0 = fmul double %x1_0, %x1_0
%x3_0 = fmul double %x2_0, %x2_0
%x4_0 = fmul double %x3_0, %x3_0
%x5_0 = fmul double %x4_0, %x2_0
%x6_0 = fadd double %x_0, 1.000000e+00
%x7_0 = fmul double %x5_0, %x6_0
%x8_0 = call double @llvm.cos.f64(double %x_0)
%x9_0 = fadd double %x7_0, %x8_0
ret double %x9_0
}
***************************
*** After optimization ***
; ModuleID = 'dsl.ll'
source_filename = "<string>"
declare double @llvm.cos.f64(double)
define double @f(double %x_0) local_unnamed_addr {
%x1_0 = fadd double %x_0, 1.000000e+00
%x2_0 = fmul double %x1_0, %x1_0
%x3_0 = fmul double %x2_0, %x2_0
%x4_0 = fmul double %x3_0, %x3_0
%x5_0 = fmul double %x2_0, %x4_0
%x7_0 = fmul double %x1_0, %x5_0
%x8_0 = tail call double @llvm.cos.f64(double %x_0)
%x9_0 = fadd double %x7_0, %x8_0
ret double %x9_0
}
**************************
You can see that fadd double %x_0, 1.000000e+00, which appeared in two places before optimization, was consolidated into one after optimization.
The generated assembly and the result of the execution (at x = 1) are as follows:
*** Target assembly ***
.section __TEXT,__text,regular,pure_instructions
.build_version macos, 15, 0
.globl _f
.p2align 2
_f:
.cfi_startproc
stp d9, d8, [sp, #-32]!
.cfi_def_cfa_offset 32
stp x29, x30, [sp, #16]
.cfi_offset w30, -8
.cfi_offset w29, -16
.cfi_offset b8, -24
.cfi_offset b9, -32
fmov d1, #1.00000000
fadd d1, d0, d1
fmul d2, d1, d1
fmul d3, d2, d2
fmul d3, d3, d3
fmul d2, d2, d3
fmul d8, d1, d2
Lloh0:
adrp x8, _cos@GOTPAGE
Lloh1:
ldr x8, [x8, _cos@GOTPAGEOFF]
blr x8
fadd d0, d8, d0
ldp x29, x30, [sp, #16]
ldp d9, d8, [sp], #32
ret
.loh AdrpLdrGot Lloh0, Lloh1
.cfi_endproc
.subsections_via_symbols
***********************
2048.540302305868
Since this was executed on AArch64 macOS, AArch64 assembly source is produced, and an underscore is prefixed to the function name. Other environments would likely yield different results.
Using Auto-vectorization
We've seen that LLVM performs common subexpression elimination, but let's try other optimizations. For example, auto-vectorization is a feature quite distant from standard Haskell, but wouldn't it be easier to utilize if we generate LLVM IR ourselves like this?
Specifically, we define a function like the following:
// Pseudocode in C
void f(int size, double * restrict resultArray, const double *inputArray)
{
for (int i = 0; i < size; ++i) {
double x = inputArray[i];
resultArray[i] = /* Calculation using x */;
}
}
Then, the idea is that LLVM will transform the loop to use SIMD instructions through auto-vectorization. From a Haskell perspective, the function type would be Int32 -> Ptr Double -> Ptr Double -> IO (), but it's best to wrap it so it can be used as VS.Vector Double -> VS.Vector Double using storable vectors (Data.Vector.Storable). There are several variations of Data.Vector, but when taking addresses for use with FFI, storable vectors are generally used. Primitive vectors and the unboxed vectors that use them are difficult to use with FFI because their addresses can change due to GC (they can be used with FFI if you pin them upon allocation or stop the GC during FFI, but that's for advanced users).
The part that generates the LLVM IR for the loop is as follows:
{-# LANGUAGE RecursiveDo #-}
codegen :: ExpS Double -> AST.Module
codegen expr = IR.buildModule "dsl.ll" $ do
absFn <- IR.extern "llvm.fabs.f64" [AST.double] AST.double
expFn <- IR.extern "llvm.exp.f64" [AST.double] AST.double
logFn <- IR.extern "llvm.log.f64" [AST.double] AST.double
sinFn <- IR.extern "llvm.sin.f64" [AST.double] AST.double
cosFn <- IR.extern "llvm.cos.f64" [AST.double] AST.double
let goValue :: IntMap.IntMap AST.Operand -> Value Double -> AST.Operand
-- Omitted
goSimpleExp :: (IR.MonadIRBuilder m, IR.MonadModuleBuilder m) => IntMap.IntMap AST.Operand -> SimpleExp Double -> m AST.Operand
-- Omitted
goExp :: (IR.MonadIRBuilder m, IR.MonadModuleBuilder m) => IntMap.IntMap AST.Operand -> ExpS Double -> m AST.Operand
-- Omitted
let sizeName :: IR.ParameterName
sizeName = IR.ParameterName "size"
let resultArrayName :: IR.ParameterName
resultArrayName = IR.ParameterName "resultArray"
let inputArrayName :: IR.ParameterName
inputArrayName = IR.ParameterName "inputArray"
_ <- IR.function "f" [(AST.i32, sizeName), (AST.ptr, resultArrayName), (AST.ptr, inputArrayName)] AST.void $ \[size, resultArray, inputArray] -> mdo
prologue <- IR.block
IR.br loop
loop <- IR.block `IR.named` "loop"
counter <- IR.phi [(IR.int32 0, prologue), (nextCounter, loopBody)] `IR.named` "counter"
lt <- IR.icmp AST.SLT counter size `IR.named` "lt"
IR.condBr lt loopBody epilogue
loopBody <- IR.block `IR.named` "loopBody"
xPtr <- IR.gep (AST.ArrayType 0 AST.double) inputArray [IR.int32 0, counter] `IR.named` "xPtr"
x <- IR.load AST.double xPtr 0 `IR.named` "x"
result <- goExp (IntMap.singleton 0 x) expr `IR.named` "result"
resultPtr <- IR.gep (AST.ArrayType 0 AST.double) resultArray [IR.int32 0, counter] `IR.named` "resultPtr"
IR.store resultPtr 0 result
nextCounter <- IR.add counter (IR.int32 1) `IR.named` "nextCounter"
IR.br loop
epilogue <- IR.block `IR.named` "epilogue"
IR.retVoid
pure ()
I won't provide a detailed explanation here. Please consult the LLVM Language Reference Manual. Here are a few supplemental points:
- Most functions in
llvm-hs-purethat output LLVM IR instructions are fairly intuitive. However, note that LLVM'sgetelementptris abbreviated asgep. - By using the
mdosyntax from theRecursiveDoextension, we can refer to labels defined further down in the code. - The
0passed toloadandstorerepresents alignment; it seems that specifying0causes the default for the type to be used.
The JIT compilation part is largely the same as before, but there are a few points to note:
foreign import ccall unsafe "dynamic"
mkDoubleArrayFun :: FunPtr (Int32 -> Ptr Double -> Ptr Double -> IO ()) -> (Int32 -> Ptr Double -> Ptr Double -> IO ())
withArrayJIT :: NFData a => ExpS Double -> ((VS.Vector Double -> VS.Vector Double) -> IO a) -> IO (Maybe a)
withArrayJIT expr doFun = do
-- Omitted
let passSetSpec = LLVM.Passes.PassSetSpec
{ LLVM.Passes.passes = [LLVM.Passes.CuratedPassSet 2]
, LLVM.Passes.targetMachine = Just targetMachine
}
LLVM.Passes.runPasses passSetSpec mod'
-- Omitted
sym <- JIT.lookupSymbol executionSession compileLayer dylib "f"
case sym of
Left (JIT.JITSymbolError err) -> do
print err
pure Nothing
Right (JIT.JITSymbol fnAddr _jitSymbolFlags) -> do
let ptrFn = mkDoubleArrayFun . castPtrToFunPtr $ wordPtrToPtr fnAddr
vecFn !inputVec = unsafePerformIO $ do
let !n = VS.length inputVec
resultVec <- VSM.unsafeNew n
VS.unsafeWith inputVec $ \inputPtr ->
VSM.unsafeWith resultVec $ \resultPtr ->
ptrFn (fromIntegral n) resultPtr inputPtr
VS.unsafeFreeze resultVec
result <- doFun vecFn
Just <$> evaluate (force result)
First, the function type changes. Here, I have made it possible to execute an arbitrary IO action using the JIT-compiled function. This works well.
Next, you must specify the targetMachine in passSetSpec. Without this, LLVM will only perform platform-independent optimizations, so vectorization will not be performed.
Finally, we do some manipulation to make the provided function appear as VS.Vector Double -> VS.Vector Double (the vecFn function).
The usage side looks like this:
main :: IO ()
main = do
let f x = (x + 1)^10 * (x + 1)
expr <- recoverSharing (f Var)
_ <- withArrayJIT expr $ \vf -> do
print $ vf (VS.fromList [1..20])
pure ()
I will also provide the execution results on AArch64 macOS:
*** Before optimization ***
; ModuleID = 'dsl.ll'
source_filename = "<string>"
declare double @llvm.fabs.f64(double)
declare double @llvm.exp.f64(double)
declare double @llvm.log.f64(double)
declare double @llvm.sin.f64(double)
declare double @llvm.cos.f64(double)
define void @f(i32 %size_0, ptr %resultArray_0, ptr %inputArray_0) {
br label %loop_0
loop_0: ; preds = %loopBody_0, %0
%counter_0 = phi i32 [ 0, %0 ], [ %nextCounter_0, %loopBody_0 ]
%lt_0 = icmp slt i32 %counter_0, %size_0
br i1 %lt_0, label %loopBody_0, label %epilogue_0
loopBody_0: ; preds = %loop_0
%xPtr_0 = getelementptr [0 x double], ptr %inputArray_0, i32 0, i32 %counter_0
%x_0 = load double, ptr %xPtr_0, align 8
%x1_0 = fadd double %x_0, 1.000000e+00
%x2_0 = fmul double %x1_0, %x1_0
%x3_0 = fmul double %x2_0, %x2_0
%x4_0 = fmul double %x3_0, %x3_0
%x5_0 = fmul double %x4_0, %x2_0
%x6_0 = fmul double %x5_0, %x1_0
%resultPtr_0 = getelementptr [0 x double], ptr %resultArray_0, i32 0, i32 %counter_0
store double %x6_0, ptr %resultPtr_0, align 8
%nextCounter_0 = add i32 %counter_0, 1
br label %loop_0
epilogue_0: ; preds = %loop_0
ret void
}
***************************
*** After optimization ***
; ModuleID = 'dsl.ll'
source_filename = "<string>"
; Function Attrs: argmemonly nofree norecurse nosync nounwind
define void @f(i32 %size_0, ptr nocapture writeonly %resultArray_0, ptr nocapture readonly %inputArray_0) local_unnamed_addr #0 {
%lt_01 = icmp sgt i32 %size_0, 0
br i1 %lt_01, label %loopBody_0.preheader, label %epilogue_0
loopBody_0.preheader: ; preds = %0
%resultArray_03 = ptrtoint ptr %resultArray_0 to i64
%inputArray_04 = ptrtoint ptr %inputArray_0 to i64
%min.iters.check = icmp ult i32 %size_0, 4
%1 = sub i64 %resultArray_03, %inputArray_04
%diff.check = icmp ult i64 %1, 32
%or.cond = select i1 %min.iters.check, i1 true, i1 %diff.check
br i1 %or.cond, label %loopBody_0.preheader6, label %vector.ph
vector.ph: ; preds = %loopBody_0.preheader
%n.vec = and i32 %size_0, -4
br label %vector.body
vector.body: ; preds = %vector.body, %vector.ph
%index = phi i32 [ 0, %vector.ph ], [ %index.next, %vector.body ]
%2 = zext i32 %index to i64
%3 = getelementptr [0 x double], ptr %inputArray_0, i64 0, i64 %2
%wide.load = load <2 x double>, ptr %3, align 8
%4 = getelementptr double, ptr %3, i64 2
%wide.load5 = load <2 x double>, ptr %4, align 8
%5 = fadd <2 x double> %wide.load, <double 1.000000e+00, double 1.000000e+00>
%6 = fadd <2 x double> %wide.load5, <double 1.000000e+00, double 1.000000e+00>
%7 = fmul <2 x double> %5, %5
%8 = fmul <2 x double> %6, %6
%9 = fmul <2 x double> %7, %7
%10 = fmul <2 x double> %8, %8
%11 = fmul <2 x double> %9, %9
%12 = fmul <2 x double> %10, %10
%13 = fmul <2 x double> %7, %11
%14 = fmul <2 x double> %8, %12
%15 = fmul <2 x double> %5, %13
%16 = fmul <2 x double> %6, %14
%17 = getelementptr [0 x double], ptr %resultArray_0, i64 0, i64 %2
store <2 x double> %15, ptr %17, align 8
%18 = getelementptr double, ptr %17, i64 2
store <2 x double> %16, ptr %18, align 8
%index.next = add nuw i32 %index, 4
%19 = icmp eq i32 %index.next, %n.vec
br i1 %19, label %middle.block, label %vector.body, !llvm.loop !0
middle.block: ; preds = %vector.body
%cmp.n = icmp eq i32 %n.vec, %size_0
br i1 %cmp.n, label %epilogue_0, label %loopBody_0.preheader6
loopBody_0.preheader6: ; preds = %loopBody_0.preheader, %middle.block
%counter_02.ph = phi i32 [ 0, %loopBody_0.preheader ], [ %n.vec, %middle.block ]
br label %loopBody_0
loopBody_0: ; preds = %loopBody_0.preheader6, %loopBody_0
%counter_02 = phi i32 [ %nextCounter_0, %loopBody_0 ], [ %counter_02.ph, %loopBody_0.preheader6 ]
%20 = zext i32 %counter_02 to i64
%xPtr_0 = getelementptr [0 x double], ptr %inputArray_0, i64 0, i64 %20
%x_0 = load double, ptr %xPtr_0, align 8
%x1_0 = fadd double %x_0, 1.000000e+00
%x2_0 = fmul double %x1_0, %x1_0
%x3_0 = fmul double %x2_0, %x2_0
%x4_0 = fmul double %x3_0, %x3_0
%x5_0 = fmul double %x2_0, %x4_0
%x6_0 = fmul double %x1_0, %x5_0
%resultPtr_0 = getelementptr [0 x double], ptr %resultArray_0, i64 0, i64 %20
store double %x6_0, ptr %resultPtr_0, align 8
%nextCounter_0 = add nuw nsw i32 %counter_02, 1
%lt_0 = icmp slt i32 %nextCounter_0, %size_0
br i1 %lt_0, label %loopBody_0, label %epilogue_0, !llvm.loop !2
epilogue_0: ; preds = %loopBody_0, %middle.block, %0
ret void
}
attributes #0 = { argmemonly nofree norecurse nosync nounwind }
!0 = distinct !{!0, !1}
!1 = !{!"llvm.loop.isvectorized", i32 1}
!2 = distinct !{!2, !1}
**************************
*** Target assembly ***
.section __TEXT,__text,regular,pure_instructions
.build_version macos, 15, 0
.globl _f
.p2align 2
_f:
cmp w0, #1
b.lt LBB0_8
mov w8, #0
cmp w0, #4
b.lo LBB0_6
sub x9, x1, x2
cmp x9, #32
b.lo LBB0_6
and w8, w0, #0xfffffffc
add x9, x1, #16
add x10, x2, #16
fmov.2d v0, #1.00000000
mov x11, x8
LBB0_4:
ldp q1, q2, [x10, #-16]
fadd.2d v1, v1, v0
fadd.2d v2, v2, v0
fmul.2d v3, v1, v1
fmul.2d v4, v2, v2
fmul.2d v5, v3, v3
fmul.2d v6, v4, v4
fmul.2d v5, v5, v5
fmul.2d v6, v6, v6
fmul.2d v3, v3, v5
fmul.2d v4, v4, v6
fmul.2d v1, v1, v3
fmul.2d v2, v2, v4
stp q1, q2, [x9, #-16]
add x9, x9, #32
add x10, x10, #32
subs w11, w11, #4
b.ne LBB0_4
cmp w8, w0
b.eq LBB0_8
LBB0_6:
mov w8, w8
fmov d0, #1.00000000
LBB0_7:
lsl x9, x8, #3
ldr d1, [x2, x9]
fadd d1, d1, d0
fmul d2, d1, d1
fmul d3, d2, d2
fmul d3, d3, d3
fmul d2, d2, d3
fmul d1, d1, d2
str d1, [x1, x9]
add x8, x8, #1
cmp w8, w0
b.lt LBB0_7
LBB0_8:
ret
.subsections_via_symbols
***********************
[2048.0,177147.0,4194304.0,4.8828125e7,3.62797056e8,1.977326743e9,8.589934592e9,3.1381059609e10,1.0e11,2.85311670611e11,7.43008370688e11,1.792160394037e12,4.049565169664e12,8.649755859375e12,1.7592186044416e13,3.4271896307633e13,6.4268410079232e13,1.16490258898219e14,2.048e14,3.50277500542221e14]
The SIMD available on my CPU (Apple M4 Pro) is the 128-bit wide NEON, which can process two Double values at a time. Furthermore, because the loop is unrolled twice, it processes four Double values in a single loop iteration.
On x86 systems, using AVX or AVX-512 might result in a different number of elements processed per iteration.
Utilizing the noalias Attribute
In the C pseudocode I presented earlier, I used restrict, but the LLVM IR I just generated doesn't use anything equivalent.
In LLVM, you can add the noalias attribute to function arguments (specifically LLVM.AST.ParameterAttribute.NoAlias in llvm-hs-pure), but LLVM.IRBuilder.Module.function doesn't allow specifying attributes for arguments. As a workaround, I will write the function definition part manually.
codegenNoAlias :: ExpS Double -> AST.Module
codegenNoAlias expr = IR.buildModule "dsl.ll" $ do
absFn <- IR.extern "llvm.fabs.f64" [AST.double] AST.double
-- Omitted
cosFn <- IR.extern "llvm.cos.f64" [AST.double] AST.double
let goValue :: IntMap.IntMap AST.Operand -> Value Double -> AST.Operand
-- Omitted
goSimpleExp :: (IR.MonadIRBuilder m, IR.MonadModuleBuilder m) => IntMap.IntMap AST.Operand -> SimpleExp Double -> m AST.Operand
-- Omitted
goExp :: (IR.MonadIRBuilder m, IR.MonadModuleBuilder m) => IntMap.IntMap AST.Operand -> ExpS Double -> m AST.Operand
-- Omitted
let functionBody size resultArray inputArray = mdo
prologue <- IR.block
-- Omitted
IR.retVoid
((sizeName, resultArrayName, inputArrayName), blocks) <- IR.runIRBuilderT IR.emptyIRBuilder $ do
sizeName <- IR.fresh `IR.named` "size"
resultArrayName <- IR.fresh `IR.named` "resultArray"
inputArrayName <- IR.fresh `IR.named` "inputArray"
functionBody (AST.LocalReference AST.i32 sizeName) (AST.LocalReference AST.ptr resultArrayName) (AST.LocalReference AST.ptr inputArrayName)
pure (sizeName, resultArrayName, inputArrayName)
let def = AST.GlobalDefinition AST.functionDefaults
{ AST.name = "f"
, AST.parameters = ([AST.Parameter AST.i32 sizeName [], AST.Parameter AST.ptr resultArrayName [AST.NoAlias], AST.Parameter AST.ptr inputArrayName []], False)
, AST.returnType = AST.void
, AST.basicBlocks = blocks
}
IR.emitDefn def
pure ()
The execution results using this are as follows:
*** Before optimization ***
; ModuleID = 'dsl.ll'
source_filename = "<string>"
declare double @llvm.fabs.f64(double)
declare double @llvm.exp.f64(double)
declare double @llvm.log.f64(double)
declare double @llvm.sin.f64(double)
declare double @llvm.cos.f64(double)
define void @f(i32 %size_0, ptr noalias %resultArray_0, ptr %inputArray_0) {
br label %loop_0
loop_0: ; preds = %loopBody_0, %0
%counter_0 = phi i32 [ 0, %0 ], [ %nextCounter_0, %loopBody_0 ]
%lt_0 = icmp slt i32 %counter_0, %size_0
br i1 %lt_0, label %loopBody_0, label %epilogue_0
loopBody_0: ; preds = %loop_0
%xPtr_0 = getelementptr [0 x double], ptr %inputArray_0, i32 0, i32 %counter_0
%x_0 = load double, ptr %xPtr_0, align 8
%x1_0 = fadd double %x_0, 1.000000e+00
%x2_0 = fmul double %x1_0, %x1_0
%x3_0 = fmul double %x2_0, %x2_0
%x4_0 = fmul double %x3_0, %x3_0
%x5_0 = fmul double %x4_0, %x2_0
%x6_0 = fmul double %x5_0, %x1_0
%resultPtr_0 = getelementptr [0 x double], ptr %resultArray_0, i32 0, i32 %counter_0
store double %x6_0, ptr %resultPtr_0, align 8
%nextCounter_0 = add i32 %counter_0, 1
br label %loop_0
epilogue_0: ; preds = %loop_0
ret void
}
***************************
*** After optimization ***
; ModuleID = 'dsl.ll'
source_filename = "<string>"
; Function Attrs: argmemonly nofree norecurse nosync nounwind
define void @f(i32 %size_0, ptr noalias nocapture writeonly %resultArray_0, ptr nocapture readonly %inputArray_0) local_unnamed_addr #0 {
%lt_01 = icmp sgt i32 %size_0, 0
br i1 %lt_01, label %loopBody_0.preheader, label %epilogue_0
loopBody_0.preheader: ; preds = %0
%min.iters.check = icmp ult i32 %size_0, 4
br i1 %min.iters.check, label %loopBody_0.preheader4, label %vector.ph
vector.ph: ; preds = %loopBody_0.preheader
%n.vec = and i32 %size_0, -4
br label %vector.body
vector.body: ; preds = %vector.body, %vector.ph
%index = phi i32 [ 0, %vector.ph ], [ %index.next, %vector.body ]
%1 = zext i32 %index to i64
%2 = getelementptr [0 x double], ptr %inputArray_0, i64 0, i64 %1
%wide.load = load <2 x double>, ptr %2, align 8
%3 = getelementptr double, ptr %2, i64 2
%wide.load3 = load <2 x double>, ptr %3, align 8
%4 = fadd <2 x double> %wide.load, <double 1.000000e+00, double 1.000000e+00>
%5 = fadd <2 x double> %wide.load3, <double 1.000000e+00, double 1.000000e+00>
%6 = fmul <2 x double> %4, %4
%7 = fmul <2 x double> %5, %5
%8 = fmul <2 x double> %6, %6
%9 = fmul <2 x double> %7, %7
%10 = fmul <2 x double> %8, %8
%11 = fmul <2 x double> %9, %9
%12 = fmul <2 x double> %6, %10
%13 = fmul <2 x double> %7, %11
%14 = fmul <2 x double> %4, %12
%15 = fmul <2 x double> %5, %13
%16 = getelementptr [0 x double], ptr %resultArray_0, i64 0, i64 %1
store <2 x double> %14, ptr %16, align 8
%17 = getelementptr double, ptr %16, i64 2
store <2 x double> %15, ptr %17, align 8
%index.next = add nuw i32 %index, 4
%18 = icmp eq i32 %index.next, %n.vec
br i1 %18, label %middle.block, label %vector.body, !llvm.loop !0
middle.block: ; preds = %vector.body
%cmp.n = icmp eq i32 %n.vec, %size_0
br i1 %cmp.n, label %epilogue_0, label %loopBody_0.preheader4
loopBody_0.preheader4: ; preds = %loopBody_0.preheader, %middle.block
%counter_02.ph = phi i32 [ 0, %loopBody_0.preheader ], [ %n.vec, %middle.block ]
br label %loopBody_0
loopBody_0: ; preds = %loopBody_0.preheader4, %loopBody_0
%counter_02 = phi i32 [ %nextCounter_0, %loopBody_0 ], [ %counter_02.ph, %loopBody_0.preheader4 ]
%19 = zext i32 %counter_02 to i64
%xPtr_0 = getelementptr [0 x double], ptr %inputArray_0, i64 0, i64 %19
%x_0 = load double, ptr %xPtr_0, align 8
%x1_0 = fadd double %x_0, 1.000000e+00
%x2_0 = fmul double %x1_0, %x1_0
%x3_0 = fmul double %x2_0, %x2_0
%x4_0 = fmul double %x3_0, %x3_0
%x5_0 = fmul double %x2_0, %x4_0
%x6_0 = fmul double %x1_0, %x5_0
%resultPtr_0 = getelementptr [0 x double], ptr %resultArray_0, i64 0, i64 %19
store double %x6_0, ptr %resultPtr_0, align 8
%nextCounter_0 = add nuw nsw i32 %counter_02, 1
%lt_0 = icmp slt i32 %nextCounter_0, %size_0
br i1 %lt_0, label %loopBody_0, label %epilogue_0, !llvm.loop !2
epilogue_0: ; preds = %loopBody_0, %middle.block, %0
ret void
}
attributes #0 = { argmemonly nofree norecurse nosync nounwind }
!0 = distinct !{!0, !1}
!1 = !{!"llvm.loop.isvectorized", i32 1}
!2 = distinct !{!2, !3, !1}
!3 = !{!"llvm.loop.unroll.runtime.disable"}
**************************
*** Target assembly ***
.section __TEXT,__text,regular,pure_instructions
.build_version macos, 15, 0
.globl _f
.p2align 2
_f:
cmp w0, #1
b.lt LBB0_8
cmp w0, #4
b.hs LBB0_3
mov w8, #0
b LBB0_6
LBB0_3:
and w8, w0, #0xfffffffc
add x9, x1, #16
add x10, x2, #16
fmov.2d v0, #1.00000000
mov x11, x8
LBB0_4:
ldp q1, q2, [x10, #-16]
fadd.2d v1, v1, v0
fadd.2d v2, v2, v0
fmul.2d v3, v1, v1
fmul.2d v4, v2, v2
fmul.2d v5, v3, v3
fmul.2d v6, v4, v4
fmul.2d v5, v5, v5
fmul.2d v6, v6, v6
fmul.2d v3, v3, v5
fmul.2d v4, v4, v6
fmul.2d v1, v1, v3
fmul.2d v2, v2, v4
stp q1, q2, [x9, #-16]
add x9, x9, #32
add x10, x10, #32
subs w11, w11, #4
b.ne LBB0_4
cmp w8, w0
b.eq LBB0_8
LBB0_6:
mov w8, w8
fmov d0, #1.00000000
LBB0_7:
lsl x9, x8, #3
ldr d1, [x2, x9]
fadd d1, d1, d0
fmul d2, d1, d1
fmul d3, d2, d2
fmul d3, d3, d3
fmul d2, d2, d3
fmul d1, d1, d2
str d1, [x1, x9]
add x8, x8, #1
cmp w8, w0
b.lt LBB0_7
LBB0_8:
ret
.subsections_via_symbols
***********************
[2048.0,177147.0,4194304.0,4.8828125e7,3.62797056e8,1.977326743e9,8.589934592e9,3.1381059609e10,1.0e11,2.85311670611e11,7.43008370688e11,1.792160394037e12,4.049565169664e12,8.649755859375e12,1.7592186044416e13,3.4271896307633e13,6.4268410079232e13,1.16490258898219e14,2.048e14,3.50277500542221e14]
If you look closely, you can see that the noalias attribute is added to the argument, and one branch in the function prologue has been removed.
Benchmark
Let's compare the performance between the process written in pure Haskell and the code generated by LLVM.
For the pure Haskell side, we prepare a version that calculates (x + 1)^(10 :: Int) normally and one where ^10 is unrolled. We also prepare a case using map from storable vector to run a Double -> Double function, versus passing the entire array via FFI.
import Criterion.Main
f :: Num a => a -> a
f x = (x + 1)^(10 :: Int)
{-# SPECIALIZE f :: Double -> Double #-}
g :: Num a => a -> a
g x = let y1 = x + 1
y2 = y1 * y1
y4 = y2 * y2
y8 = y4 * y4
in y8 * y2
{-# SPECIALIZE g :: Double -> Double #-}
main :: IO ()
main = do
expr <- recoverSharing (f Var)
let !input = VS.fromList [0..10000]
_ <- withSimpleJIT expr $ \simpleF ->
withArrayJIT expr $ \arrayF ->
defaultMain
[ bench "Haskell/vector" $ whnf (VS.map f) input
, bench "Haskell unrolled/vector" $ whnf (VS.map g) input
, bench "JIT/vector" $ whnf (VS.map simpleF) input
, bench "JIT/array" $ whnf arrayF input
]
pure ()
Results on Apple M4 Pro
First, here are the results using GHC's NCG backend:
$ cabal-3.10.3.0 bench -w ghc-9.6.6 -O2 --builddir=dist-ncg
benchmarking Haskell/map
time 26.91 μs (26.88 μs .. 26.93 μs)
1.000 R² (1.000 R² .. 1.000 R²)
mean 26.90 μs (26.85 μs .. 26.93 μs)
std dev 138.9 ns (111.8 ns .. 185.2 ns)
benchmarking Haskell unrolled/map
time 7.239 μs (7.228 μs .. 7.251 μs)
1.000 R² (1.000 R² .. 1.000 R²)
mean 7.241 μs (7.229 μs .. 7.254 μs)
std dev 41.28 ns (33.48 ns .. 51.65 ns)
benchmarking JIT/map
time 20.97 μs (20.93 μs .. 21.00 μs)
1.000 R² (1.000 R² .. 1.000 R²)
mean 20.96 μs (20.93 μs .. 20.99 μs)
std dev 102.1 ns (79.48 ns .. 131.4 ns)
benchmarking JIT/array
time 1.632 μs (1.631 μs .. 1.634 μs)
1.000 R² (1.000 R² .. 1.000 R²)
mean 1.633 μs (1.630 μs .. 1.635 μs)
std dev 7.347 ns (5.794 ns .. 9.874 ns)
Here are the results using GHC's LLVM backend as well:
$ cabal-3.10.3.0 bench -w ghc-9.6.6 -O2 --ghc-options=-fllvm --builddir=dist-llvm
benchmarking Haskell/map
time 3.181 μs (3.175 μs .. 3.187 μs)
1.000 R² (1.000 R² .. 1.000 R²)
mean 3.180 μs (3.173 μs .. 3.191 μs)
std dev 27.62 ns (17.49 ns .. 48.06 ns)
benchmarking Haskell unrolled/map
time 3.212 μs (3.208 μs .. 3.217 μs)
1.000 R² (1.000 R² .. 1.000 R²)
mean 3.214 μs (3.211 μs .. 3.217 μs)
std dev 11.45 ns (9.470 ns .. 15.02 ns)
benchmarking JIT/map
time 7.238 μs (7.223 μs .. 7.252 μs)
1.000 R² (1.000 R² .. 1.000 R²)
mean 7.227 μs (7.211 μs .. 7.250 μs)
std dev 61.64 ns (40.19 ns .. 98.23 ns)
benchmarking JIT/array
time 1.653 μs (1.650 μs .. 1.658 μs)
1.000 R² (1.000 R² .. 1.000 R²)
mean 1.649 μs (1.647 μs .. 1.653 μs)
std dev 8.994 ns (6.586 ns .. 12.49 ns)
The approach of calling JIT-compiled code element by element (JIT/map) resulted in performance that couldn't match the unrolled ^10 in pure Haskell (Haskell unrolled/vector), likely due to overhead in function calls or elsewhere.
The approach of processing the whole array with JIT-compiled code (JIT/array) performed well. The time taken was about half that of the unrolled ^10 in pure Haskell using GHC's LLVM backend (Haskell unrolled/map). Since it processes two elements simultaneously with SIMD, this result is as expected. It was worth the effort of generating the code manually and calling LLVM.
Results on Ryzen 9 7940HS
I will also include the results for the Ryzen 9 7940HS (Zen 4), which supports AVX-512. The OS used was Ubuntu 22.04 on WSL2.
The assembly code generated for (x + 1)^10 * (x + 1) was as follows:
*** Target assembly ***
.text
.file "0cstring0e"
.section .rodata.cst8,"aM",@progbits,8
.p2align 3
.LCPI0_0:
.quad 0x3ff0000000000000
.text
.globl f
.p2align 4, 0x90
.type f,@function
f:
pushq %rbp
pushq %r14
pushq %rbx
.L0$pb:
leaq .L0$pb(%rip), %rax
movabsq $_GLOBAL_OFFSET_TABLE_-.L0$pb, %r14
addq %rax, %r14
testl %edi, %edi
jle .LBB0_11
movabsq $.LCPI0_0@GOTOFF, %r8
xorl %r11d, %r11d
cmpl $32, %edi
jb .LBB0_9
movl %edi, %r11d
andl $-32, %r11d
xorl %ecx, %ecx
leal -32(%r11), %eax
movl %eax, %r9d
shrl $5, %r9d
leal 1(%r9), %r10d
cmpl $224, %eax
jb .LBB0_5
vbroadcastsd (%r14,%r8), %zmm0
movl %r10d, %ebp
andl $-8, %ebp
leaq 1984(%rsi), %rax
leaq 1984(%rdx), %rbx
xorl %ecx, %ecx
.p2align 4, 0x90
.LBB0_4:
vaddpd -1984(%rbx,%rcx,8), %zmm0, %zmm1
vaddpd -1920(%rbx,%rcx,8), %zmm0, %zmm2
vaddpd -1856(%rbx,%rcx,8), %zmm0, %zmm3
vaddpd -1792(%rbx,%rcx,8), %zmm0, %zmm4
vmulpd %zmm1, %zmm1, %zmm5
vmulpd %zmm2, %zmm2, %zmm6
vmulpd %zmm3, %zmm3, %zmm7
vmulpd %zmm4, %zmm4, %zmm8
vmulpd %zmm5, %zmm5, %zmm9
vmulpd %zmm6, %zmm6, %zmm10
vmulpd %zmm7, %zmm7, %zmm11
vmulpd %zmm8, %zmm8, %zmm12
vmulpd %zmm9, %zmm9, %zmm9
vmulpd %zmm10, %zmm10, %zmm10
vmulpd %zmm11, %zmm11, %zmm11
vmulpd %zmm12, %zmm12, %zmm12
vmulpd %zmm9, %zmm5, %zmm5
vmulpd %zmm10, %zmm6, %zmm6
vmulpd %zmm11, %zmm7, %zmm7
vmulpd %zmm12, %zmm8, %zmm8
vmulpd %zmm5, %zmm1, %zmm1
vmulpd %zmm6, %zmm2, %zmm2
vmulpd %zmm7, %zmm3, %zmm3
vmulpd %zmm8, %zmm4, %zmm4
vmovupd %zmm1, -1984(%rax,%rcx,8)
vmovupd %zmm2, -1920(%rax,%rcx,8)
vmovupd %zmm3, -1856(%rax,%rcx,8)
vmovupd %zmm4, -1792(%rax,%rcx,8)
vaddpd -1728(%rbx,%rcx,8), %zmm0, %zmm1
vaddpd -1664(%rbx,%rcx,8), %zmm0, %zmm2
vaddpd -1600(%rbx,%rcx,8), %zmm0, %zmm3
vaddpd -1536(%rbx,%rcx,8), %zmm0, %zmm4
vmulpd %zmm1, %zmm1, %zmm5
vmulpd %zmm2, %zmm2, %zmm6
vmulpd %zmm3, %zmm3, %zmm7
vmulpd %zmm4, %zmm4, %zmm8
vmulpd %zmm5, %zmm5, %zmm9
vmulpd %zmm6, %zmm6, %zmm10
vmulpd %zmm7, %zmm7, %zmm11
vmulpd %zmm8, %zmm8, %zmm12
vmulpd %zmm9, %zmm9, %zmm9
vmulpd %zmm10, %zmm10, %zmm10
vmulpd %zmm11, %zmm11, %zmm11
vmulpd %zmm12, %zmm12, %zmm12
vmulpd %zmm9, %zmm5, %zmm5
vmulpd %zmm10, %zmm6, %zmm6
vmulpd %zmm11, %zmm7, %zmm7
vmulpd %zmm12, %zmm8, %zmm8
vmulpd %zmm5, %zmm1, %zmm1
vmulpd %zmm6, %zmm2, %zmm2
vmulpd %zmm7, %zmm3, %zmm3
vmulpd %zmm8, %zmm4, %zmm4
vmovupd %zmm1, -1728(%rax,%rcx,8)
vmovupd %zmm2, -1664(%rax,%rcx,8)
vmovupd %zmm3, -1600(%rax,%rcx,8)
vmovupd %zmm4, -1536(%rax,%rcx,8)
vaddpd -1472(%rbx,%rcx,8), %zmm0, %zmm1
vaddpd -1408(%rbx,%rcx,8), %zmm0, %zmm2
vaddpd -1344(%rbx,%rcx,8), %zmm0, %zmm3
vaddpd -1280(%rbx,%rcx,8), %zmm0, %zmm4
vmulpd %zmm1, %zmm1, %zmm5
vmulpd %zmm2, %zmm2, %zmm6
vmulpd %zmm3, %zmm3, %zmm7
vmulpd %zmm4, %zmm4, %zmm8
vmulpd %zmm5, %zmm5, %zmm9
vmulpd %zmm6, %zmm6, %zmm10
vmulpd %zmm7, %zmm7, %zmm11
vmulpd %zmm8, %zmm8, %zmm12
vmulpd %zmm9, %zmm9, %zmm9
vmulpd %zmm10, %zmm10, %zmm10
vmulpd %zmm11, %zmm11, %zmm11
vmulpd %zmm12, %zmm12, %zmm12
vmulpd %zmm9, %zmm5, %zmm5
vmulpd %zmm10, %zmm6, %zmm6
vmulpd %zmm11, %zmm7, %zmm7
vmulpd %zmm12, %zmm8, %zmm8
vmulpd %zmm5, %zmm1, %zmm1
vmulpd %zmm6, %zmm2, %zmm2
vmulpd %zmm7, %zmm3, %zmm3
vmulpd %zmm8, %zmm4, %zmm4
vmovupd %zmm1, -1472(%rax,%rcx,8)
vmovupd %zmm2, -1408(%rax,%rcx,8)
vmovupd %zmm3, -1344(%rax,%rcx,8)
vmovupd %zmm4, -1280(%rax,%rcx,8)
vaddpd -1216(%rbx,%rcx,8), %zmm0, %zmm1
vaddpd -1152(%rbx,%rcx,8), %zmm0, %zmm2
vaddpd -1088(%rbx,%rcx,8), %zmm0, %zmm3
vaddpd -1024(%rbx,%rcx,8), %zmm0, %zmm4
vmulpd %zmm1, %zmm1, %zmm5
vmulpd %zmm2, %zmm2, %zmm6
vmulpd %zmm3, %zmm3, %zmm7
vmulpd %zmm4, %zmm4, %zmm8
vmulpd %zmm5, %zmm5, %zmm9
vmulpd %zmm6, %zmm6, %zmm10
vmulpd %zmm7, %zmm7, %zmm11
vmulpd %zmm8, %zmm8, %zmm12
vmulpd %zmm9, %zmm9, %zmm9
vmulpd %zmm10, %zmm10, %zmm10
vmulpd %zmm11, %zmm11, %zmm11
vmulpd %zmm12, %zmm12, %zmm12
vmulpd %zmm9, %zmm5, %zmm5
vmulpd %zmm10, %zmm6, %zmm6
vmulpd %zmm11, %zmm7, %zmm7
vmulpd %zmm12, %zmm8, %zmm8
vmulpd %zmm5, %zmm1, %zmm1
vmulpd %zmm6, %zmm2, %zmm2
vmulpd %zmm7, %zmm3, %zmm3
vmulpd %zmm8, %zmm4, %zmm4
vmovupd %zmm1, -1216(%rax,%rcx,8)
vmovupd %zmm2, -1152(%rax,%rcx,8)
vmovupd %zmm3, -1088(%rax,%rcx,8)
vmovupd %zmm4, -1024(%rax,%rcx,8)
vaddpd -960(%rbx,%rcx,8), %zmm0, %zmm1
vaddpd -896(%rbx,%rcx,8), %zmm0, %zmm2
vaddpd -832(%rbx,%rcx,8), %zmm0, %zmm3
vaddpd -768(%rbx,%rcx,8), %zmm0, %zmm4
vmulpd %zmm1, %zmm1, %zmm5
vmulpd %zmm2, %zmm2, %zmm6
vmulpd %zmm3, %zmm3, %zmm7
vmulpd %zmm4, %zmm4, %zmm8
vmulpd %zmm5, %zmm5, %zmm9
vmulpd %zmm6, %zmm6, %zmm10
vmulpd %zmm7, %zmm7, %zmm11
vmulpd %zmm8, %zmm8, %zmm12
vmulpd %zmm9, %zmm9, %zmm9
vmulpd %zmm10, %zmm10, %zmm10
vmulpd %zmm11, %zmm11, %zmm11
vmulpd %zmm12, %zmm12, %zmm12
vmulpd %zmm9, %zmm5, %zmm5
vmulpd %zmm10, %zmm6, %zmm6
vmulpd %zmm11, %zmm7, %zmm7
vmulpd %zmm12, %zmm8, %zmm8
vmulpd %zmm5, %zmm1, %zmm1
vmulpd %zmm6, %zmm2, %zmm2
vmulpd %zmm7, %zmm3, %zmm3
vmulpd %zmm8, %zmm4, %zmm4
vmovupd %zmm1, -960(%rax,%rcx,8)
vmovupd %zmm2, -896(%rax,%rcx,8)
vmovupd %zmm3, -832(%rax,%rcx,8)
vmovupd %zmm4, -768(%rax,%rcx,8)
vaddpd -704(%rbx,%rcx,8), %zmm0, %zmm1
vaddpd -640(%rbx,%rcx,8), %zmm0, %zmm2
vaddpd -576(%rbx,%rcx,8), %zmm0, %zmm3
vaddpd -512(%rbx,%rcx,8), %zmm0, %zmm4
vmulpd %zmm1, %zmm1, %zmm5
vmulpd %zmm2, %zmm2, %zmm6
vmulpd %zmm3, %zmm3, %zmm7
vmulpd %zmm4, %zmm4, %zmm8
vmulpd %zmm5, %zmm5, %zmm9
vmulpd %zmm6, %zmm6, %zmm10
vmulpd %zmm7, %zmm7, %zmm11
vmulpd %zmm8, %zmm8, %zmm12
vmulpd %zmm9, %zmm9, %zmm9
vmulpd %zmm10, %zmm10, %zmm10
vmulpd %zmm11, %zmm11, %zmm11
vmulpd %zmm12, %zmm12, %zmm12
vmulpd %zmm9, %zmm5, %zmm5
vmulpd %zmm10, %zmm6, %zmm6
vmulpd %zmm11, %zmm7, %zmm7
vmulpd %zmm12, %zmm8, %zmm8
vmulpd %zmm5, %zmm1, %zmm1
vmulpd %zmm6, %zmm2, %zmm2
vmulpd %zmm7, %zmm3, %zmm3
vmulpd %zmm8, %zmm4, %zmm4
vmovupd %zmm1, -704(%rax,%rcx,8)
vmovupd %zmm2, -640(%rax,%rcx,8)
vmovupd %zmm3, -576(%rax,%rcx,8)
vmovupd %zmm4, -512(%rax,%rcx,8)
vaddpd -448(%rbx,%rcx,8), %zmm0, %zmm1
vaddpd -384(%rbx,%rcx,8), %zmm0, %zmm2
vaddpd -320(%rbx,%rcx,8), %zmm0, %zmm3
vaddpd -256(%rbx,%rcx,8), %zmm0, %zmm4
vmulpd %zmm1, %zmm1, %zmm5
vmulpd %zmm2, %zmm2, %zmm6
vmulpd %zmm3, %zmm3, %zmm7
vmulpd %zmm4, %zmm4, %zmm8
vmulpd %zmm5, %zmm5, %zmm9
vmulpd %zmm6, %zmm6, %zmm10
vmulpd %zmm7, %zmm7, %zmm11
vmulpd %zmm8, %zmm8, %zmm12
vmulpd %zmm9, %zmm9, %zmm9
vmulpd %zmm10, %zmm10, %zmm10
vmulpd %zmm11, %zmm11, %zmm11
vmulpd %zmm12, %zmm12, %zmm12
vmulpd %zmm9, %zmm5, %zmm5
vmulpd %zmm10, %zmm6, %zmm6
vmulpd %zmm11, %zmm7, %zmm7
vmulpd %zmm12, %zmm8, %zmm8
vmulpd %zmm5, %zmm1, %zmm1
vmulpd %zmm6, %zmm2, %zmm2
vmulpd %zmm7, %zmm3, %zmm3
vmulpd %zmm8, %zmm4, %zmm4
vmovupd %zmm1, -448(%rax,%rcx,8)
vmovupd %zmm2, -384(%rax,%rcx,8)
vmovupd %zmm3, -320(%rax,%rcx,8)
vmovupd %zmm4, -256(%rax,%rcx,8)
vaddpd -192(%rbx,%rcx,8), %zmm0, %zmm1
vaddpd -128(%rbx,%rcx,8), %zmm0, %zmm2
vaddpd -64(%rbx,%rcx,8), %zmm0, %zmm3
vaddpd (%rbx,%rcx,8), %zmm0, %zmm4
vmulpd %zmm1, %zmm1, %zmm5
vmulpd %zmm2, %zmm2, %zmm6
vmulpd %zmm3, %zmm3, %zmm7
vmulpd %zmm4, %zmm4, %zmm8
vmulpd %zmm5, %zmm5, %zmm9
vmulpd %zmm6, %zmm6, %zmm10
vmulpd %zmm7, %zmm7, %zmm11
vmulpd %zmm8, %zmm8, %zmm12
vmulpd %zmm9, %zmm9, %zmm9
vmulpd %zmm10, %zmm10, %zmm10
vmulpd %zmm11, %zmm11, %zmm11
vmulpd %zmm12, %zmm12, %zmm12
vmulpd %zmm9, %zmm5, %zmm5
vmulpd %zmm10, %zmm6, %zmm6
vmulpd %zmm11, %zmm7, %zmm7
vmulpd %zmm12, %zmm8, %zmm8
vmulpd %zmm5, %zmm1, %zmm1
vmulpd %zmm6, %zmm2, %zmm2
vmulpd %zmm7, %zmm3, %zmm3
vmulpd %zmm8, %zmm4, %zmm4
vmovupd %zmm1, -192(%rax,%rcx,8)
vmovupd %zmm2, -128(%rax,%rcx,8)
vmovupd %zmm3, -64(%rax,%rcx,8)
vmovupd %zmm4, (%rax,%rcx,8)
addq $256, %rcx
addl $-8, %ebp
jne .LBB0_4
.LBB0_5:
testb $7, %r10b
je .LBB0_8
vbroadcastsd (%r14,%r8), %zmm0
incb %r9b
movl %ecx, %ecx
leaq 192(%rsi,%rcx,8), %rax
leaq 192(%rdx,%rcx,8), %rcx
xorl %ebx, %ebx
movzbl %r9b, %ebp
andl $7, %ebp
shlq $8, %rbp
.p2align 4, 0x90
.LBB0_7:
vaddpd -192(%rcx,%rbx), %zmm0, %zmm1
vaddpd -128(%rcx,%rbx), %zmm0, %zmm2
vaddpd -64(%rcx,%rbx), %zmm0, %zmm3
vaddpd (%rcx,%rbx), %zmm0, %zmm4
vmulpd %zmm1, %zmm1, %zmm5
vmulpd %zmm2, %zmm2, %zmm6
vmulpd %zmm3, %zmm3, %zmm7
vmulpd %zmm4, %zmm4, %zmm8
vmulpd %zmm5, %zmm5, %zmm9
vmulpd %zmm6, %zmm6, %zmm10
vmulpd %zmm7, %zmm7, %zmm11
vmulpd %zmm8, %zmm8, %zmm12
vmulpd %zmm9, %zmm9, %zmm9
vmulpd %zmm10, %zmm10, %zmm10
vmulpd %zmm11, %zmm11, %zmm11
vmulpd %zmm12, %zmm12, %zmm12
vmulpd %zmm9, %zmm5, %zmm5
vmulpd %zmm10, %zmm6, %zmm6
vmulpd %zmm11, %zmm7, %zmm7
vmulpd %zmm12, %zmm8, %zmm8
vmulpd %zmm5, %zmm1, %zmm1
vmulpd %zmm6, %zmm2, %zmm2
vmulpd %zmm7, %zmm3, %zmm3
vmulpd %zmm8, %zmm4, %zmm4
vmovupd %zmm1, -192(%rax,%rbx)
vmovupd %zmm2, -128(%rax,%rbx)
vmovupd %zmm3, -64(%rax,%rbx)
vmovupd %zmm4, (%rax,%rbx)
addq $256, %rbx
cmpl %ebx, %ebp
jne .LBB0_7
.LBB0_8:
cmpl %edi, %r11d
je .LBB0_11
.LBB0_9:
vmovsd (%r14,%r8), %xmm0
movl %r11d, %eax
.p2align 4, 0x90
.LBB0_10:
vaddsd (%rdx,%rax,8), %xmm0, %xmm1
vmulsd %xmm1, %xmm1, %xmm2
vmulsd %xmm2, %xmm2, %xmm3
vmulsd %xmm3, %xmm3, %xmm3
vmulsd %xmm3, %xmm2, %xmm2
vmulsd %xmm2, %xmm1, %xmm1
vmovsd %xmm1, (%rsi,%rax,8)
incq %rax
cmpl %edi, %eax
jl .LBB0_10
.LBB0_11:
popq %rbx
popq %r14
popq %rbp
vzeroupper
retq
.Lfunc_end0:
.size f, .Lfunc_end0-f
.section ".note.GNU-stack","",@progbits
***********************
The loop unrolling is quite intense, not just the use of SIMD. Do they have some sort of grudge against jumps and branches?
The benchmark results are as follows:
$ cabal-3.10.3.0 bench -w ghc-9.6.6 -O2 --builddir=dist-ncg
benchmarking Haskell/map
time 38.65 μs (38.35 μs .. 38.92 μs)
0.999 R² (0.999 R² .. 1.000 R²)
mean 38.38 μs (37.96 μs .. 39.03 μs)
std dev 1.766 μs (1.215 μs .. 3.018 μs)
variance introduced by outliers: 52% (severely inflated)
benchmarking Haskell unrolled/map
time 13.79 μs (13.70 μs .. 13.89 μs)
0.999 R² (0.999 R² .. 1.000 R²)
mean 13.70 μs (13.60 μs .. 13.84 μs)
std dev 390.5 ns (300.6 ns .. 520.6 ns)
variance introduced by outliers: 32% (moderately inflated)
benchmarking JIT/map
time 17.89 μs (17.75 μs .. 18.06 μs)
0.999 R² (0.999 R² .. 1.000 R²)
mean 17.92 μs (17.81 μs .. 18.10 μs)
std dev 474.0 ns (349.9 ns .. 644.7 ns)
variance introduced by outliers: 28% (moderately inflated)
benchmarking JIT/array
time 1.485 μs (1.467 μs .. 1.506 μs)
0.999 R² (0.998 R² .. 0.999 R²)
mean 1.487 μs (1.473 μs .. 1.511 μs)
std dev 56.94 ns (41.96 ns .. 85.39 ns)
variance introduced by outliers: 52% (severely inflated)
$ cabal-3.10.3.0 bench -w ghc-9.6.6 -O2 --ghc-options=-fllvm --builddir=dist-llvm
benchmarking Haskell/map
time 4.237 μs (4.207 μs .. 4.272 μs)
0.999 R² (0.999 R² .. 1.000 R²)
mean 4.227 μs (4.195 μs .. 4.271 μs)
std dev 128.6 ns (93.79 ns .. 171.3 ns)
variance introduced by outliers: 38% (moderately inflated)
benchmarking Haskell unrolled/map
time 4.202 μs (4.171 μs .. 4.238 μs)
0.999 R² (0.999 R² .. 1.000 R²)
mean 4.249 μs (4.213 μs .. 4.309 μs)
std dev 153.8 ns (110.9 ns .. 229.4 ns)
variance introduced by outliers: 47% (moderately inflated)
benchmarking JIT/map
time 17.79 μs (17.64 μs .. 17.96 μs)
0.999 R² (0.999 R² .. 1.000 R²)
mean 17.87 μs (17.73 μs .. 18.20 μs)
std dev 671.3 ns (342.2 ns .. 1.204 μs)
variance introduced by outliers: 44% (moderately inflated)
benchmarking JIT/array
time 1.492 μs (1.475 μs .. 1.514 μs)
0.999 R² (0.998 R² .. 0.999 R²)
mean 1.492 μs (1.478 μs .. 1.511 μs)
std dev 53.86 ns (37.95 ns .. 72.97 ns)
variance introduced by outliers: 49% (moderately inflated)
Comparing "Haskell unrolled/map" using GHC's LLVM backend with the auto-vectorized + JIT-compiled version (JIT/array), we see 4.202 / 1.492 ≈ 2.82, showing that auto-vectorization made it 2.8 times faster. Since it's AVX-512, one might expect 8-way parallelism for Double, but there are claims that internal 512-bit width processing only starts from Zen 5. If we consider it as effectively 4-way parallelism, then 2.8 times... is this about what to expect?
Compared to the Apple M4 Pro, the scalar code loses to the M4 Pro (both are mini-PCs with mobile CPUs). However, when using AVX-512 with auto-vectorization, it beats the Apple M4 Pro.
Conclusion
I believe the techniques introduced here are used in accelerate-llvm (though I haven't looked closely). Accelerate often gives the impression of being GPU-focused, but it can also be used for CPUs. LLVM allows for both CPU and GPU code generation. It seems that the maintainers of Accelerate are also involved in llvm-hs.
To truly push a CPU to its limits, you need to utilize not just SIMD via auto-vectorization but also multi-core processing. I suppose that's doable if you put in the effort. For Haskell, it might be worth looking into libraries like repa or massiv.
I used llvm-hs to write this article, and there were quite a few pitfalls. If you ever feel like overthrowing Accelerate, please walk over my dead body.
Discussion