🦍

[C#]Span<T>構造体で高速かつ安全に配列へアクセスして、行列積を素早く計算する。

2022/12/03に公開約12,100字5件のコメント

目的

今どきのC#erで危ねえポインターを使っているやついるぅ?! いねえよなあ!!
というわけで、以前の記事unsafeとポインターを扱ったので、今度はC# 7.2から追加されたSpan<T>構造体を用いて行列積のパフォーマンスを見てみます。
ここで、行列積を例に挙げている理由は下記の3点です。
1 3重ループのため、計算時間の差を可視化しやすい
2 ほどほどに複雑なアルゴリズムのため、教科書の次のレベルのコードになる
3 目的がはっきりしているため、パフォーマンス向上へのモチベーションを維持できる

参考URL

多次元配列ではトリッキーなことをしないとSpan<T>構造体を使えないようです。本末転倒なので今回は使用しません。
Span<T>構造体 - C# によるプログラミング入門 | ++C++; // 未確認飛行 C
c# - Span and two dimensional Arrays - Stack Overflow

結果

最初に計算結果を載せます。Span<T>構造体の方が遅かったです。
理由をご存じの方がいらっしゃいましたら、コメントいただけますと助かります。

配列 計算方法 計算時間
ジャグ配列 ①普通の行列積の計算 1分47秒
②右の行列を転置した計算 37秒57
③ ②+戻り値へのアクセスをポインターにした計算 28秒86
④ ②+戻り値へのアクセスをSpanにした計算 44秒70
多次元配列 ①普通の行列積の計算 2分17秒54
②右の行列を転置した計算 59秒96
③ ②+戻り値へのアクセスをポインターにした計算 1分8秒15

疑似コードで考えるプログラムの計算時間

ND列の行列ADM列の行列Bの行列積C=ABを考えます。行列積の計算量回数はO(2DMN)なのでデータを間引かないと計算時間を短縮できませんが、データアクセスを意識した疑似コードを書くことで効率化のヒントを得ることができます。例えば、右側の行列の転置行列B^Tを用いるとより高速にデータにアクセスできます。
今回のプログラムでは、右の行列の転置行列を作成して、戻り値の行列積をSpan<T>構造体にした場合のパフォーマンスを比べてみます。

疑似コードでの行列積
for(左側の行列の行 i)
{
    for(右側の行列の列 j)
    {
        for(左側の行列の列 k)
	{
	    左側の行列a[i,k]にアクセス
	    右側の行列b[k,j]にアクセス
	    a[i,k]とb[k,j]の積
	    積c[i,j]にアクセス
	    積c[i,j]にa[i,k]とb[k,j]の積を加える
	    次のkに移る(k+1の開始判定,境界判定a[i,k+1],b[k+1,j])
	}
	次のjに移る(j+1の開始判定,境界判定b[*,j+1],c[i,j+1])
    }
    次のiに移る(i+1の開始判定,境界判定a[i+1,*],c[i+1,*])
}

テストプログラム

前回のプログラムにSpan<T>構造体の計算を加えたテストメソッドです。比較用にunsafeが入っています。

テストプログラム
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace Sandbox.Tremendous1192.SelfEmployed.UnsafeMathDotNet.Zenn.Matrix
{
    /// <summary>
    /// 行列に関するテストクラス
    /// </summary>
    public static partial class TestMatrix
    {
        /// <summary>
        /// 行列積の速度比較
        /// </summary>
        public static void TestMultiplyUnsafeSpan()
        {
            // 初期化
            int N = 2022; // 左側の行列の行
            int D = 2022; // 左側の行列の列
            int M = 2022; // 右側の行列の列
            Console.WriteLine("左の行列の行数" + N + "\t列数" + D + "\t右の行列の行数" + M);

            double[,] left2D=new double[N, D];
            double[,] right2D=new double[D, M];

            double[][] leftJag = new double[N][];
            double[][] rightJag = new double[D][];

            for (int i=0;i<N;i++)
            {
                leftJag[i] = new double[D];
                for (int k=0;k<D;k++)
                {
                    left2D[i, k] = i + k;
                    leftJag[i][k] = i + k;                
                }
            }
            for (int k = 0; k < D; k++)
            {
                rightJag[k] = new double[M];
                for (int j = 0; j < M; j++)
                {
                    right2D[k, j] = k + j + 1;
                    rightJag[k][j] = k + j + 1;
                }
            }


            Stopwatch sw= new Stopwatch();
            sw.Start();
            // ジャグ配列の普通の計算
            {
                double[][] resultJag = new double[N][];
                for (int i = 0; i < N; i++)
                {
                    resultJag[i] = new double[M];
                    for (int j = 0; j < M; j++)
                    {
                        for (int k = 0; k < D; k++)
                        {
                            resultJag[i][j] += leftJag[i][k] * rightJag[k][j];
                        }
                    }
                }
            }
            GC.Collect();
            sw.Stop();
            string elapseJag = sw.Elapsed.ToString();
            //Console.WriteLine(elapseJag);

            // ジャグ配列の転置の計算
            sw.Restart();
            {
                double[][] transposedJag = new double[M][];
                for (int j = 0; j < M; j++)
                {
                    transposedJag[j] = new double[D];
                    for (int k = 0; k < D; k++)
                    {
                        transposedJag[j][k] = rightJag[k][j];
                    }
                }
                double[][] resultJag = new double[N][];
                for (int i = 0; i < N; i++)
                {
                    resultJag[i] = new double[M];
                    for (int j = 0; j < M; j++)
                    {
                        for (int k = 0; k < D; k++)
                        {
                            resultJag[i][j] += leftJag[i][k] * transposedJag[j][k];
                        }
                    }
                }
            }
            GC.Collect();
            sw.Stop();
            string elapseJagT = sw.Elapsed.ToString();
            //Console.WriteLine(elapseJagT);


            // ジャグ配列の転置の計算 + ポインター
            sw.Restart();
            unsafe
            {
                double[][] transposedJag = new double[M][];
                for (int j = 0; j < M; j++)
                {
                    transposedJag[j] = new double[D];
                    for (int k = 0; k < D; k++)
                    {
                        transposedJag[j][k] = rightJag[k][j];
                    }
                }
                double[][] resultJag = new double[N][];
                for (int i = 0; i < N; i++)
                {
                    resultJag[i] = new double[M];
                }

                for (int i = 0; i < N; i++)
                {
                    fixed (double* pResultJag = &resultJag[i][0])
                    {
                        int j = 0;
                        for (double* pR = pResultJag; pR != pResultJag + resultJag[i].Length; ++pR)
                        {
                            for (int k = 0; k < D; k++)
                            {
                              *pR += leftJag[i][k] * transposedJag[j][k];
                            }
                            ++j;
                        }
                    }
                }
            }
            GC.Collect();
            sw.Stop();
            string elapseJagTPointer = sw.Elapsed.ToString();
            //Console.WriteLine(elapseJagTPointer);


            // ジャグ配列の転置の計算 + Span
            sw.Restart();
            {
                double[][] transposedJag = new double[M][];
                for (int j = 0; j < M; j++)
                {
                    transposedJag[j] = new double[D];
                    for (int k = 0; k < D; k++)
                    {
                        transposedJag[j][k] = rightJag[k][j];
                    }
                }
                double[][] resultJag = new double[N][];
                Span<double[]> spanResultJag = resultJag.AsSpan();
                for (int i = 0; i < N; i++)
                {
                    resultJag[i] = new double[M];
                    for (int j = 0; j < M; j++)
                    {
                        for (int k = 0; k < D; k++)
                        {
                            spanResultJag[i][j] += leftJag[i][k] * transposedJag[j][k];
                        }
                    }
                }
            }
            GC.Collect();
            sw.Stop();
            string elapseJagTSpan = sw.Elapsed.ToString();
            //Console.WriteLine(elapseJagT);



            //多次元配列の普通の計算
            sw.Restart();
            {
                double[,] resultMulti = new double[N,M];
                for (int i = 0; i < N; i++)
                {
                    for (int j = 0; j < M; j++)
                    {
                        for (int k = 0; k < D; k++)
                        {
                            resultMulti[i,j] += left2D[i, k] * right2D[k, j];
                        }
                    }
                }
            }
            GC.Collect();
            sw.Stop();
            string elapseMulti = sw.Elapsed.ToString();
            //Console.WriteLine(elapseMulti);

            // 多次元配列の転置の計算
            sw.Restart();
            {
                double[,] transposedMulti = new double[M, D];
                for (int j = 0; j < M; j++)
                {
                    for (int k = 0; k < D; k++)
                    {
                        transposedMulti[j, k] = right2D[k, j];
                    }
                }
                double[,] resultMulti = new double[N, M];
                for (int i = 0; i < N; i++)
                {
                    for (int j = 0; j < M; j++)
                    {
                        for (int k = 0; k < D; k++)
                        {
                            resultMulti[i, j] += left2D[i, k] * transposedMulti[j, k];
                        }
                    }
                }
            }
            GC.Collect();
            sw.Stop();
            string elapseMultiT = sw.Elapsed.ToString();
            //Console.WriteLine(elapseMultiT);

            // 多次元配列の転置とポインターの計算
            sw.Restart();
            unsafe
            {
                double[,] transposedMulti = new double[M, D];
                for (int j = 0; j < M; j++)
                {
                    for (int k = 0; k < D; k++)
                    {
                        transposedMulti[j, k] = right2D[k, j];
                    }
                }
                double[,] resultMulti = new double[N, M];
                int count = 0;
                fixed(double* pMulti = resultMulti) 
                {
                    for (double* pM=pMulti;pM!=pMulti+resultMulti.Length;++pM)
                    {
                        for (int k = 0; k < D; k++)
                        {
                            *pM += left2D[count / M, k] * transposedMulti[count % M, k];
                        }
                        ++count;
                    }                
                }
            }
            GC.Collect();
            sw.Stop();
            string elapseMultiTPointer = sw.Elapsed.ToString();
            //Console.WriteLine(elapseMultiTPointer);

            /*
            // 多次元配列の転置の計算 + Span (できない)
            sw.Restart();
            {
                double[,] transposedMulti = new double[M, D];
                for (int j = 0; j < M; j++)
                {
                    for (int k = 0; k < D; k++)
                    {
                        transposedMulti[j, k] = right2D[k, j];
                    }
                }
                double[,] resultMulti = new double[N, M];
                Span<double> spanResultMulti = resultMulti.AsSpan();
                for (int i = 0; i < N; i++)
                {
                    for (int j = 0; j < M; j++)
                    {
                        for (int k = 0; k < D; k++)
                        {
                            spanResultMulti[i, j] += left2D[i, k] * transposedMulti[j, k];
                        }
                    }
                }
            }
            GC.Collect();
            sw.Stop();
            string elapseMultiTSpan = sw.Elapsed.ToString();
            //Console.WriteLine(elapseMultiT);
            */

            Console.WriteLine("ジャグ配列の普通の計算:\t" + elapseJag);
            Console.WriteLine("ジャグ配列の右の行列の転置:\t" + elapseJagT);
            Console.WriteLine("ジャグ配列の右の行列の転置+ポインター:\t" + elapseJagTPointer);
            Console.WriteLine("ジャグ配列の右の行列の転置+Span:\t" + elapseJagTSpan);

            Console.WriteLine("多次元配列の普通の計算:\t" + elapseMulti);
            Console.WriteLine("多次元配列の右の行列の転置:\t" + elapseMultiT);
            Console.WriteLine("多次元配列の右の行列の転置+ポインター:\t" + elapseMultiTPointer);
        }

    }
}

結果

各計算方法の計算時間を記しました。Span<T>構造体にすると、逆に遅くなりました。どういうことでしょうか。しかし、遅くなったとはいえ、ジャグ配列の方が多次元配列よりも早く計算できています。

配列 計算方法 計算時間
ジャグ配列 ①普通の行列積の計算 1分47秒
②右の行列を転置した計算 37秒57
③ ②+戻り値へのアクセスをポインターにした計算 28秒86
④ ②+戻り値へのアクセスをSpanにした計算 44秒70
多次元配列 ①普通の行列積の計算 2分17秒54
②右の行列を転置した計算 59秒96
③ ②+戻り値へのアクセスをポインターにした計算 1分8秒15

終わりに

どうしてSpan<T>構造体のほうが遅いんですかねえ。ヒープからスタックに移すコストかなあ。

Discussion

実際にアセンブリコード見てみないと何とも言えませんが、連続したメモリ領域にアクセスする場合、Span<T>は最適化されやすいですが、ソースを見た限りではループ内部で new double[M] していて、通常の配列とSpan<T>を混ぜて使っちゃっていますね。これではSpan<T> の恩恵はほぼ無さそうです。https://sharplab.io/ で、JIT Asm選択してどんなアセンブリコード出力しているか見てみるといいかもしれません。
細かくnewする事自体も遅いので、もし速度を詰めていくなら、メモリアロケーションを抑える工夫も必要になってきます。(stackalloc、ArrayPool<T>を使う等)
他には、単純にReleaseビルドにしてなかった、古いフレームワークを使用していた(.NET Framework4.xなど)が考えられます。

アドバイスありがとうございます。
勉強になります。

そもそもとして、変なタイミングでnew double[M]したので
Span<T>の最適化の邪魔をしている可能性があるのですね。
アセンブリコードを調べてみます。
SharpLab https://sharplab.io/ の情報ありがとうございます。

stackalloc、ArrayPool<T>というものがあるのですね。

仰る通りDebugビルドでした。

フレームワークだけは.Net 7.0でした。

ジャグ配列の転置の計算 + Span のコードを、もう少しSpan化してみました。

// ジャグ配列の転置の計算 + Span
sw.Restart();
{
    double[][] transposedJag = new double[M][];
    var spanTransposedJag = transposedJag.AsSpan();
    var spanRightJag = rightJag.AsSpan();
    for (int j = 0; j < M; j++)
    {
        spanTransposedJag[j] = new double[D];
        for (int k = 0; k < D; k++)
        {
            spanTransposedJag[j][k] = spanRightJag[k][j];
        }
    }
    double[][] resultJag = new double[N][];
    Span<double[]> spanResultJag = resultJag.AsSpan();
    var spanLeftJag = leftJag.AsSpan();
    for (int i = 0; i < N; i++)
    {
        spanResultJag[i] = new double[M];
        for (int j = 0; j < M; j++)
        {
            for (int k = 0; k < D; k++)
            {
                spanResultJag[i][j] += spanLeftJag[i][k] * spanTransposedJag[j][k];
            }
        }
    }
}
GC.Collect();
sw.Stop();
string elapseJagTSpan = sw.Elapsed.ToString();

ついでに、CommunityToolkit.HighPerformanceSpan2D で二次元配列をSpan化してみました。

// 多次元配列の転置の計算 + Span
sw.Restart();
{
    double[,] transposedMulti = new double[M, D];
    var spanTransposed2D = transposedMulti.AsSpan2D();
    var spanRight2D = right2D.AsSpan2D();
    for (int j = 0; j < M; j++)
    {
        for (int k = 0; k < D; k++)
        {
            spanTransposed2D[j, k] = spanRight2D[k, j];
        }
    }
    double[,] resultMulti = new double[N, M];
    var spanResult2D = resultMulti.AsSpan2D();
    var spanLeft2D = left2D.AsSpan2D();
    for (int i = 0; i < N; i++)
    {
        for (int j = 0; j < M; j++)
        {
            for (int k = 0; k < D; k++)
            {
                spanResult2D[i, j] += spanLeft2D[i, k] * spanTransposed2D[j, k];
            }
        }
    }
}
GC.Collect();
sw.Stop();
string elapseMultiTSpan = sw.Elapsed.ToString();

x86 Release での結果は、以下のようになりました。

左の行列の行数2022      列数2022        右の行列の行数2022
ジャグ配列の普通の計算: 00:02:15.7473073
ジャグ配列の右の行列の転置:     00:00:26.2950668
ジャグ配列の右の行列の転置+ポインター:  00:00:26.4669140
ジャグ配列の右の行列の転置+Span:        00:00:26.9344696
多次元配列の普通の計算: 00:02:35.2385008
多次元配列の右の行列の転置:     00:00:45.9291456
多次元配列の右の行列の転置+ポインター:  00:00:29.1750611
多次元配列の右の行列の転置+Span:        00:00:26.9590227

ジャグ配列の方は最速がポインタ・Spanと入れ替わる事もあり、そこまで有意な差を得られませんでした。多次元配列の方は、Span2Dで明確に改善が見られました。ジャグ配列にそこまで見劣りしない速度で、2次元配列の書き味で使える事を考えると、中々良いかもしれません。

書き味、読みやすさ度外視で、ジャグ配列を更なるSpan化。

Stopwatch sw = Stopwatch.StartNew();
double[][] transposedJag = new double[M][];
var spanTransposedJag = transposedJag.AsSpan();
var spanRightJag = rightJag.AsSpan();
for (int j = 0; j < M; j++)
{
    spanTransposedJag[j] = new double[D];
    var spanTransposedLine = spanTransposedJag[j].AsSpan();
    for (int k = 0; k < D; k++)
    {
        spanTransposedLine[k] = spanRightJag[k][j];
    }
}
double[][] resultJag = new double[N][];
Span<double[]> spanResultJag = resultJag.AsSpan();
var spanLeftJag = leftJag.AsSpan();
for (int i = 0; i < N; i++)
{
    spanResultJag[i] = new double[M];
    var spanResultLine = spanResultJag[i].AsSpan();
    var spanLeftLine = spanLeftJag[i].AsSpan();
    for (int j = 0; j < M; j++)
    {
        var spanTransposedLine = spanTransposedJag[j].AsSpan();
        for (int k = 0; k < D; k++)
        {
            spanResultLine[j] += spanLeftLine[k] * spanTransposedLine[k];
        }
    }
}
GC.Collect();
sw.Stop();
ジャグ配列の右の行列の転置+Span(2):     00:00:26.2418475

手間の割に…という感じなので、なんかジャグ配列の場合は変に小細工しない方が良さそうですね。

アドバイスに加えご検証ありがとうございます。

ジャグ配列はアクセス速度を意識した工夫の効果が低いというより、普通に書くだけでもパフォーマンスが高いという感じなのですね。
読みやすい書き方でパフォーマンスを確保できるのはすごいです。

教えていただいたSharpLabのX64 Releaseで実行時間を計算してみました。
今回の使い方では、ジャグ配列の最適化は基本的にコンパイラに任せるのが良さそうだと思いました。
※1 X86 ではMemoryGuardExceptionエラーが出ました。
※2 計測時間のばらつきが大きいので、何度か試した最短時間を比較しました。

大項目 小項目 計測時間の最小値(時間:分:秒)
初期化 ジャグ配列のインスタンス化 00:00:00.0013357
まずSpanに変換して、インスタンス化 00:00:00.0027886
ジャグ配列をインスタンス化して、一括でSpanに変換する 00:00:00.0011635
書き込み ジャグ配列への書き込み 00:00:00.0189865
Spanジャグ配列への書き込み 00:00:00.0194189
大項目(初期化)小項目(ジャグ配列のインスタンス化)
using System;
public class MyClass {
    public static void Main() {
        DateTime start= DateTime.Now;        
        int N=2022;
        int D=100;
        double[][] testJag=new double[N][];
        // 初期化
        for(int i=0;i<testJag.Length;++i)
        {
            testJag[i]=new double[D];
        }
        DateTime finish=DateTime.Now;
        Console.WriteLine((finish-start));
    }
}
大項目(初期化)小項目(まずSpanに変換して、インスタンス化)
using System;
public class MyClass {
    public static void Main() {
        DateTime start= DateTime.Now;
        int N=2022;
        int D=100;
        double[][] testJag=new double[N][];
        Span<double[]> testSpan=testJag.AsSpan();
        // 初期化
        for(int i=0;i<testJag.Length;++i)
        {
            testSpan[i]=new double[D];
        }
        DateTime finish=DateTime.Now;
        Console.WriteLine((finish-start));
    }
}
大項目(初期化)小項目(ジャグ配列をインスタンス化して、一括でSpanに変換する)
using System;
public class MyClass {
    public static void Main() {
        DateTime start= DateTime.Now;
        int N=2022;
        int D=100;
        double[][] testJag=new double[N][];
        // 初期化
        for(int i=0;i<testJag.Length;++i)
        {
            testJag[i]=new double[D];
        }        
        Span<double[]> testSpan=testJag.AsSpan();   
        DateTime finish=DateTime.Now;
        Console.WriteLine((finish-start));
    }
}
大項目(書き込み)小項目(ジャグ配列への書き込み)
using System;
public class MyClass {
    public static void Main() {
     
        int N=2022;
        int D=100;
        double[][] testJag=new double[N][];
        // 初期化
        for(int i=0;i<testJag.Length;++i)
        {
            testJag[i]=new double[D];
        }
        DateTime start= DateTime.Now;   
        // 書き込み
        for(int i=0;i<testJag.Length;++i)
        {
            for(int j=0;j<testJag[i].Length;++j)
            {
                testJag[i][j]=i+j;
            }        
        }
        DateTime finish=DateTime.Now;
        Console.WriteLine((finish-start));
    }
}
大項目(書き込み)小項目(Spanジャグ配列への書き込み)
using System;
public class MyClass {
    public static void Main() {
        int N=2022;
        int D=100;
        double[][] testJag=new double[N][];
        Span<double[]> testSpan=testJag.AsSpan();
        // 初期化
        for(int i=0;i<testJag.Length;++i)
        {
            testSpan[i]=new double[D];
        }
        DateTime start= DateTime.Now;
        // 書き込み
        for(int i=0;i<testSpan.Length;++i)
        {
            for(int j=0;j<testSpan[i].Length;++j)
            {
                testSpan[i][j]=i+j;
            }        
        }
        DateTime finish=DateTime.Now;
        Console.WriteLine((finish-start));
    }
}

CommunityToolkit.HighPerformance というものが公式から出ているのですね
2次元配列のSpanはオプションのようなものだったのですね。
試してみます。ありがとうございます。

ログインするとコメントできます