😸

.NET 7こそがC# SIMDプログラミングを始めるのに最適である理由

2022/12/02に公開

結論

  • fixedステートメントを使わずとも良い → インデントが1段浅くなる
  • x86系とARM系で同じことを2度も書かずとも良い
  • 名前空間やクラスがわかりやすく整理されていて検索の手間が少ない

背景

.NET Core 3.1の頃からHardware Intrinsicsが正式に導入され、同時にSIMDをC#から利用できるようになりました。
うろ覚えですがこの前後からSpan<T>のIndexOfがSIMDを利用して高速化を行なうようになったのではないでしょうか。
その後.NET 5,6と代を重ねるに連れて.NET内部ではより多くの箇所で(Last)IndexOf(Any)が使用され、よりSIMDの恩恵を受けるようになってきました。

UnityのBurst CompilerでSIMD API露出させるなどして多少知名度上がった気がしますが、実際にアプリを書く側でSIMDをわざわざ使う人はあんまりいない気がします。
その理由としては(SIMDについてよく知らないというのを除けば)、Vector(256|128)<T>を扱うのが辛いというのが最大の理由だったのではないでしょうか。

(ReadOnly)?Span<T>からVector(256|128)<T>を生成してそれに対して処理を施し、Span<T>に書き込むかあるいはスカラー値を得るという流れがSIMDでは頻出します。
この読み込み処理を効率的に行なうマネージドなAPIが実は.NET 7になるまで存在しませんでした。
unsafeなポインタをSpanから取得してSIMD API固有のLoadメソッドを呼び出す必要がありました。
GCが作動した際の安全性とパフォーマンス劣化を考慮すると出来ればポインタを使わずに(ReadOnly)?SpanからVector(128|256)<T>を取り出せる方が良かったのです。
.NET 7でとうとうVector(128|256).LoadUnsafe<T>(ref T)が登場しました。

具体例

以下のコードは簡単な暗号化が施されたファイルを復号するコードです。
4byteのマジックナンバー[0x01,0x02,0x03,0x04]に引き続いて28byteのデータ(これが復号キー)が配置され、その後に暗号化されているファイルデータが続きます。
復号するのは実に簡単で28byteずつそれぞれキーを引き算するだけです。
復号キーが32byteならメチャクチャ簡単にSIMD処理できるのですが、元のファイルフォーマットがx86時代のC++で処理されていましたので仕方ないですね。

SIMDを使用せずに書くと以下のようになります。

static void Decrypt(ReadOnlySpan<byte> cryptSpan, Span<byte> content)
{
    for (int index = 0, cryptIndex = 4; index < content.Length; ++index)
    {
        content[index] -= cryptSpan[cryptIndex];
        if (++cryptIndex == 28)
        {
            cryptIndex = 0;
        }
    }
}

.NET 6時代のSIMDバージョン

public static void Decrypt(ReadOnlySpan<byte> cryptSpan, Span<byte> content)
{
    byte* temp = stackalloc byte[56];
    fixed (byte* cryptSource = cryptSpan)
    {
        Buffer.MemoryCopy(cryptSource, temp, 28, 28);
        Buffer.MemoryCopy(cryptSource, temp + 28, 28, 28);
    }

    fixed (byte* contentPtr = content)
    {
        byte* end = contentPtr + content.Length;
        byte* itrEnd = contentPtr;

        if (Avx2.IsSupported)
        {
            const int LoopSize = 32;
            itrEnd += (content.Length / (LoopSize * 7)) * (LoopSize * 7);
            Vector256<byte> v0 = Avx.LoadVector256(temp);
            Vector256<byte> v1 = Avx.LoadVector256(temp + 4);
            Vector256<byte> v2 = Avx.LoadVector256(temp + 8);
            Vector256<byte> v3 = Avx.LoadVector256(temp + 12);
            Vector256<byte> v4 = Avx.LoadVector256(temp + 16);
            Vector256<byte> v5 = Avx.LoadVector256(temp + 20);
            Vector256<byte> v6 = Avx.LoadVector256(temp + 24);
            for (var itr = contentPtr; itr != itrEnd;)
            {
                Avx.Store(itr, Avx2.Subtract(Avx.LoadVector256(itr), v1));
                itr += LoopSize;
                Avx.Store(itr, Avx2.Subtract(Avx.LoadVector256(itr), v2));
                itr += LoopSize;
                Avx.Store(itr, Avx2.Subtract(Avx.LoadVector256(itr), v3));
                itr += LoopSize;
                Avx.Store(itr, Avx2.Subtract(Avx.LoadVector256(itr), v4));
                itr += LoopSize;
                Avx.Store(itr, Avx2.Subtract(Avx.LoadVector256(itr), v5));
                itr += LoopSize;
                Avx.Store(itr, Avx2.Subtract(Avx.LoadVector256(itr), v6));
                itr += LoopSize;
                Avx.Store(itr, Avx2.Subtract(Avx.LoadVector256(itr), v0));
                itr += LoopSize;
            }
        }
        else if (Sse2.IsSupported)
        {
            const int LoopSize = 16;
            itrEnd += (content.Length / (LoopSize * 7)) * (LoopSize * 7);
            Vector128<byte> v0 = Sse2.LoadVector128(temp);
            Vector128<byte> v1 = Sse2.LoadVector128(temp + 4);
            Vector128<byte> v2 = Sse2.LoadVector128(temp + 8);
            Vector128<byte> v3 = Sse2.LoadVector128(temp + 12);
            Vector128<byte> v4 = Sse2.LoadVector128(temp + 16);
            Vector128<byte> v5 = Sse2.LoadVector128(temp + 20);
            Vector128<byte> v6 = Sse2.LoadVector128(temp + 24);
            for (var itr = contentPtr; itr != itrEnd;)
            {
                Sse2.Store(itr, Sse2.Subtract(Sse2.LoadVector128(itr), v1));
                itr += LoopSize;
                Sse2.Store(itr, Sse2.Subtract(Sse2.LoadVector128(itr), v5));
                itr += LoopSize;
                Sse2.Store(itr, Sse2.Subtract(Sse2.LoadVector128(itr), v2));
                itr += LoopSize;
                Sse2.Store(itr, Sse2.Subtract(Sse2.LoadVector128(itr), v6));
                itr += LoopSize;
                Sse2.Store(itr, Sse2.Subtract(Sse2.LoadVector128(itr), v3));
                itr += LoopSize;
                Sse2.Store(itr, Sse2.Subtract(Sse2.LoadVector128(itr), v0));
                itr += LoopSize;
                Sse2.Store(itr, Sse2.Subtract(Sse2.LoadVector128(itr), v4));
                itr += LoopSize;
            }
        }
        else if (AdvSimd.IsSupported)
        {
            const int LoopSize = 16;
            itrEnd += (content.Length / (LoopSize * 7)) * (LoopSize * 7);
            Vector128<byte> v0 = AdvSimd.LoadVector128(temp);
            Vector128<byte> v1 = AdvSimd.LoadVector128(temp + 4);
            Vector128<byte> v2 = AdvSimd.LoadVector128(temp + 8);
            Vector128<byte> v3 = AdvSimd.LoadVector128(temp + 12);
            Vector128<byte> v4 = AdvSimd.LoadVector128(temp + 16);
            Vector128<byte> v5 = AdvSimd.LoadVector128(temp + 20);
            Vector128<byte> v6 = AdvSimd.LoadVector128(temp + 24);
            for (var itr = contentPtr; itr != itrEnd;)
            {
                AdvSimd.Store(itr, AdvSimd.Subtract(AdvSimd.LoadVector128(itr), v1));
                itr += LoopSize;
                AdvSimd.Store(itr, AdvSimd.Subtract(AdvSimd.LoadVector128(itr), v5));
                itr += LoopSize;
                AdvSimd.Store(itr, AdvSimd.Subtract(AdvSimd.LoadVector128(itr), v2));
                itr += LoopSize;
                AdvSimd.Store(itr, AdvSimd.Subtract(AdvSimd.LoadVector128(itr), v6));
                itr += LoopSize;
                AdvSimd.Store(itr, AdvSimd.Subtract(AdvSimd.LoadVector128(itr), v3));
                itr += LoopSize;
                AdvSimd.Store(itr, AdvSimd.Subtract(AdvSimd.LoadVector128(itr), v0));
                itr += LoopSize;
                AdvSimd.Store(itr, AdvSimd.Subtract(AdvSimd.LoadVector128(itr), v4));
                itr += LoopSize;
            }
        }

        for (int index = 4; itrEnd != end; itrEnd++)
        {
            *itrEnd -= temp[index];
            if (++index == 56)
            {
                index = 0;
            }
        }
    }
}

.NET 7対応の為に書き直したバージョン

public static void Decrypt(scoped ReadOnlySpan<byte> cryptSpan, Span<byte> content)
{
    const int stride256 = 32 * 7;
    const int stride128 = 16 * 7;
    Span<byte> temp = stackalloc byte[56];
    if (cryptSpan.Length == 28)
    {
        cryptSpan.CopyTo(temp);
        cryptSpan.CopyTo(temp.Slice(28));
        cryptSpan = temp;
    }
    ref var itr = ref MemoryMarshal.GetReference(content);
    ref var itrEnd = ref Unsafe.AddByteOffset(ref itr, content.Length);
    ref var cryptStart = ref MemoryMarshal.GetReference(cryptSpan);
    if (content.Length >= stride256 && Vector256.IsHardwareAccelerated)
    {
        var v0 = Vector256.LoadUnsafe(ref cryptStart);
        var v1 = Vector256.LoadUnsafe(ref cryptStart, 4);
        var v2 = Vector256.LoadUnsafe(ref cryptStart, 8);
        var v3 = Vector256.LoadUnsafe(ref cryptStart, 12);
        var v4 = Vector256.LoadUnsafe(ref cryptStart, 16);
        var v5 = Vector256.LoadUnsafe(ref cryptStart, 20);
        var v6 = Vector256.LoadUnsafe(ref cryptStart, 24);
        itrEnd = ref Unsafe.Subtract(ref itrEnd, stride256);
        do
        {
            (Vector256.LoadUnsafe(ref itr) - v1).StoreUnsafe(ref itr);
            itr = ref Unsafe.AddByteOffset(ref itr, 32);
            (Vector256.LoadUnsafe(ref itr) - v2).StoreUnsafe(ref itr);
            itr = ref Unsafe.AddByteOffset(ref itr, 32);
            (Vector256.LoadUnsafe(ref itr) - v3).StoreUnsafe(ref itr);
            itr = ref Unsafe.AddByteOffset(ref itr, 32);
            (Vector256.LoadUnsafe(ref itr) - v4).StoreUnsafe(ref itr);
            itr = ref Unsafe.AddByteOffset(ref itr, 32);
            (Vector256.LoadUnsafe(ref itr) - v5).StoreUnsafe(ref itr);
            itr = ref Unsafe.AddByteOffset(ref itr, 32);
            (Vector256.LoadUnsafe(ref itr) - v6).StoreUnsafe(ref itr);
            itr = ref Unsafe.AddByteOffset(ref itr, 32);
            (Vector256.LoadUnsafe(ref itr) - v0).StoreUnsafe(ref itr);
            itr = ref Unsafe.AddByteOffset(ref itr, 32);
        } while (!Unsafe.IsAddressGreaterThan(ref itr, ref itrEnd));
        itrEnd = ref Unsafe.AddByteOffset(ref itrEnd, stride256);
    }
    else if (content.Length >= stride128 && Vector128.IsHardwareAccelerated)
    {
        var v0 = Vector128.LoadUnsafe(ref cryptStart);
        var v1 = Vector128.LoadUnsafe(ref cryptStart, 4);
        var v2 = Vector128.LoadUnsafe(ref cryptStart, 8);
        var v3 = Vector128.LoadUnsafe(ref cryptStart, 12);
        var v4 = Vector128.LoadUnsafe(ref cryptStart, 16);
        var v5 = Vector128.LoadUnsafe(ref cryptStart, 20);
        var v6 = Vector128.LoadUnsafe(ref cryptStart, 24);
        itrEnd = ref Unsafe.Subtract(ref itrEnd, stride128);
        do
        {
            (Vector128.LoadUnsafe(ref itr) - v1).StoreUnsafe(ref itr);
            itr = ref Unsafe.AddByteOffset(ref itr, 16);
            (Vector128.LoadUnsafe(ref itr) - v5).StoreUnsafe(ref itr);
            itr = ref Unsafe.AddByteOffset(ref itr, 16);
            (Vector128.LoadUnsafe(ref itr) - v2).StoreUnsafe(ref itr);
            itr = ref Unsafe.AddByteOffset(ref itr, 16);
            (Vector128.LoadUnsafe(ref itr) - v6).StoreUnsafe(ref itr);
            itr = ref Unsafe.AddByteOffset(ref itr, 16);
            (Vector128.LoadUnsafe(ref itr) - v3).StoreUnsafe(ref itr);
            itr = ref Unsafe.AddByteOffset(ref itr, 16);
            (Vector128.LoadUnsafe(ref itr) - v0).StoreUnsafe(ref itr);
            itr = ref Unsafe.AddByteOffset(ref itr, 16);
            (Vector128.LoadUnsafe(ref itr) - v4).StoreUnsafe(ref itr);
            itr = ref Unsafe.AddByteOffset(ref itr, 16);
        } while (!Unsafe.IsAddressGreaterThan(ref itr, ref itrEnd));
        itrEnd = ref Unsafe.AddByteOffset(ref itrEnd, stride128);
    }

    ref var cryptItr = ref Unsafe.AddByteOffset(ref cryptStart, 4);
    ref var cryptEnd = ref Unsafe.AddByteOffset(ref cryptItr, 52);
    while (Unsafe.IsAddressLessThan(ref itr, ref itrEnd))
    {
        itr -= cryptItr;
        itr = ref Unsafe.AddByteOffset(ref itr, 1);
        cryptItr = ref Unsafe.AddByteOffset(ref cryptItr, 1);
        if (Unsafe.AreSame(ref cryptItr, ref cryptEnd))
        {
            cryptItr = ref cryptStart;
        }
    }
}

冗長ですので名前空間は省略しています。

.NET6時代のコードと比較するとポインタ操作がエグいことになっていますが、代わりにC#11で正式化されたstatic abstract methodのおかげで単純な四則演算を算術記号で表記できるなど大分書き心地が良くなっています。

Discussion