💨

HaskellでEDSLを作る:SIMD編

2024/12/20に公開

シリーズ:

この記事はHaskell Advent Calendar 2024の20日目の記事です。


HaskellでEDSLを作る:LLVM編ではLLVMを使って自動ベクトル化を行い、SIMD命令を活用しました。一方で、GHCにはSIMD命令をHaskellから直接使うプリミティブ型と関数があります。これらを活用できないでしょうか?

例によってサンプルコードはhaskell-dsl-example/simdに置いています。

SIMDとは

SIMDとはsingle instruction, multiple dataの略で、一つの命令で複数のデータを扱う技術を指します。典型的には、複数の値を保持できるSIMDレジスターがあり、専用の命令を使うとそれらをまとめて処理できます。

Haskell風に書くと、こういうデータ型があると考えてください:

data FloatX4 = FloatX4 !Float !Float !Float !Float
data DoubleX2 = DoubleX2 !Double !Double

そして、一つの命令で FloatX4 同士の要素ごとの足し算ができます。Haskell風のコードで書くと、

plusFloatX4 :: FloatX4 -> FloatX4 -> FloatX4
plusFloatX4 (FloatX4 a b c d) (FloatX4 a' b' c' d') = FloatX4 (a + a') (b + b') (c + c') (d + d')

という処理がCPUに組み込まれていて1命令で実行できる感じです。

典型的には、SIMDレジスターの幅はアーキテクチャーによって128ビットあるいは256ビットあるいは512ビットという風に決まっています。ですので、ベクトルあたりの要素数は型によって変わります。128ビット幅のレジスターを持つx86 SSEやArm NEONであれば、Float は4要素、Double なら2要素、とという具合です。

GHCのSIMDプリミティブ

前にHaskell/GHCのSIMDについて考えるという記事を書いたので、そちらも参照してください。

GHCは7.8.1以降でSIMDプリミティブを実装しています。具体的には、以下のプリミティブ型と、

-- 128ビット版
type Int8X16#
type Int16X8#
type Int32X4#
type Int64X2#
type Word8X16#
type Word16X8#
type Word32X4#
type Word64X2#
type FloatX4#
type DoubleX2#

-- 256ビット版、512ビット版もある

いくつかのプリミティブ関数です:

packFloatX4# :: (# Float#, Float#, Float#, Float# #) -> FloatX4#
unpackFloatX4# :: FloatX4# -> (# Float#, Float#, Float#, Float# #)
insertFloatX4# :: FloatX4# -> Float# -> Int# -> FloatX4#
broadcastFloatX4# :: Float# -> FloatX4#
plusFloatX4# :: FloatX4# -> FloatX4# -> FloatX4#
minusFloatX4# :: FloatX4# -> FloatX4# -> FloatX4#
timesFloatX4# :: FloatX4# -> FloatX4# -> FloatX4#
divideFloatX4# :: FloatX4# -> FloatX4# -> FloatX4#
negateFloatX4# :: FloatX4# -> FloatX4# -> FloatX4#
indexFloatX4Array# :: ByteArray# -> Int# -> FloatX4#
readFloatX4Array# :: MutableByteArray# s -> Int# -> State# s -> (# State# s, FloatX4# #)
writeFloatX4Array# :: MutableByteArray# s -> Int# -> FloatX4# -> State# s -> State# s
indexFloatX4OffAddr# :: Addr# -> Int# -> FloatX4#
readFloatX4OffAddr# :: Addr# -> Int# -> State# s -> (# State# s, FloatX4# #)
writeFloatX4OffAddr# :: Addr# -> Int# -> FloatX4# -> State# s -> State# s
indexFloatArrayAsFloatX4# :: ByteArray# -> Int# -> FloatX4#
readFloatArrayAsFloatX4# :: MutableByteArray# s -> Int# -> State# s -> (# State# s, FloatX4# #)
writeFloatArrayAsFloatX4# :: MutableByteArray# s -> Int# -> FloatX4# -> State# s -> State# s
indexFloatOffAddrAsFloatX4# :: Addr# -> Int# -> FloatX4#
readFloatOffAddrAsFloatX4# :: Addr# -> Int# -> State# s -> (# State# s, FloatX4# #)
writeFloatOffAddrAsFloatX4# :: Addr# -> Int# -> FloatX4# -> State# s -> State# s

256ビット、512ビットのベクトル型は常に256ビット、512ビットのSIMDレジスターに対応します。つまり、x86だとAVXやAVX-512を必要とし、「SSEレジスター2本」にはなりません。

注意しないといけないのは、これらは使えるバックエンドに制限があります。具体的には、

  • x86: GHC 7.8以降のLLVMバックエンド、またはGHC 9.12以降のNCGバックエンド(一部の型のみ)
  • AArch64: GHC 9.8以降のLLVMバックエンド

が必要です。

ちなみに、GHC 9.12ではいくつかプリミティブ関数が追加されました(min/max/FMA/shuffle)。

まあ、アレですね。型や関数の名前が # で終わっているのは上級者向けの目印です。一般ユーザーが使うには、何らかの形でこれらをラップする必要があります。

どうラップするか

既存のラッパー

HackageにSIMDプリミティブのラッパーがいくつか上がっています。

simdパッケージは、以下のようなインターフェースを提供しています:

module Data.SIMD.SIMD4 where

class SIMD4 a where
  data X4 a
  plusX4 :: X4 a -> X4 a -> X4 a
  indexArrayAsX4 :: ByteArray -> Int -> X4 a
  broadcastX4 :: a -> X4 a
  -- 略

class SIMD4 a => SIMD4Float a where
  divideX4 :: X4 a -> X4 a -> X4 a

instance (Show a, SIMD4 a) => Show (X4 a)
instance (Num a, SIMD4 a) => Num (X4 a)
instance (Fractional a, SIMD4Float a) => Fractional (X4 a)

instance SIMD4 Float
instance SIMD4 Double
-- 略

vectorizeUnboxedX4 :: (SIMD4 a, VU.Unbox a) => VU.Vector a -> VU.Vector (X4 a)
unVectorizeUnboxedX4 :: (SIMD4 a, VU.Unbox a) => VU.Vector (X4 a) -> VU.Vector a
vectorizeStorableX4 :: (SIMD4 a, Storable a, Storable (X4 a)) => VS.Vector a -> VS.Vector (X4 a)
unVectorizeStorableX4 :: (SIMD4 a, Storable a, Storable (X4 a)) => VS.Vector (X4 a) -> VS.Vector a

要素数に応じた型族を使っているのが特徴です。ただ、このライブラリーでは X4 DoubleDoubleX4# のラッパーになっています。AVXがないと使えないってことですね。

Vector a を処理する際は Vector (X4 a) にキャストするようですが、端数があると処理できません。

もう一つ、primitive-simdパッケージは(型族ではない)個別の型を提供しています。

module Data.Primitive.SIMD where

class (Num v, Real (Elem v)) => SIMDVector v where
  type Elem v
  type ElemTuple v
  broadcastVector :: Elem v -> v
  mapVector :: (Elem v -> Elem v) -> v -> v
  -- ...

data FloatX4
data FloatX8
data FloatX16
data DoubleX2
data DoubleX4
data DoubleX8
-- ...

DoubleX4 型は常に DoubleX4# 型のラッパーかというとそういうわけでもなく、デフォルトでは DoubleX2# を2つ使い、パッケージフラグで使用するベクトルの最大長をno-vec, vec256, vec512という風に選択することができます。

スカラーの配列 Vector a の処理に関するサポートは特になさそうです。

そういうわけで、simdパッケージもprimitive-simdパッケージも一長一短です。特に、一般の Vector Float などを処理する手段がなさそうなのは欠点だと思います。なので、ここでは新しいラッパーを作ることにします。

新しく作るラッパー

まず、simdパッケージの「型族を使う」というアイディアは良さそうなので採用したいです。つまり、ベクトルのビット数ではなく、要素数に応じて型族を作るというアイディアです。また、primitive-simdパッケージの「ベクトル長が短いアーキテクチャでは DoubleX4DoubleX2# を2つ使って表現する」というアイディアも良さそうなので採用します。

-- 使えるベクトル長が128ビットの場合
data X4 a
data instance X4 Float = FloatX4 FloatX4#
data instance X4 Double = DoubleX2X2 DoubleX2# DoubleX2#

型クラスのインスタンスについても考えます。simdパッケージでは SIMD4 というでかいクラスがあり、そこに四則演算も含めていました。しかし、四則演算は個別の型クラスに分けた方が良いと私は思います。一つ考えられるのは、次のようなクラス構成です:

class BroadcastX4 a where
  broadcastX4 :: a -> X4 a

instance BroadcastX4 Float
instance BroadcastX4 Double

class BroadcastX4 a => NumX4 a where
  plusX4 :: X4 a -> X4 a -> X4 a
  -- 略

instance NumX4 Float
instance NumX4 Double

instance (Num a, NumX4 a) => Num (X4 a) where
  (+) = plusX4
  -- ...

が、X4 の部分は他の要素数も考慮して一般化するべきかもしれません:

class Broadcast f a where
  broadcast :: a -> f a

instance Broadcast X4 Float
instance Broadcast X4 Double

class Broadcast f a => NumF f a where
  plusF :: f a -> f a -> f a
  minusF :: f a -> f a -> f a
  timesF :: f a -> f a -> f a
  absF :: f a -> f a
  signumF :: f a -> f a
  negateF :: f a -> f a

instance NumF X4 Float
instance NumF X4 Double

instance (NumF X4 a, Num a) => Num (X4 a) where
  (+) = plusF
  -- ...

Data.Vector.Storable.Vector a を扱えるように、Storable 系のクラスとインスタンスも用意します。

class Storable a => StorableF f a where
  peekElemOffF :: Ptr a -> Int -> IO (f a)
  pokeElemOffF :: Ptr a -> Int -> f a -> IO ()

instance StorableF X4 Float
instance StorableF X4 Double

storable vectorに対して「4要素ごとに処理して」「端数も処理する」map関数は次のように書けます:

mapStorable :: (StorableF X4 a, StorableF X4 b) => (X4 a -> X4 b) -> (a -> b) -> VS.Vector a -> VS.Vector b
mapStorable fv f !v = unsafePerformIO $ do
  let !n = VS.length v
  !result <- VSM.unsafeNew n
  VS.unsafeWith v $ \ !inputPtr ->
    VSM.unsafeWith result $ \ !resultPtr -> do
      let loopVector !i | i + 4 > n = loopScalar i
                        | otherwise = do
                          !a <- peekElemOffF inputPtr i
                          pokeElemOffF resultPtr i (fv a)
                          loopVector (i + 4)
          loopScalar !i | i >= n = pure ()
                        | otherwise = do
                          !a <- peekElemOff inputPtr i
                          pokeElemOff resultPtr i (f a)
                          loopScalar (i + 1)
      loopVector 0
  VS.unsafeFreeze result
{-# INLINE mapStorable #-}

簡単のため、fusion系の最適化は実装していません。

使う側は

mapStorable (\v -> (v + 1)^(10 :: Int)) (\x -> (x + 1)^(10 :: Int)) (VS.fromList [0..10] :: VS.Vector Float)

という風になります。

端数を統一的に扱う

上記の mapStorable 関数には、SIMDによる4要素ごとの処理 X4 a -> X4 b と、端数の処理 a -> b を別々に渡す必要があります。どうせ演算子オーバーロードで同じように書けるので、冗長ですね。

例えば f x = (x + 1)^10 という関数なら関数の型は Num a => a -> a と一般化できるので、

mapStorable' :: (StorableF X4 a, Num a, NumF X4 a) => (forall v. Num v => v -> v) -> VS.Vector a -> VS.Vector a
mapStorable' f = mapStorable f f

とできます。しかし、これでは Num 制約を map 関数で特別扱いすることになってしまい、汎用性が低いです。FractionalBits 系の操作を使いたくなった時に困りそうです。

そこで、別のやり方を考えます。具体的には、端数部分を処理する関数の型に細工を加えて、Identity a -> Identity b と考えます。すると、ベクトル部分の関数も端数部分の関数も f a -> f b という形になり、都合がよさそうです。つまり、X4Identity をインスタンスとするような何らかの型クラス SIMD によって

mapStorable' :: (StorableF X4 a, StorableF X4 b) => (forall f. SIMD f => f a -> f b) -> VS.Vector a -> VS.Vector b
mapStorable' f = mapStorable f (runIdentity . f . Identity)

とするのです。

SIMD f 制約がどういうものであるべきか考えます。例えば、fNum (f Float)Num (f Double) などのインスタンスを持つ必要があります。理想を言えば任意の Num a に対して Num (f a) があって欲しいです。つまり

class (forall a. Num a => Num (f a)) => SIMD f
instance SIMD Identity
instance SIMD X4

となって欲しいですが、これは不可能です。FloatDouble に対してしか Num (X4 a) は定義されていませんからね。ですので、それも制約に加えます。

class (forall a. (Num a, NumF X4 a) => Num (f a)) => SIMD f
instance SIMD Identity
instance SIMD X4

8要素、16要素のSIMD型を実装すると、(Num a, NumF X4 a) の部分は (Num a, NumF X4 a, NumF X8 a, NumF X16 a) という風に増えていきます。ですので、この部分をまとめて

type NumElement a = (Num a, NumF X4 a) -- X8, X16も増えるかもしれない
class (forall a. NumElement a => Num (f a)) => SIMD f
instance SIMD Identity
instance SIMD X4

とします。Num 以外のクラスに対応させる場合は

type NumElement a = (Num a, NumF X4 a) -- X8, X16も増えるかもしれない
type FractionalElement a = (Fractional a, FractionalF X4 a)
type FloatingElement a = (Floating a, FloatingF X4 a)
type MonoidElement a = (Monoid a, MonoidF X4 a)
class ( forall a. NumElement a => Num (f a)
      , forall a. FractionalElement a => Fractional (f a)
      , forall a. FloatingElement a => Floating (f a)
      , forall a. MonoidElement a => Monoid (f a)
      ) => SIMD f
instance SIMD Identity
instance SIMD X4

という風になります。

利用者側は

f :: Num a => a -> a
f x = (x + 1)^10

mapStorable' f (VS.fromList [0..10])

と書いても良いですし、

g :: (SIMD f, NumElement a) => f a -> f a
g x = (x + 1)^10

mapStorable' g (VS.fromList [0..10])

と書くこともできます。

サンプルコードを動かす

サンプルコードはhaskell-dsl-example/simdに置いています。

$ git clone https://github.com/minoki/haskell-dsl-example.git
$ cd haskell-dsl-example/simd

私のメイン環境はApple Silicon Macなので、GHC 9.8以降でLLVMバックエンドが使えるようにしておく必要があります。Haskellの環境構築2023にも書きましたが、Homebrewを使っている人は

$ OPT=$(brew --prefix llvm@15)/bin/opt LLC=$(brew --prefix llvm@15)/bin/llc ghcup install ghc --force 9.8.4

という感じでGHCを入れ直す必要があるかもしれません。

LLVMバックエンドを使うにはGHCに -fllvm オプションを渡す必要があります。Cabalなら --ghc-options=-fllvm です。毎回指定するのはだるいので、最初に

$ cabal configure -w ghc-9.8.4 --ghc-options=-fllvm

を実行しておくと良いかもしれません。このコマンドによって cabal.project.local というファイルができ、設定内容が書き込まれます。

cabal run

$ cabal run
[1.0,1024.0,59049.0,1048576.0,9765625.0,6.0466176e7,2.8247526e8,1.0737418e9,3.4867843e9,1.0e10,2.5937424e10]
[1.0,1024.0,59049.0,1048576.0,9765625.0,6.0466176e7,2.8247526e8,1.0737418e9,3.4867843e9,1.0e10,2.5937424e10]
[1.0,1024.0,59049.0,1048576.0,9765625.0,6.0466176e7,2.8247526e8,1.0737418e9,3.4867843e9,1.0e10,2.5937424e10]
[1.0,1024.0,59049.0,1048576.0,9765625.0,6.0466176e7,2.8247526e8,1.0737418e9,3.4867843e9,1.0e10,2.5937424e10]
[1.0,1024.0,59049.0,1048576.0,9765625.0,6.0466176e7,2.82475249e8,1.073741824e9,3.486784401e9,1.0e10,2.5937424601e10]
[1.0,1024.0,59049.0,1048576.0,9765625.0,6.0466176e7,2.82475249e8,1.073741824e9,3.486784401e9,1.0e10,2.5937424601e10]
[1.0,1024.0,59049.0,1048576.0,9765625.0,6.0466176e7,2.82475249e8,1.073741824e9,3.486784401e9,1.0e10,2.5937424601e10]
[1.0,1024.0,59049.0,1048576.0,9765625.0,6.0466176e7,2.82475249e8,1.073741824e9,3.486784401e9,1.0e10,2.5937424601e10]

という風な内容が出てきたら正常です。

ベンチマークに使うコードはsimd/benchmark/Main.hsです。

Apple M4 Proでの結果

GHC 9.8.4 / LLVMバックエンドでの結果を載せておきます。

$ cabal bench -w ghc-9.8.4 --ghc-options=-fllvm -O2
benchmarking Float/f/scalar
time                 3.206 μs   (3.200 μs .. 3.211 μs)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 3.207 μs   (3.201 μs .. 3.215 μs)
std dev              22.70 ns   (15.40 ns .. 37.09 ns)

benchmarking Float/f/vector
time                 892.3 ns   (889.6 ns .. 894.8 ns)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 893.0 ns   (890.9 ns .. 894.5 ns)
std dev              6.179 ns   (4.750 ns .. 7.933 ns)

benchmarking Float/f/vector (unified)
time                 897.8 ns   (895.3 ns .. 900.2 ns)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 898.7 ns   (896.3 ns .. 902.2 ns)
std dev              9.800 ns   (6.691 ns .. 16.75 ns)

benchmarking Float/g/scalar
time                 3.215 μs   (3.211 μs .. 3.219 μs)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 3.213 μs   (3.210 μs .. 3.220 μs)
std dev              15.61 ns   (8.731 ns .. 28.17 ns)

benchmarking Float/g/vector
time                 893.0 ns   (890.7 ns .. 895.9 ns)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 894.2 ns   (892.1 ns .. 896.7 ns)
std dev              7.870 ns   (6.011 ns .. 11.68 ns)

benchmarking Float/g/vector (unified)
time                 901.2 ns   (897.6 ns .. 904.0 ns)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 898.0 ns   (895.9 ns .. 899.9 ns)
std dev              6.522 ns   (5.186 ns .. 8.885 ns)

benchmarking Double/f/scalar
time                 3.185 μs   (3.179 μs .. 3.193 μs)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 3.188 μs   (3.184 μs .. 3.199 μs)
std dev              21.43 ns   (11.81 ns .. 36.86 ns)

benchmarking Double/f/vector
time                 3.649 μs   (3.621 μs .. 3.674 μs)
                     1.000 R²   (0.999 R² .. 1.000 R²)
mean                 3.662 μs   (3.640 μs .. 3.674 μs)
std dev              54.79 ns   (39.52 ns .. 73.07 ns)
variance introduced by outliers: 13% (moderately inflated)

benchmarking Double/f/vector (unified)
time                 3.671 μs   (3.652 μs .. 3.688 μs)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 3.672 μs   (3.654 μs .. 3.683 μs)
std dev              45.43 ns   (27.65 ns .. 83.19 ns)

benchmarking Double/g/scalar
time                 3.221 μs   (3.214 μs .. 3.228 μs)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 3.220 μs   (3.215 μs .. 3.225 μs)
std dev              17.02 ns   (12.05 ns .. 28.72 ns)

benchmarking Double/g/vector
time                 1.658 μs   (1.656 μs .. 1.662 μs)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 1.662 μs   (1.659 μs .. 1.673 μs)
std dev              16.55 ns   (4.710 ns .. 36.25 ns)

benchmarking Double/g/vector (unified)
time                 1.661 μs   (1.658 μs .. 1.664 μs)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 1.658 μs   (1.656 μs .. 1.662 μs)
std dev              9.381 ns   (4.704 ns .. 17.62 ns)

Float の方は3.206/0.8923≈3.59でした。3倍以上4倍未満です。4並列なのでこんなもんですかね。

Double の方は、^ を使って計算した方はあまり速度が向上していません。冪乗を展開した方は3.221/1.658≈1.94でした。おおよそ2倍、2並列なので妥当ですね。

Ryzen 9 7940HSでの結果(LLVMバックエンド)

AVX-512が使えるRyzen 9 7940HS(Zen 4)での結果も載せておきます。OSはWSL2上のUbuntu 22.04です。

$ cabal bench -w ghc-9.8.4 --ghc-options=-fllvm -O2
benchmarking Float/f/scalar
time                 4.190 μs   (4.154 μs .. 4.231 μs)
                     0.999 R²   (0.999 R² .. 1.000 R²)
mean                 4.201 μs   (4.172 μs .. 4.261 μs)
std dev              139.4 ns   (77.76 ns .. 238.0 ns)
variance introduced by outliers: 42% (moderately inflated)

benchmarking Float/f/vector
time                 1.288 μs   (1.276 μs .. 1.302 μs)
                     0.999 R²   (0.999 R² .. 1.000 R²)
mean                 1.293 μs   (1.282 μs .. 1.308 μs)
std dev              42.72 ns   (30.33 ns .. 65.09 ns)
variance introduced by outliers: 45% (moderately inflated)

benchmarking Float/f/vector (unified)
time                 1.290 μs   (1.279 μs .. 1.301 μs)
                     0.999 R²   (0.999 R² .. 1.000 R²)
mean                 1.291 μs   (1.280 μs .. 1.304 μs)
std dev              42.05 ns   (30.81 ns .. 60.24 ns)
variance introduced by outliers: 45% (moderately inflated)

benchmarking Float/g/scalar
time                 4.187 μs   (4.148 μs .. 4.235 μs)
                     0.999 R²   (0.999 R² .. 1.000 R²)
mean                 4.200 μs   (4.165 μs .. 4.257 μs)
std dev              149.5 ns   (109.6 ns .. 213.1 ns)
variance introduced by outliers: 46% (moderately inflated)

benchmarking Float/g/vector
time                 1.359 μs   (1.350 μs .. 1.370 μs)
                     1.000 R²   (0.999 R² .. 1.000 R²)
mean                 1.363 μs   (1.355 μs .. 1.375 μs)
std dev              32.39 ns   (21.39 ns .. 48.15 ns)
variance introduced by outliers: 30% (moderately inflated)

benchmarking Float/g/vector (unified)
time                 1.366 μs   (1.354 μs .. 1.378 μs)
                     0.999 R²   (0.999 R² .. 1.000 R²)
mean                 1.365 μs   (1.356 μs .. 1.378 μs)
std dev              35.55 ns   (26.86 ns .. 46.75 ns)
variance introduced by outliers: 34% (moderately inflated)

benchmarking Double/f/scalar
time                 4.214 μs   (4.183 μs .. 4.246 μs)
                     0.999 R²   (0.999 R² .. 1.000 R²)
mean                 4.225 μs   (4.195 μs .. 4.280 μs)
std dev              131.8 ns   (87.25 ns .. 222.4 ns)
variance introduced by outliers: 39% (moderately inflated)

benchmarking Double/f/vector
time                 2.159 μs   (2.140 μs .. 2.182 μs)
                     0.999 R²   (0.998 R² .. 0.999 R²)
mean                 2.208 μs   (2.180 μs .. 2.253 μs)
std dev              116.2 ns   (77.12 ns .. 170.5 ns)
variance introduced by outliers: 67% (severely inflated)

benchmarking Double/f/vector (unified)
time                 2.146 μs   (2.131 μs .. 2.165 μs)
                     0.999 R²   (0.999 R² .. 1.000 R²)
mean                 2.153 μs   (2.138 μs .. 2.172 μs)
std dev              57.85 ns   (44.10 ns .. 86.56 ns)
variance introduced by outliers: 34% (moderately inflated)

benchmarking Double/g/scalar
time                 4.192 μs   (4.159 μs .. 4.225 μs)
                     0.999 R²   (0.999 R² .. 1.000 R²)
mean                 4.195 μs   (4.171 μs .. 4.234 μs)
std dev              100.5 ns   (76.68 ns .. 130.0 ns)
variance introduced by outliers: 27% (moderately inflated)

benchmarking Double/g/vector
time                 2.187 μs   (2.171 μs .. 2.205 μs)
                     0.999 R²   (0.999 R² .. 1.000 R²)
mean                 2.192 μs   (2.177 μs .. 2.215 μs)
std dev              61.57 ns   (44.70 ns .. 88.71 ns)
variance introduced by outliers: 36% (moderately inflated)

benchmarking Double/g/vector (unified)
time                 2.181 μs   (2.159 μs .. 2.207 μs)
                     0.999 R²   (0.999 R² .. 1.000 R²)
mean                 2.193 μs   (2.176 μs .. 2.225 μs)
std dev              75.01 ns   (45.34 ns .. 123.4 ns)
variance introduced by outliers: 46% (moderately inflated)

Float の方は4.190/1.288≈3.25でした。これも3倍以上4倍未満です。4並列なのでこんなもんですかね。

こちらも、Double の方は、^ を使って計算した方はあまり速度が向上していません。冪乗を展開した方は4.192/2.181≈1.92でした。おおよそ2倍、2並列なので妥当ですね。

Ryzen 9 7940HSでの結果(NCGバックエンド)

GHC 9.12で実装されたNCGバックエンドでのSIMDサポートも試してみましょう。

$ cabal bench -w ghc-9.12.1 --builddir=dist-ncg -O2 --allow-newer
benchmarking Float/f/scalar
time                 38.18 μs   (38.00 μs .. 38.37 μs)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 38.00 μs   (37.80 μs .. 38.19 μs)
std dev              670.4 ns   (522.7 ns .. 874.4 ns)
variance introduced by outliers: 14% (moderately inflated)

benchmarking Float/f/vector
time                 7.118 μs   (7.049 μs .. 7.184 μs)
                     0.999 R²   (0.999 R² .. 1.000 R²)
mean                 7.121 μs   (7.080 μs .. 7.178 μs)
std dev              170.4 ns   (140.2 ns .. 209.8 ns)
variance introduced by outliers: 26% (moderately inflated)

benchmarking Float/f/vector (unified)
time                 7.151 μs   (7.106 μs .. 7.195 μs)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 7.131 μs   (7.091 μs .. 7.177 μs)
std dev              147.8 ns   (123.3 ns .. 186.5 ns)
variance introduced by outliers: 21% (moderately inflated)

benchmarking Float/g/scalar
time                 13.76 μs   (13.63 μs .. 13.89 μs)
                     0.999 R²   (0.999 R² .. 1.000 R²)
mean                 13.39 μs   (13.22 μs .. 13.54 μs)
std dev              539.9 ns   (434.0 ns .. 699.9 ns)
variance introduced by outliers: 48% (moderately inflated)

benchmarking Float/g/vector
time                 1.397 μs   (1.384 μs .. 1.412 μs)
                     0.999 R²   (0.999 R² .. 1.000 R²)
mean                 1.395 μs   (1.384 μs .. 1.408 μs)
std dev              35.98 ns   (29.38 ns .. 44.60 ns)
variance introduced by outliers: 33% (moderately inflated)

benchmarking Float/g/vector (unified)
time                 1.199 μs   (1.187 μs .. 1.210 μs)
                     0.999 R²   (0.999 R² .. 1.000 R²)
mean                 1.195 μs   (1.186 μs .. 1.205 μs)
std dev              30.76 ns   (26.36 ns .. 35.65 ns)
variance introduced by outliers: 34% (moderately inflated)

benchmarking Double/f/scalar
time                 38.59 μs   (38.38 μs .. 38.83 μs)
                     1.000 R²   (1.000 R² .. 1.000 R²)
mean                 38.65 μs   (38.46 μs .. 38.91 μs)
std dev              699.1 ns   (550.4 ns .. 890.0 ns)
variance introduced by outliers: 14% (moderately inflated)

benchmarking Double/f/vector
time                 8.237 μs   (8.185 μs .. 8.302 μs)
                     1.000 R²   (0.999 R² .. 1.000 R²)
mean                 8.281 μs   (8.229 μs .. 8.349 μs)
std dev              195.3 ns   (145.6 ns .. 255.1 ns)
variance introduced by outliers: 26% (moderately inflated)

benchmarking Double/f/vector (unified)
time                 8.267 μs   (8.208 μs .. 8.332 μs)
                     1.000 R²   (0.999 R² .. 1.000 R²)
mean                 8.291 μs   (8.233 μs .. 8.363 μs)
std dev              205.6 ns   (162.1 ns .. 272.7 ns)
variance introduced by outliers: 27% (moderately inflated)

benchmarking Double/g/scalar
time                 13.91 μs   (13.79 μs .. 14.03 μs)
                     0.999 R²   (0.998 R² .. 0.999 R²)
mean                 13.64 μs   (13.47 μs .. 13.77 μs)
std dev              523.9 ns   (437.4 ns .. 658.1 ns)
variance introduced by outliers: 46% (moderately inflated)

benchmarking Double/g/vector
time                 2.667 μs   (2.640 μs .. 2.701 μs)
                     0.999 R²   (0.999 R² .. 0.999 R²)
mean                 2.690 μs   (2.671 μs .. 2.711 μs)
std dev              67.58 ns   (53.11 ns .. 86.86 ns)
variance introduced by outliers: 31% (moderately inflated)

benchmarking Double/g/vector (unified)
time                 2.267 μs   (2.241 μs .. 2.298 μs)
                     0.999 R²   (0.999 R² .. 1.000 R²)
mean                 2.264 μs   (2.246 μs .. 2.288 μs)
std dev              68.17 ns   (54.30 ns .. 86.66 ns)
variance introduced by outliers: 39% (moderately inflated)

Float の方は f が38.18/7.118≈5.36倍、g が13.76/1.199≈11.5倍でした。4並列なのに4倍以上速度向上しています。

Double の方は、f は38.59/8.237≈4.68倍、g は13.91/2.267≈6.14倍でした。これも2並列なのに2倍以上速度向上しています。

NCGの方でSIMDの利用による速度向上幅が大きいということは、スカラーのコードの最適化が足りないのでしょうか。

これまでを振り返って/一般化

これまで何回か、HaskellでEDSLを作る際に役立つ手法を見てきました。いずれも、演算子オーバーロードで普通のHaskellっぽく書けるようになっています。使った型を見ると、自動微分では Reverse s a、StableName編では Exp または Exp a の形をしていて、SIMDでは X4 a という形をしています。

これらは多くが f a という形をしています(aFloatDouble など、要素の型)。まあ演算子オーバーロードする都合上 f a の形になるのは当然なんですが、何かの意味を見出すことはできるでしょうか?

f はある種の関手 (functor) のようなものと思うことができるかもしれません。ただ、任意の a -> bf a -> f b に持ち上げることができるわけではありません。なので、「Hask圏の強自己関手」(Functor のインスタンス)ではなさそうです。

ひとつの見方としては、f はHaskellのいくつかの型を対象とし、いくつかの関数を射とする圏(Hask圏の部分圏)からHask圏への関手、と考えることができるでしょうか。もっとクールな見方ができるかは私にはわかりません。

ささやかな野望:Haskellによる数値計算フレームワーク

これまで、

  • 自動微分
  • DSLから中間言語を作ってLLVMでJITコンパイル
  • SIMD

などをHaskellで扱う方法を見てきました。これらはいずれも演算子オーバーロードで実装されており、f a という形の型を持ちます。これらを統一的に扱えないでしょうか?つまり、今回定義した SIMD というクラスを一般化して、他の用途でも同様に書けるようにならないでしょうか?

あるいは、HaskellにはすでにAccelerateというフレームワークがありますが、これは現状自動微分をサポートしていません(Support Automatic Differentiation · Issue #398 · AccelerateHS/accelerate)。Accelerateを自動微分に対応させたようなフレームワークを作れないでしょうか?

また、GoogleがPython向けに作っているJAXというフレームワークでは、自動微分、自動ベクトル化、JITコンパイルやGPUでの実行などが行えます。これのHaskell版を作れないでしょうか?

そんな野望をこの数年抱えていたのですが、どうやら私にはそれに取り組むための十分な時間がなさそうです。ですので、同じ志を持った人が現れた時に役に立てるように、必要な技術とアイディアをこの一連の記事にまとめているというわけです。

まあそんな泣き言ばかり言っていてもアレなので、この記事で提示した設計のSIMDラッパーライブラリーをそのうち公開する予定です。


お読みいただきありがとうございました。HaskellはDSLを作る上でも色々な可能性がある言語です。その可能性の一端を感じていただけましたでしょうか?

Discussion