🦉

F#でアセンブリプログラミング(実践編)

2023/12/24に公開

データを繰り返し処理する部分をSIMDを使った並列処理で高速化してみようと思います。

復号アルゴリズム

ここに某ゲームデータがあります。このデータは暗号化されています。しかし、既に解析されており、そのアルゴリズムはGrfCryptと名前が付けられGitHubなどで公開されています。例えばここ
この中でProcessBlockと呼ばれる関数は8byteデータを受け取り復号して返します。つまりすべてのデータは8byte区切りされこの関数に渡されるわけで、ここを高速化することに意義があります。

ということで、ProcessBlockの高速化について順を追って説明していきます。

オリジナルコード

いつ書いたか思い出せませんが、私がF#に移植したコードはこんな感じです。

let IP = [|
    0x39uy; 0x31uy; 0x29uy; 0x21uy; 0x19uy; 0x11uy; 0x09uy; 0x01uy;  0x3Buy; 0x33uy; 0x2Buy; 0x23uy; 0x1Buy; 0x13uy; 0x0Buy; 0x03uy;
    0x3Duy; 0x35uy; 0x2Duy; 0x25uy; 0x1Duy; 0x15uy; 0x0Duy; 0x05uy;  0x3Fuy; 0x37uy; 0x2Fuy; 0x27uy; 0x1Fuy; 0x17uy; 0x0Fuy; 0x07uy;
    0x38uy; 0x30uy; 0x28uy; 0x20uy; 0x18uy; 0x10uy; 0x08uy; 0x00uy;  0x3Auy; 0x32uy; 0x2Auy; 0x22uy; 0x1Auy; 0x12uy; 0x0Auy; 0x02uy;
    0x3Cuy; 0x34uy; 0x2Cuy; 0x24uy; 0x1Cuy; 0x14uy; 0x0Cuy; 0x04uy;  0x3Euy; 0x36uy; 0x2Euy; 0x26uy; 0x1Euy; 0x16uy; 0x0Euy; 0x06uy;
|]
let private IPINV = [|
    0x27uy; 0x07uy; 0x2Fuy; 0x0Fuy; 0x37uy; 0x17uy; 0x3Fuy; 0x1Fuy;  0x26uy; 0x06uy; 0x2Euy; 0x0Euy; 0x36uy; 0x16uy; 0x3Euy; 0x1Euy;
    0x25uy; 0x05uy; 0x2Duy; 0x0Duy; 0x35uy; 0x15uy; 0x3Duy; 0x1Duy;  0x24uy; 0x04uy; 0x2Cuy; 0x0Cuy; 0x34uy; 0x14uy; 0x3Cuy; 0x1Cuy;
    0x23uy; 0x03uy; 0x2Buy; 0x0Buy; 0x33uy; 0x13uy; 0x3Buy; 0x1Buy;  0x22uy; 0x02uy; 0x2Auy; 0x0Auy; 0x32uy; 0x12uy; 0x3Auy; 0x1Auy;
    0x21uy; 0x01uy; 0x29uy; 0x09uy; 0x31uy; 0x11uy; 0x39uy; 0x19uy;  0x20uy; 0x00uy; 0x28uy; 0x08uy; 0x30uy; 0x10uy; 0x38uy; 0x18uy;
|]
let private S =[|
    [|
        0xEFuy; 0x03uy; 0x41uy; 0xFDuy; 0xD8uy; 0x74uy; 0x1Euy; 0x47uy;  0x26uy; 0xEFuy; 0xFBuy; 0x22uy; 0xB3uy; 0xD8uy; 0x84uy; 0x1Euy;
        0x39uy; 0xACuy; 0xA7uy; 0x60uy; 0x62uy; 0xC1uy; 0xCDuy; 0xBAuy;  0x5Cuy; 0x96uy; 0x90uy; 0x59uy; 0x05uy; 0x3Buy; 0x7Auy; 0x85uy;
        0x40uy; 0xFDuy; 0x1Euy; 0xC8uy; 0xE7uy; 0x8Auy; 0x8Buy; 0x21uy;  0xDAuy; 0x43uy; 0x64uy; 0x9Fuy; 0x2Duy; 0x14uy; 0xB1uy; 0x72uy;
        0xF5uy; 0x5Buy; 0xC8uy; 0xB6uy; 0x9Cuy; 0x37uy; 0x76uy; 0xECuy;  0x39uy; 0xA0uy; 0xA3uy; 0x05uy; 0x52uy; 0x6Euy; 0x0Fuy; 0xD9uy;
    |];
    [|
        0xA7uy; 0xDDuy; 0x0Duy; 0x78uy; 0x9Euy; 0x0Buy; 0xE3uy; 0x95uy;  0x60uy; 0x36uy; 0x36uy; 0x4Fuy; 0xF9uy; 0x60uy; 0x5Auy; 0xA3uy;
        0x11uy; 0x24uy; 0xD2uy; 0x87uy; 0xC8uy; 0x52uy; 0x75uy; 0xECuy;  0xBBuy; 0xC1uy; 0x4Cuy; 0xBAuy; 0x24uy; 0xFEuy; 0x8Fuy; 0x19uy;
        0xDAuy; 0x13uy; 0x66uy; 0xAFuy; 0x49uy; 0xD0uy; 0x90uy; 0x06uy;  0x8Cuy; 0x6Auy; 0xFBuy; 0x91uy; 0x37uy; 0x8Duy; 0x0Duy; 0x78uy;
        0xBFuy; 0x49uy; 0x11uy; 0xF4uy; 0x23uy; 0xE5uy; 0xCEuy; 0x3Buy;  0x55uy; 0xBCuy; 0xA2uy; 0x57uy; 0xE8uy; 0x22uy; 0x74uy; 0xCEuy;
    |];
    [|
        0x2Cuy; 0xEAuy; 0xC1uy; 0xBFuy; 0x4Auy; 0x24uy; 0x1Fuy; 0xC2uy;  0x79uy; 0x47uy; 0xA2uy; 0x7Cuy; 0xB6uy; 0xD9uy; 0x68uy; 0x15uy;
        0x80uy; 0x56uy; 0x5Duy; 0x01uy; 0x33uy; 0xFDuy; 0xF4uy; 0xAEuy;  0xDEuy; 0x30uy; 0x07uy; 0x9Buy; 0xE5uy; 0x83uy; 0x9Buy; 0x68uy;
        0x49uy; 0xB4uy; 0x2Euy; 0x83uy; 0x1Fuy; 0xC2uy; 0xB5uy; 0x7Cuy;  0xA2uy; 0x19uy; 0xD8uy; 0xE5uy; 0x7Cuy; 0x2Fuy; 0x83uy; 0xDAuy;
        0xF7uy; 0x6Buy; 0x90uy; 0xFEuy; 0xC4uy; 0x01uy; 0x5Auy; 0x97uy;  0x61uy; 0xA6uy; 0x3Duy; 0x40uy; 0x0Buy; 0x58uy; 0xE6uy; 0x3Duy;
    |];
    [|
        0x4Duy; 0xD1uy; 0xB2uy; 0x0Fuy; 0x28uy; 0xBDuy; 0xE4uy; 0x78uy;  0xF6uy; 0x4Auy; 0x0Fuy; 0x93uy; 0x8Buy; 0x17uy; 0xD1uy; 0xA4uy;
        0x3Auy; 0xECuy; 0xC9uy; 0x35uy; 0x93uy; 0x56uy; 0x7Euy; 0xCBuy;  0x55uy; 0x20uy; 0xA0uy; 0xFEuy; 0x6Cuy; 0x89uy; 0x17uy; 0x62uy;
        0x17uy; 0x62uy; 0x4Buy; 0xB1uy; 0xB4uy; 0xDEuy; 0xD1uy; 0x87uy;  0xC9uy; 0x14uy; 0x3Cuy; 0x4Auy; 0x7Euy; 0xA8uy; 0xE2uy; 0x7Duy;
        0xA0uy; 0x9Fuy; 0xF6uy; 0x5Cuy; 0x6Auy; 0x09uy; 0x8Duy; 0xF0uy;  0x0Fuy; 0xE3uy; 0x53uy; 0x25uy; 0x95uy; 0x36uy; 0x28uy; 0xCBuy;
    |];
|]
let private E = [|
    0x3Fuy; 0x20uy; 0x21uy; 0x22uy; 0x23uy; 0x24uy;  0x23uy; 0x24uy; 0x25uy; 0x26uy; 0x27uy; 0x28uy;
    0x27uy; 0x28uy; 0x29uy; 0x2Auy; 0x2Buy; 0x2Cuy;  0x2Buy; 0x2Cuy; 0x2Duy; 0x2Euy; 0x2Fuy; 0x30uy;
    0x2Fuy; 0x30uy; 0x31uy; 0x32uy; 0x33uy; 0x34uy;  0x33uy; 0x34uy; 0x35uy; 0x36uy; 0x37uy; 0x38uy;
    0x37uy; 0x38uy; 0x39uy; 0x3Auy; 0x3Buy; 0x3Cuy;  0x3Buy; 0x3Cuy; 0x3Duy; 0x3Euy; 0x3Fuy; 0x20uy;
|]
let private P = [|
    0x0Fuy; 0x06uy; 0x13uy; 0x14uy; 0x1Cuy; 0x0Buy; 0x1Buy; 0x10uy;  0x00uy; 0x0Euy; 0x16uy; 0x19uy; 0x04uy; 0x11uy; 0x1Euy; 0x09uy;
    0x01uy; 0x07uy; 0x17uy; 0x0Duy; 0x1Fuy; 0x1Auy; 0x02uy; 0x08uy;  0x12uy; 0x0Cuy; 0x1Duy; 0x05uy; 0x15uy; 0x0Auy; 0x03uy; 0x18uy; 
|]

let mixbit (table : byte[]) scale src dst =
    for dstoffset = 0 to table.Length / scale - 1 do
    for dstshift = 0 to scale - 1 do
    let value = int table.[dstoffset * scale + dstshift]
    let srcoffset = value >>> 3
    let srcshift = value &&& 7
    if 0x80uy >>> srcshift &&& NativePtr.get src srcoffset <> 0uy then
        let b = NativePtr.toByRef (NativePtr.add dst dstoffset)
        b <- b ^^^ (0x01uy <<< scale - dstshift - 1)

let ProcessBlock buf =
    let stack1 = NativePtr.stackalloc 16
    let stack2 = NativePtr.add stack1 8
    mixbit IP 8 buf stack1
    mixbit E 6 stack1 stack2
    S |> Array.iteri (fun i s -> (s[NativePtr.get stack2 (i * 2) |> int] &&& 0xF0uy) ||| (s[NativePtr.get stack2 (i * 2 + 1) |> int] &&& 0x0Fuy) |> NativePtr.set stack2 i)
    mixbit P 8 stack2 stack1
    NativePtr.toNativeInt buf |> NativePtr.ofNativeInt |> NativePtr.write <| 0uL
    mixbit IPINV 8 stack1 buf

F#での最適化

まずはF#言語で表現できる範囲での最適化をします。

  • 配列アクセスは効率が悪いので、あらかじめ固定しておきポインターアクセスします。
  • 元の処理はbig endianのようで、bit shift方向とbyte shift方向が一致していませんでした。これはデータを反転すれば一致させることができます。
  • 他の配列と異なりEはある程度パターン化しておりなおかつ連番ですのでまとめて処理できます。結果も1byteずつSの処理に渡されるため、ここは全てループアンローリングしてしまった方がいいでしょう。

そんなこんなを盛り込むとこんな感じになりました。

let private IP, IPINV, S, P =
    let IP = [|
        0x06uy; 0x0Euy; 0x16uy; 0x1Euy; 0x26uy; 0x2Euy; 0x36uy; 0x3Euy;  0x04uy; 0x0Cuy; 0x14uy; 0x1Cuy; 0x24uy; 0x2Cuy; 0x34uy; 0x3Cuy;
        0x02uy; 0x0Auy; 0x12uy; 0x1Auy; 0x22uy; 0x2Auy; 0x32uy; 0x3Auy;  0x00uy; 0x08uy; 0x10uy; 0x18uy; 0x20uy; 0x28uy; 0x30uy; 0x38uy;
        0x07uy; 0x0Fuy; 0x17uy; 0x1Fuy; 0x27uy; 0x2Fuy; 0x37uy; 0x3Fuy;  0x05uy; 0x0Duy; 0x15uy; 0x1Duy; 0x25uy; 0x2Duy; 0x35uy; 0x3Duy;
        0x03uy; 0x0Buy; 0x13uy; 0x1Buy; 0x23uy; 0x2Buy; 0x33uy; 0x3Buy;  0x01uy; 0x09uy; 0x11uy; 0x19uy; 0x21uy; 0x29uy; 0x31uy; 0x39uy; 
    |]
    let IPINV = [|
        0x18uy; 0x38uy; 0x10uy; 0x30uy; 0x08uy; 0x28uy; 0x00uy; 0x20uy;  0x19uy; 0x39uy; 0x11uy; 0x31uy; 0x09uy; 0x29uy; 0x01uy; 0x21uy;
        0x1Auy; 0x3Auy; 0x12uy; 0x32uy; 0x0Auy; 0x2Auy; 0x02uy; 0x22uy;  0x1Buy; 0x3Buy; 0x13uy; 0x33uy; 0x0Buy; 0x2Buy; 0x03uy; 0x23uy;
        0x1Cuy; 0x3Cuy; 0x14uy; 0x34uy; 0x0Cuy; 0x2Cuy; 0x04uy; 0x24uy;  0x1Duy; 0x3Duy; 0x15uy; 0x35uy; 0x0Duy; 0x2Duy; 0x05uy; 0x25uy;
        0x1Euy; 0x3Euy; 0x16uy; 0x36uy; 0x0Euy; 0x2Euy; 0x06uy; 0x26uy;  0x1Fuy; 0x3Fuy; 0x17uy; 0x37uy; 0x0Fuy; 0x2Fuy; 0x07uy; 0x27uy;
    |]
    let S = [|
        // S0
        0xEFuy; 0x03uy; 0x41uy; 0xFDuy; 0xD8uy; 0x74uy; 0x1Euy; 0x47uy;  0x26uy; 0xEFuy; 0xFBuy; 0x22uy; 0xB3uy; 0xD8uy; 0x84uy; 0x1Euy;
        0x39uy; 0xACuy; 0xA7uy; 0x60uy; 0x62uy; 0xC1uy; 0xCDuy; 0xBAuy;  0x5Cuy; 0x96uy; 0x90uy; 0x59uy; 0x05uy; 0x3Buy; 0x7Auy; 0x85uy;
        0x40uy; 0xFDuy; 0x1Euy; 0xC8uy; 0xE7uy; 0x8Auy; 0x8Buy; 0x21uy;  0xDAuy; 0x43uy; 0x64uy; 0x9Fuy; 0x2Duy; 0x14uy; 0xB1uy; 0x72uy;
        0xF5uy; 0x5Buy; 0xC8uy; 0xB6uy; 0x9Cuy; 0x37uy; 0x76uy; 0xECuy;  0x39uy; 0xA0uy; 0xA3uy; 0x05uy; 0x52uy; 0x6Euy; 0x0Fuy; 0xD9uy;
        // S1
        0xA7uy; 0xDDuy; 0x0Duy; 0x78uy; 0x9Euy; 0x0Buy; 0xE3uy; 0x95uy;  0x60uy; 0x36uy; 0x36uy; 0x4Fuy; 0xF9uy; 0x60uy; 0x5Auy; 0xA3uy;
        0x11uy; 0x24uy; 0xD2uy; 0x87uy; 0xC8uy; 0x52uy; 0x75uy; 0xECuy;  0xBBuy; 0xC1uy; 0x4Cuy; 0xBAuy; 0x24uy; 0xFEuy; 0x8Fuy; 0x19uy;
        0xDAuy; 0x13uy; 0x66uy; 0xAFuy; 0x49uy; 0xD0uy; 0x90uy; 0x06uy;  0x8Cuy; 0x6Auy; 0xFBuy; 0x91uy; 0x37uy; 0x8Duy; 0x0Duy; 0x78uy;
        0xBFuy; 0x49uy; 0x11uy; 0xF4uy; 0x23uy; 0xE5uy; 0xCEuy; 0x3Buy;  0x55uy; 0xBCuy; 0xA2uy; 0x57uy; 0xE8uy; 0x22uy; 0x74uy; 0xCEuy;
        // S2
        0x2Cuy; 0xEAuy; 0xC1uy; 0xBFuy; 0x4Auy; 0x24uy; 0x1Fuy; 0xC2uy;  0x79uy; 0x47uy; 0xA2uy; 0x7Cuy; 0xB6uy; 0xD9uy; 0x68uy; 0x15uy;
        0x80uy; 0x56uy; 0x5Duy; 0x01uy; 0x33uy; 0xFDuy; 0xF4uy; 0xAEuy;  0xDEuy; 0x30uy; 0x07uy; 0x9Buy; 0xE5uy; 0x83uy; 0x9Buy; 0x68uy;
        0x49uy; 0xB4uy; 0x2Euy; 0x83uy; 0x1Fuy; 0xC2uy; 0xB5uy; 0x7Cuy;  0xA2uy; 0x19uy; 0xD8uy; 0xE5uy; 0x7Cuy; 0x2Fuy; 0x83uy; 0xDAuy;
        0xF7uy; 0x6Buy; 0x90uy; 0xFEuy; 0xC4uy; 0x01uy; 0x5Auy; 0x97uy;  0x61uy; 0xA6uy; 0x3Duy; 0x40uy; 0x0Buy; 0x58uy; 0xE6uy; 0x3Duy;
        // S3
        0x4Duy; 0xD1uy; 0xB2uy; 0x0Fuy; 0x28uy; 0xBDuy; 0xE4uy; 0x78uy;  0xF6uy; 0x4Auy; 0x0Fuy; 0x93uy; 0x8Buy; 0x17uy; 0xD1uy; 0xA4uy;
        0x3Auy; 0xECuy; 0xC9uy; 0x35uy; 0x93uy; 0x56uy; 0x7Euy; 0xCBuy;  0x55uy; 0x20uy; 0xA0uy; 0xFEuy; 0x6Cuy; 0x89uy; 0x17uy; 0x62uy;
        0x17uy; 0x62uy; 0x4Buy; 0xB1uy; 0xB4uy; 0xDEuy; 0xD1uy; 0x87uy;  0xC9uy; 0x14uy; 0x3Cuy; 0x4Auy; 0x7Euy; 0xA8uy; 0xE2uy; 0x7Duy;
        0xA0uy; 0x9Fuy; 0xF6uy; 0x5Cuy; 0x6Auy; 0x09uy; 0x8Duy; 0xF0uy;  0x0Fuy; 0xE3uy; 0x53uy; 0x25uy; 0x95uy; 0x36uy; 0x28uy; 0xCBuy;
    |]
    let P = [|
        0x17uy; 0x1Cuy; 0x0Cuy; 0x1Buy; 0x13uy; 0x14uy; 0x01uy; 0x08uy;  0x0Euy; 0x19uy; 0x16uy; 0x03uy; 0x1Euy; 0x11uy; 0x09uy; 0x07uy;
        0x0Fuy; 0x05uy; 0x1Duy; 0x18uy; 0x0Auy; 0x10uy; 0x00uy; 0x06uy;  0x1Fuy; 0x04uy; 0x0Duy; 0x12uy; 0x02uy; 0x1Auy; 0x0Buy; 0x15uy;
    |]
    let getaddr (arr : byte[]) =
        NativeMemory.AlignedAlloc(unativeint arr.Length, if Vector512.IsHardwareAccelerated then unativeint Vector512<byte>.Count else unativeint Vector256<byte>.Count)
        |> NativePtr.ofVoidPtr<byte>
        |>! fun addr -> Marshal.Copy(arr, 0, NativePtr.toNativeInt addr, arr.Length)
    getaddr IP, getaddr IPINV, getaddr S, getaddr P

let private ProcessBlockOptimized (buf : nativeptr<byte>) =
    let buf = NativePtr.toNativeInt buf |> NativePtr.ofNativeInt<uint64>
    let inline shuffle (table : nativeptr<byte>) src =
        let mutable dst = 0uL
        for offset = 0 to 31 do dst <- src >>> int (NativePtr.get table offset) &&& 1uL <<< offset ||| dst
        dst
    let inline mixbit8 (table : nativeptr<byte>) src =
        shuffle (NativePtr.add table 32) src <<< 32 ||| shuffle table src
    let inline mixbitESP src =
        0uL ||| (((src >>> 0x23 &&& 0b11111uL) ||| (src >>> 0x33 &&& 0b100000uL) |> int |||   0 |> NativePtr.get S &&& 0xF0uy) ||| ((src >>> 0x2F &&& 0b1uL) ||| (src >>> 0x1F &&& 0b111110uL) |> int |||   0 |> NativePtr.get S &&& 0x0Fuy) |> uint64 <<<  0)
            ||| (((src >>> 0x2B &&& 0b11111uL) ||| (src >>> 0x1B &&& 0b100000uL) |> int |||  64 |> NativePtr.get S &&& 0xF0uy) ||| ((src >>> 0x37 &&& 0b1uL) ||| (src >>> 0x27 &&& 0b111110uL) |> int |||  64 |> NativePtr.get S &&& 0x0Fuy) |> uint64 <<<  8)
            ||| (((src >>> 0x33 &&& 0b11111uL) ||| (src >>> 0x23 &&& 0b100000uL) |> int ||| 128 |> NativePtr.get S &&& 0xF0uy) ||| ((src >>> 0x3F &&& 0b1uL) ||| (src >>> 0x2F &&& 0b111110uL) |> int ||| 128 |> NativePtr.get S &&& 0x0Fuy) |> uint64 <<< 16)
            ||| (((src >>> 0x3B &&& 0b11111uL) ||| (src >>> 0x2B &&& 0b100000uL) |> int ||| 192 |> NativePtr.get S &&& 0xF0uy) ||| ((src >>> 0x27 &&& 0b1uL) ||| (src >>> 0x37 &&& 0b111110uL) |> int ||| 192 |> NativePtr.get S &&& 0x0Fuy) |> uint64 <<< 24)
        |> shuffle P
    let tmp = NativePtr.read buf |> mixbit8 IP
    mixbitESP tmp ^^^ tmp |> mixbit8 IPINV |> NativePtr.write buf

.NET7向けAVX2化

shuffle関数の部分は32回shiftを繰り返しています。こういうところがSIMDの出番です。
AVX2であれば256bitレジスタを使い、64bit整数を4個同時にshiftできます。具体的にはVPSRLVQ命令(Avx2.ShiftRightLogicalVariableメソッド)です。shiftで使わないビットが影響を及ぼさないように事前に&&& 0xFFuLとビットマスクが必要です。
AVX2はshiftした結果から1bitずつ取り出すのが実は苦手ですが、256bitレジスタをバイト単位で最上位ビットを集めるVPMOVMSKB命令(Avx2.MoveMaskメソッド、Vector256.ExtractMostSignificantBitsメソッド)が使えます。この命令が使えるように<<< 7することで最下位ビットの値を最上位ビットへ移動します。
これらを使うと、64bit整数を4個同時にshiftした結果をいい感じに集約して256bit = 32byteの並びを得れば32bit分の結果が得られます。2回繰り返せば64bit分が得られます。
集約する際、128bit境界をなるべく跨がないように集めるとパフォーマンスが良くなります。この場合、VPUNPCKLBW命令(Avx2.UnpackLowメソッド)を使うと2つのレジスタの下位バイトを集約することができます。このための準備としてshift結果を下位バイトに集めておくため、VPSHUFB命令(Avx2.Shuffleメソッド)を使います。unpackするサイズを順番に広げていけば、集約しながら整列も進みます。
集約されていく様子はコメントにメモしてあります。

ループアンローリングしたE部分ですが、本来は64bit shiftですが、shift量を調節すれば32bit shiftで完結できそうです。256bitレジスタ2回で処理できます。その後のSから読み取る部分は、本来であればGATHER命令の出番ですが、Downfall脆弱性によりパフォーマンスが低下するので、SIMD処理を諦め汎用レジスタで処理します。その後のPの部分は下位32bitしか使われていないため、256bitレジスタで8個同時にshiftできるため、mixbit8とは関数を分けました。

ずいぶん長いコードになってしまいましたが、関数呼び出しは全てインライン展開され、分岐やループのない真っ直ぐなコードになりました。

let private ProcessBlockAvx2 (buf : nativeptr<byte>) =
    let buf = NativePtr.toNativeInt buf |> NativePtr.ofNativeInt<uint64>
    let ymm2 = Vector128.CreateScalarUnsafe 0xFFuL |> Avx2.BroadcastScalarToVector256
    let ymm3 = Vector256.Create(0x08_00us).AsByte()
    let inline mixbit8 (ptr : nativeptr<byte>) (ymm0 : Vector256<uint64>) =
        let inline shuffle8 offset =
            let mutable ymm1 = NativePtr.add ptr offset |> Vector256.LoadAligned |> _.As()
            let ymm4 = Avx2.Shuffle(Avx2.ShiftRightLogicalVariable(ymm0, ymm1 &&& ymm2).As(), ymm3) in ymm1 <- ymm1 >>> 8   // ______________yx______________80 ← _______y_______x_______8_______0
            let ymm5 = Avx2.Shuffle(Avx2.ShiftRightLogicalVariable(ymm0, ymm1 &&& ymm2).As(), ymm3) in ymm1 <- ymm1 >>> 8   // ______________yx______________91 ← _______y_______x_______9_______1
            let ymm6 = Avx2.UnpackLow(ymm4, ymm5).AsUInt16()                                                                // ____________yyxx____________9810
            let ymm4 = Avx2.Shuffle(Avx2.ShiftRightLogicalVariable(ymm0, ymm1 &&& ymm2).As(), ymm3) in ymm1 <- ymm1 >>> 8   // ______________yx______________A2 ← _______y_______x_______A_______2
            let ymm5 = Avx2.Shuffle(Avx2.ShiftRightLogicalVariable(ymm0, ymm1 &&& ymm2).As(), ymm3) in ymm1 <- ymm1 >>> 8   // ______________yx______________B3 ← _______y_______x_______B_______3
            let ymm4 = Avx2.UnpackLow(ymm4, ymm5).AsUInt16()                                                                // ____________yyxx____________BA32
            let ymm6 = Avx2.UnpackLow(ymm6, ymm4).AsUInt32()                                                                // ________yyyyxxxx________BA983210
            let ymm4 = Avx2.Shuffle(Avx2.ShiftRightLogicalVariable(ymm0, ymm1 &&& ymm2).As(), ymm3) in ymm1 <- ymm1 >>> 8   // ______________yx______________C4 ← _______y_______x_______C_______4
            let ymm5 = Avx2.Shuffle(Avx2.ShiftRightLogicalVariable(ymm0, ymm1 &&& ymm2).As(), ymm3) in ymm1 <- ymm1 >>> 8   // ______________yx______________D5 ← _______y_______x_______D_______5
            let ymm5 = Avx2.UnpackLow(ymm4, ymm5).AsUInt16()                                                                // ____________yyxx____________DC54
            let ymm4 = Avx2.Shuffle(Avx2.ShiftRightLogicalVariable(ymm0, ymm1 &&& ymm2).As(), ymm3) in ymm1 <- ymm1 >>> 8   // ______________yx______________E6 ← _______y_______x_______E_______6
            let ymm0 = Avx2.Shuffle(Avx2.ShiftRightLogicalVariable(ymm0, ymm1 &&& ymm2).As(), ymm3) in ymm1 <- ymm1 >>> 8   // ______________yx______________F7 ← _______y_______x_______F_______7
            let ymm0 = Avx2.UnpackLow(ymm4, ymm0).AsUInt16()                                                                // ____________yyxx____________FE76
            let ymm0 = Avx2.UnpackLow(ymm5, ymm0).AsUInt32()                                                                // ________yyyyxxxx________FEDC7654
            Avx2.UnpackLow(ymm6, ymm0) <<< 7 |> _.AsByte().ExtractMostSignificantBits() |> uint64                           // yyyyyyyyxxxxxxxxFEDCBA9876543210
        shuffle8 32 <<< 32 ||| shuffle8 0
    let inline mixbitESP (src : uint64) =
        let getbyte index =
            NativePtr.toNativeInt S + nativeint index |> NativePtr.ofNativeInt<byte> |> NativePtr.read
        let ymm0 = src >>> 32 |> uint32 |> Vector256.Create
        let ymm1 = Avx2.ShiftRightLogicalVariable(ymm0, Vector256.Create(0x03u, 0x0Bu, 0x13u, 0x1Bu, 0x0Fu, 0x17u, 0x1Fu, 0x07u))
        let ymm0 = Avx2.ShiftLeftLogicalVariable(Avx2.ShiftRightLogicalVariable(ymm0, Vector256.Create(0x18u, 0x00u, 0x08u, 0x10u, 0x00u, 0x08u, 0x10u, 0x18u)), Vector256.Create(5u, 5u, 5u, 5u, 1u, 1u, 1u, 1u))
        let ymm1 = ymm1 &&& Vector256.Create(0b011111u, 0b011111u, 0b011111u, 0b011111u, 0b000001u, 0b000001u, 0b000001u, 0b000001u)
        let ymm0 = ymm0 &&& Vector256.Create(0b100000u, 0b100000u, 0b100000u, 0b100000u, 0b111110u, 0b111110u, 0b111110u, 0b111110u)
        let ymm0 = ymm1 ||| ymm0 ||| Vector256.Create(0u, 64u, 128u, 192u, 0u, 64u, 128u, 192u)
        let xmm1 = ymm0.GetUpper()
        let xmm0 = ymm0.GetLower()
        let xmm0 = Vector128.CreateScalar(getbyte xmm0[0]).WithElement(1, getbyte xmm0[1]).WithElement(2, getbyte xmm0[2]).WithElement(3, getbyte xmm0[3]).As() &&& Vector128.CreateScalar 0xF0_F0_F0_F0u
        let xmm1 = Vector128.CreateScalar(getbyte xmm1[0]).WithElement(1, getbyte xmm1[1]).WithElement(2, getbyte xmm1[2]).WithElement(3, getbyte xmm1[3]).As() &&& Vector128.CreateScalar 0x0F_0F_0F_0Fu
        let ymm0 = xmm0 ||| xmm1 |> Avx2.BroadcastScalarToVector256
        let mutable ymm1 = Vector256.LoadAligned(P).As()
        let ymm4 = Vector256.Create 0xFFu
        let ymm5 = Vector256.Create(0x0C_08_04_00uL).AsByte()
        let ymm3 = Avx2.Shuffle(Avx2.ShiftRightLogicalVariable(ymm0, ymm1 &&& ymm4).As(), ymm5) in ymm1 <- ymm1 >>> 8   // ____________yyxx____________C840 ← ___y___y___x___x___C___8___4___0
        let ymm2 = Avx2.Shuffle(Avx2.ShiftRightLogicalVariable(ymm0, ymm1 &&& ymm4).As(), ymm5) in ymm1 <- ymm1 >>> 8   // ____________yyxx____________D951 ← ___y___y___x___x___D___9___5___1
        let ymm2 = Avx2.UnpackLow(ymm3, ymm2).AsUInt16()                                                                // ________yyyyxxxx________DC985410
        let ymm3 = Avx2.Shuffle(Avx2.ShiftRightLogicalVariable(ymm0, ymm1 &&& ymm4).As(), ymm5) in ymm1 <- ymm1 >>> 8   // ____________yyxx____________EA62 ← ___y___y___x___x___E___A___6___2
        let ymm0 = Avx2.Shuffle(Avx2.ShiftRightLogicalVariable(ymm0, ymm1 &&& ymm4).As(), ymm5) in ymm1 <- ymm1 >>> 8   // ____________yyxx____________FB73 ← ___y___y___x___x___F___B___7___3
        let ymm0 = Avx2.UnpackLow(ymm3, ymm0).AsUInt16()                                                                // ________yyyyxxxx________FEBA7632
        Avx2.UnpackLow(ymm2, ymm0) <<< 7 |> _.AsByte().ExtractMostSignificantBits() |> uint64                           // yyyyyyyyxxxxxxxxFEDCBA9876543210
    let tmp = Avx2.BroadcastScalarToVector256 buf |> mixbit8 IP
    mixbitESP tmp ^^^ tmp |> Vector256.Create |> mixbit8 IPINV |> NativePtr.write buf

.NET8向けAVX-512化

AVX-512であれば512bitレジスタを使い、64bit整数を8個同時にshiftできます。AVX2から半減します。しかし、シフト量を表すテーブルは32bitしかなくこのままでは512bitレジスタを活かせません。単純に上位にコピーしても同じアドレスをshiftしてしまい意味がありません。ずらそうとするとずらす手間と整列する手間がかかります。AVX2ではshift前にビットマスクもしていました。
これらを一気に解決する方法としてAVX-512で追加されたVPMOVZXBQ命令(Avx512F.ConvertToVector512UInt64メソッド)が使えます。8bit×8をゼロ拡張して64bit×8に配置してくれます。ちなみにVPMOVZXBQ命令はレジスタだけでなくメモリを受け付けるため、AVX2版のAvx2.ConvertToVector256UInt64メソッドはポインターを引数に取るオーバーロードが用意されていますが、Avx512F.ConvertToVector512UInt64は確信犯的にポインターバージョンが削除されています。仕方がないので一旦Vector128に読み込んでいます。
集約にはAVX-512で追加されたVPERMT2B命令(Avx512Vbmi.PermuteVar64x8x2メソッド)で64バイト×2を任意の位置にshuffleすることで一気に整列できます。VPMOVZXBQVPERMT2Bのどちらも128bit境界を跨ぎますが、AVX2よりも効率的に処理できるので結果的には良いパフォーマンスが得られます。shiftしてmove maskする点は同じですが、どのバイトをshiftしどう集約するか、出来上がったコードはAVX2とは全く異なるアプローチになりました。

ループアンローリングしたE部分ですが、512bitレジスタを使えば1回で処理できます。しかし256bitレジスタに比べて521bitレジスタでの実行は倍以上の時間がかかり、逆に遅くなってしまいます。そのため、ここはAVX2版から変更なしです。…と言いつつ、AVX-512で追加されたVPRORVQ命令(Avx512F.VL.RotateRightVariableメソッド)を使っています。AVX2版では溢れないようshift量を調整していましたが、rotateすることで溢れても気にしないようになっています。また、|||を重ねていた部分もAVX-512で追加された三項演算子VPTERNLOGD命令(Avx512F.VL.TernaryLogicメソッド)に変えました。
Sを読み取る部分ですが、Sは32バイト=512bitに収まるので4レジスタに読み込みそこからpermuteで取り出すようにしました。

AVX2版よりは減りましたが、まだまだ長いコードです。

let private ProcessBlockAvx512 (buf : nativeptr<byte>) =
    let buf = NativePtr.toNativeInt buf |> NativePtr.ofNativeInt<uint64>
    let zmm15 = Vector128.Create(0x38_30_28_20_18_10_08_00uL, 0x78_70_68_60_58_50_48_40uL).ToVector256Unsafe().ToVector512Unsafe().AsByte()
    let inline mixbit8 (ptr : nativeptr<byte>) (zmm0 : Vector512<uint64>) =
        let ptr = NativePtr.toNativeInt ptr |> NativePtr.ofNativeInt<uint64>
        let zmm1 = Avx512DQ.ShiftRightLogicalVariable(zmm0, NativePtr.get ptr 0 |> Vector128.CreateScalar |> _.AsByte() |> Avx512F.ConvertToVector512UInt64)    // _______7_______6_______5_______4_______3_______2_______1_______0
        let zmm2 = Avx512DQ.ShiftRightLogicalVariable(zmm0, NativePtr.get ptr 1 |> Vector128.CreateScalar |> _.AsByte() |> Avx512F.ConvertToVector512UInt64)    // _______F_______E_______D_______C_______B_______A_______9_______8
        let zmm3 = Avx512DQ.ShiftRightLogicalVariable(zmm0, NativePtr.get ptr 2 |> Vector128.CreateScalar |> _.AsByte() |> Avx512F.ConvertToVector512UInt64)    // _______x_______x_______x_______x_______x_______x_______x_______x
        let zmm4 = Avx512DQ.ShiftRightLogicalVariable(zmm0, NativePtr.get ptr 3 |> Vector128.CreateScalar |> _.AsByte() |> Avx512F.ConvertToVector512UInt64)    // _______y_______y_______y_______y_______y_______y_______y_______y
        let zmm5 = Avx512DQ.ShiftRightLogicalVariable(zmm0, NativePtr.get ptr 4 |> Vector128.CreateScalar |> _.AsByte() |> Avx512F.ConvertToVector512UInt64)
        let zmm6 = Avx512DQ.ShiftRightLogicalVariable(zmm0, NativePtr.get ptr 5 |> Vector128.CreateScalar |> _.AsByte() |> Avx512F.ConvertToVector512UInt64)
        let zmm7 = Avx512DQ.ShiftRightLogicalVariable(zmm0, NativePtr.get ptr 6 |> Vector128.CreateScalar |> _.AsByte() |> Avx512F.ConvertToVector512UInt64)
        let zmm8 = Avx512DQ.ShiftRightLogicalVariable(zmm0, NativePtr.get ptr 7 |> Vector128.CreateScalar |> _.AsByte() |> Avx512F.ConvertToVector512UInt64)
        let ymm1 = Avx512Vbmi.PermuteVar64x8x2(zmm1.As(), zmm15, zmm2.As()).GetLower()                                                                          // ________________FEDCBA9876543210
        let ymm3 = Avx512Vbmi.PermuteVar64x8x2(zmm3.As(), zmm15, zmm4.As()).GetLower()                                                                          // ________________yyyyyyyyxxxxxxxx
        let ymm5 = Avx512Vbmi.PermuteVar64x8x2(zmm5.As(), zmm15, zmm6.As()).GetLower()
        let ymm7 = Avx512Vbmi.PermuteVar64x8x2(zmm7.As(), zmm15, zmm8.As()).GetLower()
        let rax = ymm3.GetLower() |> ymm1.WithUpper |> _.AsUInt64() <<< 7 |> _.AsByte().ExtractMostSignificantBits() |> uint64                                  // yyyyyyyyxxxxxxxxFEDCBA9876543210
        let rdx = ymm7.GetLower() |> ymm5.WithUpper |> _.AsUInt64() <<< 7 |> _.AsByte().ExtractMostSignificantBits() |> uint64
        rdx <<< 32 ||| rax
    let inline mixbitESP (src : uint64) =
        let ymm0 = src >>> 32 |> uint32 |> Vector256.Create
        let ymm1 = Avx2.ShiftRightLogicalVariable(ymm0, Vector256.Create(0x03u, 0x0Bu, 0x13u, 0x1Bu, 0x0Fu, 0x17u, 0x1Fu, 0x07u)) &&& Vector256.Create(0b011111u, 0b011111u, 0b011111u, 0b011111u, 0b000001u, 0b000001u, 0b000001u, 0b000001u)
        let ymm0 = Avx512F.VL.RotateRightVariable(ymm0, Vector256.Create(0x13u, 0x1Bu, 0x03u, 0x0Bu, 0x1Fu, 0x07u, 0x0Fu, 0x17u)) &&& Vector256.Create(0b100000u, 0b100000u, 0b100000u, 0b100000u, 0b111110u, 0b111110u, 0b111110u, 0b111110u)
        let zmm0 = Avx512F.VL.TernaryLogic(ymm0, ymm1, Vector256.Create(0x00000040_00000000uL).As(), 0xF0uy ||| 0xCCuy ||| 0xAAuy).ToVector512Unsafe().As()
        let ymm2 = Vector256.Create(0xF0u, 0xF0u, 0xF0u, 0xF0u, 0x0Fu, 0x0Fu, 0x0Fu, 0x0Fu).As()
        let ymm1 = Avx512Vbmi.PermuteVar64x8x2(NativePtr.add S   0 |> Vector512.LoadAligned, zmm0, NativePtr.add S  64 |> Vector512.LoadAligned).GetLower() &&& ymm2
        let ymm0 = Avx512Vbmi.PermuteVar64x8x2(NativePtr.add S 128 |> Vector512.LoadAligned, zmm0, NativePtr.add S 192 |> Vector512.LoadAligned).GetLower() &&& ymm2
        let xmm1 = ymm1.GetUpper() ||| ymm1.GetLower()
        let xmm0 = ymm0.GetUpper() ||| ymm0.GetLower()
        let zmm1 = Avx512Vbmi.VL.PermuteVar16x8x2(xmm1, Vector128.CreateScalar(0x1C_18_04_00u).As(), xmm0).AsUInt32() |> Avx512F.BroadcastScalarToVector512
        let zmm2 = Vector512.Create(0x1C_18_14_10_0C_08_04_00uL, 0x3C_38_34_30_2C_28_24_20uL, 0x5C_58_54_50_4C_48_44_40uL, 0x7C_78_74_70_6C_68_64_60uL, 0uL, 0uL, 0uL, 0uL).AsByte()
        let zmm0 = Avx512DQ.ShiftRightLogicalVariable(zmm1, NativePtr.add P  0 |> Vector128.LoadAligned |> Avx512F.ConvertToVector512UInt32)    // ___F___E___D___C___B___A___9___8___7___6___5___4___3___2___1___0
        let zmm1 = Avx512DQ.ShiftRightLogicalVariable(zmm1, NativePtr.add P 16 |> Vector128.LoadAligned |> Avx512F.ConvertToVector512UInt32)    // ___y___y___y___y___y___y___y___y___x___x___x___x___x___x___x___x
        let ymm0 = Avx512Vbmi.PermuteVar64x8x2(zmm0.As(), zmm2, zmm1.As()).GetLower()                                                           // yyyyyyyyxxxxxxxxFEDCBA9876543210
        ymm0.AsUInt64() <<< 7 |> _.AsByte().ExtractMostSignificantBits() |> uint64
    let tmp = NativePtr.read buf |> Vector512.Create |> mixbit8 IP
    mixbitESP tmp ^^^ tmp |> Vector512.Create |> mixbit8 IPINV |> NativePtr.write buf

.NET9向けAVX-512化

実は.NET8での実装が間に合わず漏れた命令があります。VPMULTISHIFTQB命令(Avx512Vbmi.MultiShiftメソッド)は64並列でshiftを行うことができます。並び変えることなく一気に実行できるので結果も整列済みになります。今までの苦労は何だったのかと。そしてこんな便利な命令が.NET8から漏れてしまったことが非常に残念です。

他の部分は変更有りませんが、全体としてかなりすっきりしました。

let private ProcessBlock (buf : nativeptr<byte>) =
    let buf = NativePtr.toNativeInt buf |> NativePtr.ofNativeInt<uint64>
    let inline mixbit8 (ptr : nativeptr<byte>) (src : uint64) =
        Avx512Vbmi.MultiShift(Vector512.Load ptr, Vector512.Create src).AsUInt64() <<< 7 |> _.AsByte().ExtractMostSignificantBits() // yyyyyyyyxxxxxxxxFEDCBA9876543210
    let inline mixbitESP (src : uint64) =
        let ymm0 = src >>> 32 |> uint32 |> Vector256.Create
        let ymm1 = Avx2.ShiftRightLogicalVariable(ymm0, Vector256.Create(0x03u, 0x0Bu, 0x13u, 0x1Bu, 0x0Fu, 0x17u, 0x1Fu, 0x07u)) &&& Vector256.Create(0b011111u, 0b011111u, 0b011111u, 0b011111u, 0b000001u, 0b000001u, 0b000001u, 0b000001u)
        let ymm0 = Avx512F.VL.RotateRightVariable(ymm0, Vector256.Create(0x13u, 0x1Bu, 0x03u, 0x0Bu, 0x1Fu, 0x07u, 0x0Fu, 0x17u)) &&& Vector256.Create(0b100000u, 0b100000u, 0b100000u, 0b100000u, 0b111110u, 0b111110u, 0b111110u, 0b111110u)
        let zmm0 = Avx512F.VL.TernaryLogic(ymm0, ymm1, Vector256.Create(0x00000040_00000000uL).As(), 0xF0uy ||| 0xCCuy ||| 0xAAuy).ToVector512Unsafe().As()
        let ymm2 = Vector256.Create(0xF0u, 0xF0u, 0xF0u, 0xF0u, 0x0Fu, 0x0Fu, 0x0Fu, 0x0Fu).As()
        let ymm1 = Avx512Vbmi.PermuteVar64x8x2(NativePtr.add S   0 |> Vector512.LoadAligned, zmm0, NativePtr.add S  64 |> Vector512.LoadAligned).GetLower() &&& ymm2
        let ymm0 = Avx512Vbmi.PermuteVar64x8x2(NativePtr.add S 128 |> Vector512.LoadAligned, zmm0, NativePtr.add S 192 |> Vector512.LoadAligned).GetLower() &&& ymm2
        let xmm1 = ymm1.GetUpper() ||| ymm1.GetLower()
        let xmm0 = ymm0.GetUpper() ||| ymm0.GetLower()
        let ymm0 = Avx512Vbmi.VL.PermuteVar16x8x2(xmm1, Vector128.CreateScalar(0x1C_18_04_00u).As(), xmm0).AsUInt32() |> Avx512F.BroadcastScalarToVector256
        Avx512Vbmi.VL.MultiShift(Vector256.Load P, ymm0).AsUInt64() <<< 7 |> _.AsByte().ExtractMostSignificantBits() |> uint64    // yyyyyyyyxxxxxxxxFEDCBA9876543210
    let tmp = NativePtr.read buf |> mixbit8 IP
    mixbitESP tmp ^^^ tmp |> mixbit8 IPINV |> NativePtr.write buf

効果測定

BenchmarkDotNetで測ってみます。

  • Intel Core i7-1065G7 CPU 1.30GHz
  • .NET 8.0.0 (8.0.23.53103), X64 RyuJIT AVX2
Method Mean Error StdDev Ratio Code Size
Baseline 738.36 ns 7.809 ns 7.305 ns 1.00 890 B
Optimized 349.12 ns 1.949 ns 1.823 ns 0.47 699 B
UseAVX2 84.37 ns 0.852 ns 0.797 ns 0.11 1,402 B
UseAVX512 52.22 ns 0.608 ns 0.539 ns 0.07 873 B

.NET8向けAVX-512化では元コードから14倍の高速化となりました。

Discussion