F#でアセンブリプログラミング(実践編)
データを繰り返し処理する部分をSIMDを使った並列処理で高速化してみようと思います。
復号アルゴリズム
ここに某ゲームデータがあります。このデータは暗号化されています。しかし、既に解析されており、そのアルゴリズムはGrfCryptと名前が付けられGitHubなどで公開されています。例えばここ。
この中でProcessBlock
と呼ばれる関数は8byteデータを受け取り復号して返します。つまりすべてのデータは8byte区切りされこの関数に渡されるわけで、ここを高速化することに意義があります。
ということで、ProcessBlock
の高速化について順を追って説明していきます。
オリジナルコード
いつ書いたか思い出せませんが、私がF#に移植したコードはこんな感じです。
let IP = [|
0x39uy; 0x31uy; 0x29uy; 0x21uy; 0x19uy; 0x11uy; 0x09uy; 0x01uy; 0x3Buy; 0x33uy; 0x2Buy; 0x23uy; 0x1Buy; 0x13uy; 0x0Buy; 0x03uy;
0x3Duy; 0x35uy; 0x2Duy; 0x25uy; 0x1Duy; 0x15uy; 0x0Duy; 0x05uy; 0x3Fuy; 0x37uy; 0x2Fuy; 0x27uy; 0x1Fuy; 0x17uy; 0x0Fuy; 0x07uy;
0x38uy; 0x30uy; 0x28uy; 0x20uy; 0x18uy; 0x10uy; 0x08uy; 0x00uy; 0x3Auy; 0x32uy; 0x2Auy; 0x22uy; 0x1Auy; 0x12uy; 0x0Auy; 0x02uy;
0x3Cuy; 0x34uy; 0x2Cuy; 0x24uy; 0x1Cuy; 0x14uy; 0x0Cuy; 0x04uy; 0x3Euy; 0x36uy; 0x2Euy; 0x26uy; 0x1Euy; 0x16uy; 0x0Euy; 0x06uy;
|]
let private IPINV = [|
0x27uy; 0x07uy; 0x2Fuy; 0x0Fuy; 0x37uy; 0x17uy; 0x3Fuy; 0x1Fuy; 0x26uy; 0x06uy; 0x2Euy; 0x0Euy; 0x36uy; 0x16uy; 0x3Euy; 0x1Euy;
0x25uy; 0x05uy; 0x2Duy; 0x0Duy; 0x35uy; 0x15uy; 0x3Duy; 0x1Duy; 0x24uy; 0x04uy; 0x2Cuy; 0x0Cuy; 0x34uy; 0x14uy; 0x3Cuy; 0x1Cuy;
0x23uy; 0x03uy; 0x2Buy; 0x0Buy; 0x33uy; 0x13uy; 0x3Buy; 0x1Buy; 0x22uy; 0x02uy; 0x2Auy; 0x0Auy; 0x32uy; 0x12uy; 0x3Auy; 0x1Auy;
0x21uy; 0x01uy; 0x29uy; 0x09uy; 0x31uy; 0x11uy; 0x39uy; 0x19uy; 0x20uy; 0x00uy; 0x28uy; 0x08uy; 0x30uy; 0x10uy; 0x38uy; 0x18uy;
|]
let private S =[|
[|
0xEFuy; 0x03uy; 0x41uy; 0xFDuy; 0xD8uy; 0x74uy; 0x1Euy; 0x47uy; 0x26uy; 0xEFuy; 0xFBuy; 0x22uy; 0xB3uy; 0xD8uy; 0x84uy; 0x1Euy;
0x39uy; 0xACuy; 0xA7uy; 0x60uy; 0x62uy; 0xC1uy; 0xCDuy; 0xBAuy; 0x5Cuy; 0x96uy; 0x90uy; 0x59uy; 0x05uy; 0x3Buy; 0x7Auy; 0x85uy;
0x40uy; 0xFDuy; 0x1Euy; 0xC8uy; 0xE7uy; 0x8Auy; 0x8Buy; 0x21uy; 0xDAuy; 0x43uy; 0x64uy; 0x9Fuy; 0x2Duy; 0x14uy; 0xB1uy; 0x72uy;
0xF5uy; 0x5Buy; 0xC8uy; 0xB6uy; 0x9Cuy; 0x37uy; 0x76uy; 0xECuy; 0x39uy; 0xA0uy; 0xA3uy; 0x05uy; 0x52uy; 0x6Euy; 0x0Fuy; 0xD9uy;
|];
[|
0xA7uy; 0xDDuy; 0x0Duy; 0x78uy; 0x9Euy; 0x0Buy; 0xE3uy; 0x95uy; 0x60uy; 0x36uy; 0x36uy; 0x4Fuy; 0xF9uy; 0x60uy; 0x5Auy; 0xA3uy;
0x11uy; 0x24uy; 0xD2uy; 0x87uy; 0xC8uy; 0x52uy; 0x75uy; 0xECuy; 0xBBuy; 0xC1uy; 0x4Cuy; 0xBAuy; 0x24uy; 0xFEuy; 0x8Fuy; 0x19uy;
0xDAuy; 0x13uy; 0x66uy; 0xAFuy; 0x49uy; 0xD0uy; 0x90uy; 0x06uy; 0x8Cuy; 0x6Auy; 0xFBuy; 0x91uy; 0x37uy; 0x8Duy; 0x0Duy; 0x78uy;
0xBFuy; 0x49uy; 0x11uy; 0xF4uy; 0x23uy; 0xE5uy; 0xCEuy; 0x3Buy; 0x55uy; 0xBCuy; 0xA2uy; 0x57uy; 0xE8uy; 0x22uy; 0x74uy; 0xCEuy;
|];
[|
0x2Cuy; 0xEAuy; 0xC1uy; 0xBFuy; 0x4Auy; 0x24uy; 0x1Fuy; 0xC2uy; 0x79uy; 0x47uy; 0xA2uy; 0x7Cuy; 0xB6uy; 0xD9uy; 0x68uy; 0x15uy;
0x80uy; 0x56uy; 0x5Duy; 0x01uy; 0x33uy; 0xFDuy; 0xF4uy; 0xAEuy; 0xDEuy; 0x30uy; 0x07uy; 0x9Buy; 0xE5uy; 0x83uy; 0x9Buy; 0x68uy;
0x49uy; 0xB4uy; 0x2Euy; 0x83uy; 0x1Fuy; 0xC2uy; 0xB5uy; 0x7Cuy; 0xA2uy; 0x19uy; 0xD8uy; 0xE5uy; 0x7Cuy; 0x2Fuy; 0x83uy; 0xDAuy;
0xF7uy; 0x6Buy; 0x90uy; 0xFEuy; 0xC4uy; 0x01uy; 0x5Auy; 0x97uy; 0x61uy; 0xA6uy; 0x3Duy; 0x40uy; 0x0Buy; 0x58uy; 0xE6uy; 0x3Duy;
|];
[|
0x4Duy; 0xD1uy; 0xB2uy; 0x0Fuy; 0x28uy; 0xBDuy; 0xE4uy; 0x78uy; 0xF6uy; 0x4Auy; 0x0Fuy; 0x93uy; 0x8Buy; 0x17uy; 0xD1uy; 0xA4uy;
0x3Auy; 0xECuy; 0xC9uy; 0x35uy; 0x93uy; 0x56uy; 0x7Euy; 0xCBuy; 0x55uy; 0x20uy; 0xA0uy; 0xFEuy; 0x6Cuy; 0x89uy; 0x17uy; 0x62uy;
0x17uy; 0x62uy; 0x4Buy; 0xB1uy; 0xB4uy; 0xDEuy; 0xD1uy; 0x87uy; 0xC9uy; 0x14uy; 0x3Cuy; 0x4Auy; 0x7Euy; 0xA8uy; 0xE2uy; 0x7Duy;
0xA0uy; 0x9Fuy; 0xF6uy; 0x5Cuy; 0x6Auy; 0x09uy; 0x8Duy; 0xF0uy; 0x0Fuy; 0xE3uy; 0x53uy; 0x25uy; 0x95uy; 0x36uy; 0x28uy; 0xCBuy;
|];
|]
let private E = [|
0x3Fuy; 0x20uy; 0x21uy; 0x22uy; 0x23uy; 0x24uy; 0x23uy; 0x24uy; 0x25uy; 0x26uy; 0x27uy; 0x28uy;
0x27uy; 0x28uy; 0x29uy; 0x2Auy; 0x2Buy; 0x2Cuy; 0x2Buy; 0x2Cuy; 0x2Duy; 0x2Euy; 0x2Fuy; 0x30uy;
0x2Fuy; 0x30uy; 0x31uy; 0x32uy; 0x33uy; 0x34uy; 0x33uy; 0x34uy; 0x35uy; 0x36uy; 0x37uy; 0x38uy;
0x37uy; 0x38uy; 0x39uy; 0x3Auy; 0x3Buy; 0x3Cuy; 0x3Buy; 0x3Cuy; 0x3Duy; 0x3Euy; 0x3Fuy; 0x20uy;
|]
let private P = [|
0x0Fuy; 0x06uy; 0x13uy; 0x14uy; 0x1Cuy; 0x0Buy; 0x1Buy; 0x10uy; 0x00uy; 0x0Euy; 0x16uy; 0x19uy; 0x04uy; 0x11uy; 0x1Euy; 0x09uy;
0x01uy; 0x07uy; 0x17uy; 0x0Duy; 0x1Fuy; 0x1Auy; 0x02uy; 0x08uy; 0x12uy; 0x0Cuy; 0x1Duy; 0x05uy; 0x15uy; 0x0Auy; 0x03uy; 0x18uy;
|]
let mixbit (table : byte[]) scale src dst =
for dstoffset = 0 to table.Length / scale - 1 do
for dstshift = 0 to scale - 1 do
let value = int table.[dstoffset * scale + dstshift]
let srcoffset = value >>> 3
let srcshift = value &&& 7
if 0x80uy >>> srcshift &&& NativePtr.get src srcoffset <> 0uy then
let b = NativePtr.toByRef (NativePtr.add dst dstoffset)
b <- b ^^^ (0x01uy <<< scale - dstshift - 1)
let ProcessBlock buf =
let stack1 = NativePtr.stackalloc 16
let stack2 = NativePtr.add stack1 8
mixbit IP 8 buf stack1
mixbit E 6 stack1 stack2
S |> Array.iteri (fun i s -> (s[NativePtr.get stack2 (i * 2) |> int] &&& 0xF0uy) ||| (s[NativePtr.get stack2 (i * 2 + 1) |> int] &&& 0x0Fuy) |> NativePtr.set stack2 i)
mixbit P 8 stack2 stack1
NativePtr.toNativeInt buf |> NativePtr.ofNativeInt |> NativePtr.write <| 0uL
mixbit IPINV 8 stack1 buf
F#での最適化
まずはF#言語で表現できる範囲での最適化をします。
- 配列アクセスは効率が悪いので、あらかじめ固定しておきポインターアクセスします。
- 元の処理はbig endianのようで、bit shift方向とbyte shift方向が一致していませんでした。これはデータを反転すれば一致させることができます。
- 他の配列と異なり
E
はある程度パターン化しておりなおかつ連番ですのでまとめて処理できます。結果も1byteずつS
の処理に渡されるため、ここは全てループアンローリングしてしまった方がいいでしょう。
そんなこんなを盛り込むとこんな感じになりました。
let private IP, IPINV, S, P =
let IP = [|
0x06uy; 0x0Euy; 0x16uy; 0x1Euy; 0x26uy; 0x2Euy; 0x36uy; 0x3Euy; 0x04uy; 0x0Cuy; 0x14uy; 0x1Cuy; 0x24uy; 0x2Cuy; 0x34uy; 0x3Cuy;
0x02uy; 0x0Auy; 0x12uy; 0x1Auy; 0x22uy; 0x2Auy; 0x32uy; 0x3Auy; 0x00uy; 0x08uy; 0x10uy; 0x18uy; 0x20uy; 0x28uy; 0x30uy; 0x38uy;
0x07uy; 0x0Fuy; 0x17uy; 0x1Fuy; 0x27uy; 0x2Fuy; 0x37uy; 0x3Fuy; 0x05uy; 0x0Duy; 0x15uy; 0x1Duy; 0x25uy; 0x2Duy; 0x35uy; 0x3Duy;
0x03uy; 0x0Buy; 0x13uy; 0x1Buy; 0x23uy; 0x2Buy; 0x33uy; 0x3Buy; 0x01uy; 0x09uy; 0x11uy; 0x19uy; 0x21uy; 0x29uy; 0x31uy; 0x39uy;
|]
let IPINV = [|
0x18uy; 0x38uy; 0x10uy; 0x30uy; 0x08uy; 0x28uy; 0x00uy; 0x20uy; 0x19uy; 0x39uy; 0x11uy; 0x31uy; 0x09uy; 0x29uy; 0x01uy; 0x21uy;
0x1Auy; 0x3Auy; 0x12uy; 0x32uy; 0x0Auy; 0x2Auy; 0x02uy; 0x22uy; 0x1Buy; 0x3Buy; 0x13uy; 0x33uy; 0x0Buy; 0x2Buy; 0x03uy; 0x23uy;
0x1Cuy; 0x3Cuy; 0x14uy; 0x34uy; 0x0Cuy; 0x2Cuy; 0x04uy; 0x24uy; 0x1Duy; 0x3Duy; 0x15uy; 0x35uy; 0x0Duy; 0x2Duy; 0x05uy; 0x25uy;
0x1Euy; 0x3Euy; 0x16uy; 0x36uy; 0x0Euy; 0x2Euy; 0x06uy; 0x26uy; 0x1Fuy; 0x3Fuy; 0x17uy; 0x37uy; 0x0Fuy; 0x2Fuy; 0x07uy; 0x27uy;
|]
let S = [|
// S0
0xEFuy; 0x03uy; 0x41uy; 0xFDuy; 0xD8uy; 0x74uy; 0x1Euy; 0x47uy; 0x26uy; 0xEFuy; 0xFBuy; 0x22uy; 0xB3uy; 0xD8uy; 0x84uy; 0x1Euy;
0x39uy; 0xACuy; 0xA7uy; 0x60uy; 0x62uy; 0xC1uy; 0xCDuy; 0xBAuy; 0x5Cuy; 0x96uy; 0x90uy; 0x59uy; 0x05uy; 0x3Buy; 0x7Auy; 0x85uy;
0x40uy; 0xFDuy; 0x1Euy; 0xC8uy; 0xE7uy; 0x8Auy; 0x8Buy; 0x21uy; 0xDAuy; 0x43uy; 0x64uy; 0x9Fuy; 0x2Duy; 0x14uy; 0xB1uy; 0x72uy;
0xF5uy; 0x5Buy; 0xC8uy; 0xB6uy; 0x9Cuy; 0x37uy; 0x76uy; 0xECuy; 0x39uy; 0xA0uy; 0xA3uy; 0x05uy; 0x52uy; 0x6Euy; 0x0Fuy; 0xD9uy;
// S1
0xA7uy; 0xDDuy; 0x0Duy; 0x78uy; 0x9Euy; 0x0Buy; 0xE3uy; 0x95uy; 0x60uy; 0x36uy; 0x36uy; 0x4Fuy; 0xF9uy; 0x60uy; 0x5Auy; 0xA3uy;
0x11uy; 0x24uy; 0xD2uy; 0x87uy; 0xC8uy; 0x52uy; 0x75uy; 0xECuy; 0xBBuy; 0xC1uy; 0x4Cuy; 0xBAuy; 0x24uy; 0xFEuy; 0x8Fuy; 0x19uy;
0xDAuy; 0x13uy; 0x66uy; 0xAFuy; 0x49uy; 0xD0uy; 0x90uy; 0x06uy; 0x8Cuy; 0x6Auy; 0xFBuy; 0x91uy; 0x37uy; 0x8Duy; 0x0Duy; 0x78uy;
0xBFuy; 0x49uy; 0x11uy; 0xF4uy; 0x23uy; 0xE5uy; 0xCEuy; 0x3Buy; 0x55uy; 0xBCuy; 0xA2uy; 0x57uy; 0xE8uy; 0x22uy; 0x74uy; 0xCEuy;
// S2
0x2Cuy; 0xEAuy; 0xC1uy; 0xBFuy; 0x4Auy; 0x24uy; 0x1Fuy; 0xC2uy; 0x79uy; 0x47uy; 0xA2uy; 0x7Cuy; 0xB6uy; 0xD9uy; 0x68uy; 0x15uy;
0x80uy; 0x56uy; 0x5Duy; 0x01uy; 0x33uy; 0xFDuy; 0xF4uy; 0xAEuy; 0xDEuy; 0x30uy; 0x07uy; 0x9Buy; 0xE5uy; 0x83uy; 0x9Buy; 0x68uy;
0x49uy; 0xB4uy; 0x2Euy; 0x83uy; 0x1Fuy; 0xC2uy; 0xB5uy; 0x7Cuy; 0xA2uy; 0x19uy; 0xD8uy; 0xE5uy; 0x7Cuy; 0x2Fuy; 0x83uy; 0xDAuy;
0xF7uy; 0x6Buy; 0x90uy; 0xFEuy; 0xC4uy; 0x01uy; 0x5Auy; 0x97uy; 0x61uy; 0xA6uy; 0x3Duy; 0x40uy; 0x0Buy; 0x58uy; 0xE6uy; 0x3Duy;
// S3
0x4Duy; 0xD1uy; 0xB2uy; 0x0Fuy; 0x28uy; 0xBDuy; 0xE4uy; 0x78uy; 0xF6uy; 0x4Auy; 0x0Fuy; 0x93uy; 0x8Buy; 0x17uy; 0xD1uy; 0xA4uy;
0x3Auy; 0xECuy; 0xC9uy; 0x35uy; 0x93uy; 0x56uy; 0x7Euy; 0xCBuy; 0x55uy; 0x20uy; 0xA0uy; 0xFEuy; 0x6Cuy; 0x89uy; 0x17uy; 0x62uy;
0x17uy; 0x62uy; 0x4Buy; 0xB1uy; 0xB4uy; 0xDEuy; 0xD1uy; 0x87uy; 0xC9uy; 0x14uy; 0x3Cuy; 0x4Auy; 0x7Euy; 0xA8uy; 0xE2uy; 0x7Duy;
0xA0uy; 0x9Fuy; 0xF6uy; 0x5Cuy; 0x6Auy; 0x09uy; 0x8Duy; 0xF0uy; 0x0Fuy; 0xE3uy; 0x53uy; 0x25uy; 0x95uy; 0x36uy; 0x28uy; 0xCBuy;
|]
let P = [|
0x17uy; 0x1Cuy; 0x0Cuy; 0x1Buy; 0x13uy; 0x14uy; 0x01uy; 0x08uy; 0x0Euy; 0x19uy; 0x16uy; 0x03uy; 0x1Euy; 0x11uy; 0x09uy; 0x07uy;
0x0Fuy; 0x05uy; 0x1Duy; 0x18uy; 0x0Auy; 0x10uy; 0x00uy; 0x06uy; 0x1Fuy; 0x04uy; 0x0Duy; 0x12uy; 0x02uy; 0x1Auy; 0x0Buy; 0x15uy;
|]
let getaddr (arr : byte[]) =
NativeMemory.AlignedAlloc(unativeint arr.Length, if Vector512.IsHardwareAccelerated then unativeint Vector512<byte>.Count else unativeint Vector256<byte>.Count)
|> NativePtr.ofVoidPtr<byte>
|>! fun addr -> Marshal.Copy(arr, 0, NativePtr.toNativeInt addr, arr.Length)
getaddr IP, getaddr IPINV, getaddr S, getaddr P
let private ProcessBlockOptimized (buf : nativeptr<byte>) =
let buf = NativePtr.toNativeInt buf |> NativePtr.ofNativeInt<uint64>
let inline shuffle (table : nativeptr<byte>) src =
let mutable dst = 0uL
for offset = 0 to 31 do dst <- src >>> int (NativePtr.get table offset) &&& 1uL <<< offset ||| dst
dst
let inline mixbit8 (table : nativeptr<byte>) src =
shuffle (NativePtr.add table 32) src <<< 32 ||| shuffle table src
let inline mixbitESP src =
0uL ||| (((src >>> 0x23 &&& 0b11111uL) ||| (src >>> 0x33 &&& 0b100000uL) |> int ||| 0 |> NativePtr.get S &&& 0xF0uy) ||| ((src >>> 0x2F &&& 0b1uL) ||| (src >>> 0x1F &&& 0b111110uL) |> int ||| 0 |> NativePtr.get S &&& 0x0Fuy) |> uint64 <<< 0)
||| (((src >>> 0x2B &&& 0b11111uL) ||| (src >>> 0x1B &&& 0b100000uL) |> int ||| 64 |> NativePtr.get S &&& 0xF0uy) ||| ((src >>> 0x37 &&& 0b1uL) ||| (src >>> 0x27 &&& 0b111110uL) |> int ||| 64 |> NativePtr.get S &&& 0x0Fuy) |> uint64 <<< 8)
||| (((src >>> 0x33 &&& 0b11111uL) ||| (src >>> 0x23 &&& 0b100000uL) |> int ||| 128 |> NativePtr.get S &&& 0xF0uy) ||| ((src >>> 0x3F &&& 0b1uL) ||| (src >>> 0x2F &&& 0b111110uL) |> int ||| 128 |> NativePtr.get S &&& 0x0Fuy) |> uint64 <<< 16)
||| (((src >>> 0x3B &&& 0b11111uL) ||| (src >>> 0x2B &&& 0b100000uL) |> int ||| 192 |> NativePtr.get S &&& 0xF0uy) ||| ((src >>> 0x27 &&& 0b1uL) ||| (src >>> 0x37 &&& 0b111110uL) |> int ||| 192 |> NativePtr.get S &&& 0x0Fuy) |> uint64 <<< 24)
|> shuffle P
let tmp = NativePtr.read buf |> mixbit8 IP
mixbitESP tmp ^^^ tmp |> mixbit8 IPINV |> NativePtr.write buf
.NET7向けAVX2化
shuffle
関数の部分は32回shiftを繰り返しています。こういうところがSIMDの出番です。
AVX2であれば256bitレジスタを使い、64bit整数を4個同時にshiftできます。具体的にはVPSRLVQ
命令(Avx2.ShiftRightLogicalVariable
メソッド)です。shiftで使わないビットが影響を及ぼさないように事前に&&& 0xFFuL
とビットマスクが必要です。
AVX2はshiftした結果から1bitずつ取り出すのが実は苦手ですが、256bitレジスタをバイト単位で最上位ビットを集めるVPMOVMSKB
命令(Avx2.MoveMask
メソッド、Vector256.ExtractMostSignificantBits
メソッド)が使えます。この命令が使えるように<<< 7
することで最下位ビットの値を最上位ビットへ移動します。
これらを使うと、64bit整数を4個同時にshiftした結果をいい感じに集約して256bit = 32byteの並びを得れば32bit分の結果が得られます。2回繰り返せば64bit分が得られます。
集約する際、128bit境界をなるべく跨がないように集めるとパフォーマンスが良くなります。この場合、VPUNPCKLBW
命令(Avx2.UnpackLow
メソッド)を使うと2つのレジスタの下位バイトを集約することができます。このための準備としてshift結果を下位バイトに集めておくため、VPSHUFB
命令(Avx2.Shuffle
メソッド)を使います。unpackするサイズを順番に広げていけば、集約しながら整列も進みます。
集約されていく様子はコメントにメモしてあります。
ループアンローリングしたE
部分ですが、本来は64bit shiftですが、shift量を調節すれば32bit shiftで完結できそうです。256bitレジスタ2回で処理できます。その後のS
から読み取る部分は、本来であればGATHER命令の出番ですが、Downfall脆弱性によりパフォーマンスが低下するので、SIMD処理を諦め汎用レジスタで処理します。その後のP
の部分は下位32bitしか使われていないため、256bitレジスタで8個同時にshiftできるため、mixbit8とは関数を分けました。
ずいぶん長いコードになってしまいましたが、関数呼び出しは全てインライン展開され、分岐やループのない真っ直ぐなコードになりました。
let private ProcessBlockAvx2 (buf : nativeptr<byte>) =
let buf = NativePtr.toNativeInt buf |> NativePtr.ofNativeInt<uint64>
let ymm2 = Vector128.CreateScalarUnsafe 0xFFuL |> Avx2.BroadcastScalarToVector256
let ymm3 = Vector256.Create(0x08_00us).AsByte()
let inline mixbit8 (ptr : nativeptr<byte>) (ymm0 : Vector256<uint64>) =
let inline shuffle8 offset =
let mutable ymm1 = NativePtr.add ptr offset |> Vector256.LoadAligned |> _.As()
let ymm4 = Avx2.Shuffle(Avx2.ShiftRightLogicalVariable(ymm0, ymm1 &&& ymm2).As(), ymm3) in ymm1 <- ymm1 >>> 8 // ______________yx______________80 ← _______y_______x_______8_______0
let ymm5 = Avx2.Shuffle(Avx2.ShiftRightLogicalVariable(ymm0, ymm1 &&& ymm2).As(), ymm3) in ymm1 <- ymm1 >>> 8 // ______________yx______________91 ← _______y_______x_______9_______1
let ymm6 = Avx2.UnpackLow(ymm4, ymm5).AsUInt16() // ____________yyxx____________9810
let ymm4 = Avx2.Shuffle(Avx2.ShiftRightLogicalVariable(ymm0, ymm1 &&& ymm2).As(), ymm3) in ymm1 <- ymm1 >>> 8 // ______________yx______________A2 ← _______y_______x_______A_______2
let ymm5 = Avx2.Shuffle(Avx2.ShiftRightLogicalVariable(ymm0, ymm1 &&& ymm2).As(), ymm3) in ymm1 <- ymm1 >>> 8 // ______________yx______________B3 ← _______y_______x_______B_______3
let ymm4 = Avx2.UnpackLow(ymm4, ymm5).AsUInt16() // ____________yyxx____________BA32
let ymm6 = Avx2.UnpackLow(ymm6, ymm4).AsUInt32() // ________yyyyxxxx________BA983210
let ymm4 = Avx2.Shuffle(Avx2.ShiftRightLogicalVariable(ymm0, ymm1 &&& ymm2).As(), ymm3) in ymm1 <- ymm1 >>> 8 // ______________yx______________C4 ← _______y_______x_______C_______4
let ymm5 = Avx2.Shuffle(Avx2.ShiftRightLogicalVariable(ymm0, ymm1 &&& ymm2).As(), ymm3) in ymm1 <- ymm1 >>> 8 // ______________yx______________D5 ← _______y_______x_______D_______5
let ymm5 = Avx2.UnpackLow(ymm4, ymm5).AsUInt16() // ____________yyxx____________DC54
let ymm4 = Avx2.Shuffle(Avx2.ShiftRightLogicalVariable(ymm0, ymm1 &&& ymm2).As(), ymm3) in ymm1 <- ymm1 >>> 8 // ______________yx______________E6 ← _______y_______x_______E_______6
let ymm0 = Avx2.Shuffle(Avx2.ShiftRightLogicalVariable(ymm0, ymm1 &&& ymm2).As(), ymm3) in ymm1 <- ymm1 >>> 8 // ______________yx______________F7 ← _______y_______x_______F_______7
let ymm0 = Avx2.UnpackLow(ymm4, ymm0).AsUInt16() // ____________yyxx____________FE76
let ymm0 = Avx2.UnpackLow(ymm5, ymm0).AsUInt32() // ________yyyyxxxx________FEDC7654
Avx2.UnpackLow(ymm6, ymm0) <<< 7 |> _.AsByte().ExtractMostSignificantBits() |> uint64 // yyyyyyyyxxxxxxxxFEDCBA9876543210
shuffle8 32 <<< 32 ||| shuffle8 0
let inline mixbitESP (src : uint64) =
let getbyte index =
NativePtr.toNativeInt S + nativeint index |> NativePtr.ofNativeInt<byte> |> NativePtr.read
let ymm0 = src >>> 32 |> uint32 |> Vector256.Create
let ymm1 = Avx2.ShiftRightLogicalVariable(ymm0, Vector256.Create(0x03u, 0x0Bu, 0x13u, 0x1Bu, 0x0Fu, 0x17u, 0x1Fu, 0x07u))
let ymm0 = Avx2.ShiftLeftLogicalVariable(Avx2.ShiftRightLogicalVariable(ymm0, Vector256.Create(0x18u, 0x00u, 0x08u, 0x10u, 0x00u, 0x08u, 0x10u, 0x18u)), Vector256.Create(5u, 5u, 5u, 5u, 1u, 1u, 1u, 1u))
let ymm1 = ymm1 &&& Vector256.Create(0b011111u, 0b011111u, 0b011111u, 0b011111u, 0b000001u, 0b000001u, 0b000001u, 0b000001u)
let ymm0 = ymm0 &&& Vector256.Create(0b100000u, 0b100000u, 0b100000u, 0b100000u, 0b111110u, 0b111110u, 0b111110u, 0b111110u)
let ymm0 = ymm1 ||| ymm0 ||| Vector256.Create(0u, 64u, 128u, 192u, 0u, 64u, 128u, 192u)
let xmm1 = ymm0.GetUpper()
let xmm0 = ymm0.GetLower()
let xmm0 = Vector128.CreateScalar(getbyte xmm0[0]).WithElement(1, getbyte xmm0[1]).WithElement(2, getbyte xmm0[2]).WithElement(3, getbyte xmm0[3]).As() &&& Vector128.CreateScalar 0xF0_F0_F0_F0u
let xmm1 = Vector128.CreateScalar(getbyte xmm1[0]).WithElement(1, getbyte xmm1[1]).WithElement(2, getbyte xmm1[2]).WithElement(3, getbyte xmm1[3]).As() &&& Vector128.CreateScalar 0x0F_0F_0F_0Fu
let ymm0 = xmm0 ||| xmm1 |> Avx2.BroadcastScalarToVector256
let mutable ymm1 = Vector256.LoadAligned(P).As()
let ymm4 = Vector256.Create 0xFFu
let ymm5 = Vector256.Create(0x0C_08_04_00uL).AsByte()
let ymm3 = Avx2.Shuffle(Avx2.ShiftRightLogicalVariable(ymm0, ymm1 &&& ymm4).As(), ymm5) in ymm1 <- ymm1 >>> 8 // ____________yyxx____________C840 ← ___y___y___x___x___C___8___4___0
let ymm2 = Avx2.Shuffle(Avx2.ShiftRightLogicalVariable(ymm0, ymm1 &&& ymm4).As(), ymm5) in ymm1 <- ymm1 >>> 8 // ____________yyxx____________D951 ← ___y___y___x___x___D___9___5___1
let ymm2 = Avx2.UnpackLow(ymm3, ymm2).AsUInt16() // ________yyyyxxxx________DC985410
let ymm3 = Avx2.Shuffle(Avx2.ShiftRightLogicalVariable(ymm0, ymm1 &&& ymm4).As(), ymm5) in ymm1 <- ymm1 >>> 8 // ____________yyxx____________EA62 ← ___y___y___x___x___E___A___6___2
let ymm0 = Avx2.Shuffle(Avx2.ShiftRightLogicalVariable(ymm0, ymm1 &&& ymm4).As(), ymm5) in ymm1 <- ymm1 >>> 8 // ____________yyxx____________FB73 ← ___y___y___x___x___F___B___7___3
let ymm0 = Avx2.UnpackLow(ymm3, ymm0).AsUInt16() // ________yyyyxxxx________FEBA7632
Avx2.UnpackLow(ymm2, ymm0) <<< 7 |> _.AsByte().ExtractMostSignificantBits() |> uint64 // yyyyyyyyxxxxxxxxFEDCBA9876543210
let tmp = Avx2.BroadcastScalarToVector256 buf |> mixbit8 IP
mixbitESP tmp ^^^ tmp |> Vector256.Create |> mixbit8 IPINV |> NativePtr.write buf
.NET8向けAVX-512化
AVX-512であれば512bitレジスタを使い、64bit整数を8個同時にshiftできます。AVX2から半減します。しかし、シフト量を表すテーブルは32bitしかなくこのままでは512bitレジスタを活かせません。単純に上位にコピーしても同じアドレスをshiftしてしまい意味がありません。ずらそうとするとずらす手間と整列する手間がかかります。AVX2ではshift前にビットマスクもしていました。
これらを一気に解決する方法としてAVX-512で追加されたVPMOVZXBQ
命令(Avx512F.ConvertToVector512UInt64
メソッド)が使えます。8bit×8をゼロ拡張して64bit×8に配置してくれます。ちなみにVPMOVZXBQ
命令はレジスタだけでなくメモリを受け付けるため、AVX2版のAvx2.ConvertToVector256UInt64
メソッドはポインターを引数に取るオーバーロードが用意されていますが、Avx512F.ConvertToVector512UInt64
は確信犯的にポインターバージョンが削除されています。仕方がないので一旦Vector128
に読み込んでいます。
集約にはAVX-512で追加されたVPERMT2B
命令(Avx512Vbmi.PermuteVar64x8x2
メソッド)で64バイト×2を任意の位置にshuffleすることで一気に整列できます。VPMOVZXBQ
とVPERMT2B
のどちらも128bit境界を跨ぎますが、AVX2よりも効率的に処理できるので結果的には良いパフォーマンスが得られます。shiftしてmove maskする点は同じですが、どのバイトをshiftしどう集約するか、出来上がったコードはAVX2とは全く異なるアプローチになりました。
ループアンローリングしたE
部分ですが、512bitレジスタを使えば1回で処理できます。しかし256bitレジスタに比べて521bitレジスタでの実行は倍以上の時間がかかり、逆に遅くなってしまいます。そのため、ここはAVX2版から変更なしです。…と言いつつ、AVX-512で追加されたVPRORVQ
命令(Avx512F.VL.RotateRightVariable
メソッド)を使っています。AVX2版では溢れないようshift量を調整していましたが、rotateすることで溢れても気にしないようになっています。また、|||
を重ねていた部分もAVX-512で追加された三項演算子VPTERNLOGD
命令(Avx512F.VL.TernaryLogic
メソッド)に変えました。
S
を読み取る部分ですが、S
は32バイト=512bitに収まるので4レジスタに読み込みそこからpermuteで取り出すようにしました。
AVX2版よりは減りましたが、まだまだ長いコードです。
let private ProcessBlockAvx512 (buf : nativeptr<byte>) =
let buf = NativePtr.toNativeInt buf |> NativePtr.ofNativeInt<uint64>
let zmm15 = Vector128.Create(0x38_30_28_20_18_10_08_00uL, 0x78_70_68_60_58_50_48_40uL).ToVector256Unsafe().ToVector512Unsafe().AsByte()
let inline mixbit8 (ptr : nativeptr<byte>) (zmm0 : Vector512<uint64>) =
let ptr = NativePtr.toNativeInt ptr |> NativePtr.ofNativeInt<uint64>
let zmm1 = Avx512DQ.ShiftRightLogicalVariable(zmm0, NativePtr.get ptr 0 |> Vector128.CreateScalar |> _.AsByte() |> Avx512F.ConvertToVector512UInt64) // _______7_______6_______5_______4_______3_______2_______1_______0
let zmm2 = Avx512DQ.ShiftRightLogicalVariable(zmm0, NativePtr.get ptr 1 |> Vector128.CreateScalar |> _.AsByte() |> Avx512F.ConvertToVector512UInt64) // _______F_______E_______D_______C_______B_______A_______9_______8
let zmm3 = Avx512DQ.ShiftRightLogicalVariable(zmm0, NativePtr.get ptr 2 |> Vector128.CreateScalar |> _.AsByte() |> Avx512F.ConvertToVector512UInt64) // _______x_______x_______x_______x_______x_______x_______x_______x
let zmm4 = Avx512DQ.ShiftRightLogicalVariable(zmm0, NativePtr.get ptr 3 |> Vector128.CreateScalar |> _.AsByte() |> Avx512F.ConvertToVector512UInt64) // _______y_______y_______y_______y_______y_______y_______y_______y
let zmm5 = Avx512DQ.ShiftRightLogicalVariable(zmm0, NativePtr.get ptr 4 |> Vector128.CreateScalar |> _.AsByte() |> Avx512F.ConvertToVector512UInt64)
let zmm6 = Avx512DQ.ShiftRightLogicalVariable(zmm0, NativePtr.get ptr 5 |> Vector128.CreateScalar |> _.AsByte() |> Avx512F.ConvertToVector512UInt64)
let zmm7 = Avx512DQ.ShiftRightLogicalVariable(zmm0, NativePtr.get ptr 6 |> Vector128.CreateScalar |> _.AsByte() |> Avx512F.ConvertToVector512UInt64)
let zmm8 = Avx512DQ.ShiftRightLogicalVariable(zmm0, NativePtr.get ptr 7 |> Vector128.CreateScalar |> _.AsByte() |> Avx512F.ConvertToVector512UInt64)
let ymm1 = Avx512Vbmi.PermuteVar64x8x2(zmm1.As(), zmm15, zmm2.As()).GetLower() // ________________FEDCBA9876543210
let ymm3 = Avx512Vbmi.PermuteVar64x8x2(zmm3.As(), zmm15, zmm4.As()).GetLower() // ________________yyyyyyyyxxxxxxxx
let ymm5 = Avx512Vbmi.PermuteVar64x8x2(zmm5.As(), zmm15, zmm6.As()).GetLower()
let ymm7 = Avx512Vbmi.PermuteVar64x8x2(zmm7.As(), zmm15, zmm8.As()).GetLower()
let rax = ymm3.GetLower() |> ymm1.WithUpper |> _.AsUInt64() <<< 7 |> _.AsByte().ExtractMostSignificantBits() |> uint64 // yyyyyyyyxxxxxxxxFEDCBA9876543210
let rdx = ymm7.GetLower() |> ymm5.WithUpper |> _.AsUInt64() <<< 7 |> _.AsByte().ExtractMostSignificantBits() |> uint64
rdx <<< 32 ||| rax
let inline mixbitESP (src : uint64) =
let ymm0 = src >>> 32 |> uint32 |> Vector256.Create
let ymm1 = Avx2.ShiftRightLogicalVariable(ymm0, Vector256.Create(0x03u, 0x0Bu, 0x13u, 0x1Bu, 0x0Fu, 0x17u, 0x1Fu, 0x07u)) &&& Vector256.Create(0b011111u, 0b011111u, 0b011111u, 0b011111u, 0b000001u, 0b000001u, 0b000001u, 0b000001u)
let ymm0 = Avx512F.VL.RotateRightVariable(ymm0, Vector256.Create(0x13u, 0x1Bu, 0x03u, 0x0Bu, 0x1Fu, 0x07u, 0x0Fu, 0x17u)) &&& Vector256.Create(0b100000u, 0b100000u, 0b100000u, 0b100000u, 0b111110u, 0b111110u, 0b111110u, 0b111110u)
let zmm0 = Avx512F.VL.TernaryLogic(ymm0, ymm1, Vector256.Create(0x00000040_00000000uL).As(), 0xF0uy ||| 0xCCuy ||| 0xAAuy).ToVector512Unsafe().As()
let ymm2 = Vector256.Create(0xF0u, 0xF0u, 0xF0u, 0xF0u, 0x0Fu, 0x0Fu, 0x0Fu, 0x0Fu).As()
let ymm1 = Avx512Vbmi.PermuteVar64x8x2(NativePtr.add S 0 |> Vector512.LoadAligned, zmm0, NativePtr.add S 64 |> Vector512.LoadAligned).GetLower() &&& ymm2
let ymm0 = Avx512Vbmi.PermuteVar64x8x2(NativePtr.add S 128 |> Vector512.LoadAligned, zmm0, NativePtr.add S 192 |> Vector512.LoadAligned).GetLower() &&& ymm2
let xmm1 = ymm1.GetUpper() ||| ymm1.GetLower()
let xmm0 = ymm0.GetUpper() ||| ymm0.GetLower()
let zmm1 = Avx512Vbmi.VL.PermuteVar16x8x2(xmm1, Vector128.CreateScalar(0x1C_18_04_00u).As(), xmm0).AsUInt32() |> Avx512F.BroadcastScalarToVector512
let zmm2 = Vector512.Create(0x1C_18_14_10_0C_08_04_00uL, 0x3C_38_34_30_2C_28_24_20uL, 0x5C_58_54_50_4C_48_44_40uL, 0x7C_78_74_70_6C_68_64_60uL, 0uL, 0uL, 0uL, 0uL).AsByte()
let zmm0 = Avx512DQ.ShiftRightLogicalVariable(zmm1, NativePtr.add P 0 |> Vector128.LoadAligned |> Avx512F.ConvertToVector512UInt32) // ___F___E___D___C___B___A___9___8___7___6___5___4___3___2___1___0
let zmm1 = Avx512DQ.ShiftRightLogicalVariable(zmm1, NativePtr.add P 16 |> Vector128.LoadAligned |> Avx512F.ConvertToVector512UInt32) // ___y___y___y___y___y___y___y___y___x___x___x___x___x___x___x___x
let ymm0 = Avx512Vbmi.PermuteVar64x8x2(zmm0.As(), zmm2, zmm1.As()).GetLower() // yyyyyyyyxxxxxxxxFEDCBA9876543210
ymm0.AsUInt64() <<< 7 |> _.AsByte().ExtractMostSignificantBits() |> uint64
let tmp = NativePtr.read buf |> Vector512.Create |> mixbit8 IP
mixbitESP tmp ^^^ tmp |> Vector512.Create |> mixbit8 IPINV |> NativePtr.write buf
.NET9向けAVX-512化
実は.NET8での実装が間に合わず漏れた命令があります。VPMULTISHIFTQB
命令(Avx512Vbmi.MultiShift
メソッド)は64並列でshiftを行うことができます。並び変えることなく一気に実行できるので結果も整列済みになります。今までの苦労は何だったのかと。そしてこんな便利な命令が.NET8から漏れてしまったことが非常に残念です。
他の部分は変更有りませんが、全体としてかなりすっきりしました。
let private ProcessBlock (buf : nativeptr<byte>) =
let buf = NativePtr.toNativeInt buf |> NativePtr.ofNativeInt<uint64>
let inline mixbit8 (ptr : nativeptr<byte>) (src : uint64) =
Avx512Vbmi.MultiShift(Vector512.Load ptr, Vector512.Create src).AsUInt64() <<< 7 |> _.AsByte().ExtractMostSignificantBits() // yyyyyyyyxxxxxxxxFEDCBA9876543210
let inline mixbitESP (src : uint64) =
let ymm0 = src >>> 32 |> uint32 |> Vector256.Create
let ymm1 = Avx2.ShiftRightLogicalVariable(ymm0, Vector256.Create(0x03u, 0x0Bu, 0x13u, 0x1Bu, 0x0Fu, 0x17u, 0x1Fu, 0x07u)) &&& Vector256.Create(0b011111u, 0b011111u, 0b011111u, 0b011111u, 0b000001u, 0b000001u, 0b000001u, 0b000001u)
let ymm0 = Avx512F.VL.RotateRightVariable(ymm0, Vector256.Create(0x13u, 0x1Bu, 0x03u, 0x0Bu, 0x1Fu, 0x07u, 0x0Fu, 0x17u)) &&& Vector256.Create(0b100000u, 0b100000u, 0b100000u, 0b100000u, 0b111110u, 0b111110u, 0b111110u, 0b111110u)
let zmm0 = Avx512F.VL.TernaryLogic(ymm0, ymm1, Vector256.Create(0x00000040_00000000uL).As(), 0xF0uy ||| 0xCCuy ||| 0xAAuy).ToVector512Unsafe().As()
let ymm2 = Vector256.Create(0xF0u, 0xF0u, 0xF0u, 0xF0u, 0x0Fu, 0x0Fu, 0x0Fu, 0x0Fu).As()
let ymm1 = Avx512Vbmi.PermuteVar64x8x2(NativePtr.add S 0 |> Vector512.LoadAligned, zmm0, NativePtr.add S 64 |> Vector512.LoadAligned).GetLower() &&& ymm2
let ymm0 = Avx512Vbmi.PermuteVar64x8x2(NativePtr.add S 128 |> Vector512.LoadAligned, zmm0, NativePtr.add S 192 |> Vector512.LoadAligned).GetLower() &&& ymm2
let xmm1 = ymm1.GetUpper() ||| ymm1.GetLower()
let xmm0 = ymm0.GetUpper() ||| ymm0.GetLower()
let ymm0 = Avx512Vbmi.VL.PermuteVar16x8x2(xmm1, Vector128.CreateScalar(0x1C_18_04_00u).As(), xmm0).AsUInt32() |> Avx512F.BroadcastScalarToVector256
Avx512Vbmi.VL.MultiShift(Vector256.Load P, ymm0).AsUInt64() <<< 7 |> _.AsByte().ExtractMostSignificantBits() |> uint64 // yyyyyyyyxxxxxxxxFEDCBA9876543210
let tmp = NativePtr.read buf |> mixbit8 IP
mixbitESP tmp ^^^ tmp |> mixbit8 IPINV |> NativePtr.write buf
効果測定
BenchmarkDotNetで測ってみます。
- Intel Core i7-1065G7 CPU 1.30GHz
- .NET 8.0.0 (8.0.23.53103), X64 RyuJIT AVX2
Method | Mean | Error | StdDev | Ratio | Code Size |
---|---|---|---|---|---|
Baseline | 738.36 ns | 7.809 ns | 7.305 ns | 1.00 | 890 B |
Optimized | 349.12 ns | 1.949 ns | 1.823 ns | 0.47 | 699 B |
UseAVX2 | 84.37 ns | 0.852 ns | 0.797 ns | 0.11 | 1,402 B |
UseAVX512 | 52.22 ns | 0.608 ns | 0.539 ns | 0.07 | 873 B |
.NET8向けAVX-512化では元コードから14倍の高速化となりました。
Discussion