[C#]Span<T>構造体で高速かつ安全に配列へアクセスして、行列積を素早く計算する。
目的
今どきのC#erで危ねえポインターを使っているやついるぅ?! いねえよなあ!!
というわけで、以前の記事でunsafe
とポインターを扱ったので、今度はC# 7.2から追加されたSpan<T>
構造体を用いて行列積のパフォーマンスを見てみます。
ここで、行列積を例に挙げている理由は下記の3点です。
1 3重ループのため、計算時間の差を可視化しやすい
2 ほどほどに複雑なアルゴリズムのため、教科書の次のレベルのコードになる
3 目的がはっきりしているため、パフォーマンス向上へのモチベーションを維持できる
参考URL
多次元配列ではトリッキーなことをしないとSpan<T>
構造体を使えないようです。本末転倒なので今回は使用しません。
Span<T>構造体 - C# によるプログラミング入門 | ++C++; // 未確認飛行 C
c# - Span and two dimensional Arrays - Stack Overflow
結果
最初に計算結果を載せます。Span<T>
構造体の方が遅かったです。
理由をご存じの方がいらっしゃいましたら、コメントいただけますと助かります。
配列 | 計算方法 | 計算時間 |
---|---|---|
ジャグ配列 | ①普通の行列積の計算 | 1分47秒 |
②右の行列を転置した計算 | 37秒57 | |
③ ②+戻り値へのアクセスをポインターにした計算 | 28秒86 | |
④ ②+戻り値へのアクセスをSpanにした計算 | 44秒70 | |
多次元配列 | ①普通の行列積の計算 | 2分17秒54 |
②右の行列を転置した計算 | 59秒96 | |
③ ②+戻り値へのアクセスをポインターにした計算 | 1分8秒15 |
疑似コードで考えるプログラムの計算時間
今回のプログラムでは、右の行列の転置行列を作成して、戻り値の行列積をSpan<T>
構造体にした場合のパフォーマンスを比べてみます。
for(左側の行列の行 i)
{
for(右側の行列の列 j)
{
for(左側の行列の列 k)
{
左側の行列a[i,k]にアクセス
右側の行列b[k,j]にアクセス
a[i,k]とb[k,j]の積
積c[i,j]にアクセス
積c[i,j]にa[i,k]とb[k,j]の積を加える
次のkに移る(k+1の開始判定,境界判定a[i,k+1],b[k+1,j])
}
次のjに移る(j+1の開始判定,境界判定b[*,j+1],c[i,j+1])
}
次のiに移る(i+1の開始判定,境界判定a[i+1,*],c[i+1,*])
}
テストプログラム
前回のプログラムにSpan<T>
構造体の計算を加えたテストメソッドです。比較用にunsafe
が入っています。
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace Sandbox.Tremendous1192.SelfEmployed.UnsafeMathDotNet.Zenn.Matrix
{
/// <summary>
/// 行列に関するテストクラス
/// </summary>
public static partial class TestMatrix
{
/// <summary>
/// 行列積の速度比較
/// </summary>
public static void TestMultiplyUnsafeSpan()
{
// 初期化
int N = 2022; // 左側の行列の行
int D = 2022; // 左側の行列の列
int M = 2022; // 右側の行列の列
Console.WriteLine("左の行列の行数" + N + "\t列数" + D + "\t右の行列の行数" + M);
double[,] left2D=new double[N, D];
double[,] right2D=new double[D, M];
double[][] leftJag = new double[N][];
double[][] rightJag = new double[D][];
for (int i=0;i<N;i++)
{
leftJag[i] = new double[D];
for (int k=0;k<D;k++)
{
left2D[i, k] = i + k;
leftJag[i][k] = i + k;
}
}
for (int k = 0; k < D; k++)
{
rightJag[k] = new double[M];
for (int j = 0; j < M; j++)
{
right2D[k, j] = k + j + 1;
rightJag[k][j] = k + j + 1;
}
}
Stopwatch sw= new Stopwatch();
sw.Start();
// ジャグ配列の普通の計算
{
double[][] resultJag = new double[N][];
for (int i = 0; i < N; i++)
{
resultJag[i] = new double[M];
for (int j = 0; j < M; j++)
{
for (int k = 0; k < D; k++)
{
resultJag[i][j] += leftJag[i][k] * rightJag[k][j];
}
}
}
}
GC.Collect();
sw.Stop();
string elapseJag = sw.Elapsed.ToString();
//Console.WriteLine(elapseJag);
// ジャグ配列の転置の計算
sw.Restart();
{
double[][] transposedJag = new double[M][];
for (int j = 0; j < M; j++)
{
transposedJag[j] = new double[D];
for (int k = 0; k < D; k++)
{
transposedJag[j][k] = rightJag[k][j];
}
}
double[][] resultJag = new double[N][];
for (int i = 0; i < N; i++)
{
resultJag[i] = new double[M];
for (int j = 0; j < M; j++)
{
for (int k = 0; k < D; k++)
{
resultJag[i][j] += leftJag[i][k] * transposedJag[j][k];
}
}
}
}
GC.Collect();
sw.Stop();
string elapseJagT = sw.Elapsed.ToString();
//Console.WriteLine(elapseJagT);
// ジャグ配列の転置の計算 + ポインター
sw.Restart();
unsafe
{
double[][] transposedJag = new double[M][];
for (int j = 0; j < M; j++)
{
transposedJag[j] = new double[D];
for (int k = 0; k < D; k++)
{
transposedJag[j][k] = rightJag[k][j];
}
}
double[][] resultJag = new double[N][];
for (int i = 0; i < N; i++)
{
resultJag[i] = new double[M];
}
for (int i = 0; i < N; i++)
{
fixed (double* pResultJag = &resultJag[i][0])
{
int j = 0;
for (double* pR = pResultJag; pR != pResultJag + resultJag[i].Length; ++pR)
{
for (int k = 0; k < D; k++)
{
*pR += leftJag[i][k] * transposedJag[j][k];
}
++j;
}
}
}
}
GC.Collect();
sw.Stop();
string elapseJagTPointer = sw.Elapsed.ToString();
//Console.WriteLine(elapseJagTPointer);
// ジャグ配列の転置の計算 + Span
sw.Restart();
{
double[][] transposedJag = new double[M][];
for (int j = 0; j < M; j++)
{
transposedJag[j] = new double[D];
for (int k = 0; k < D; k++)
{
transposedJag[j][k] = rightJag[k][j];
}
}
double[][] resultJag = new double[N][];
Span<double[]> spanResultJag = resultJag.AsSpan();
for (int i = 0; i < N; i++)
{
resultJag[i] = new double[M];
for (int j = 0; j < M; j++)
{
for (int k = 0; k < D; k++)
{
spanResultJag[i][j] += leftJag[i][k] * transposedJag[j][k];
}
}
}
}
GC.Collect();
sw.Stop();
string elapseJagTSpan = sw.Elapsed.ToString();
//Console.WriteLine(elapseJagT);
//多次元配列の普通の計算
sw.Restart();
{
double[,] resultMulti = new double[N,M];
for (int i = 0; i < N; i++)
{
for (int j = 0; j < M; j++)
{
for (int k = 0; k < D; k++)
{
resultMulti[i,j] += left2D[i, k] * right2D[k, j];
}
}
}
}
GC.Collect();
sw.Stop();
string elapseMulti = sw.Elapsed.ToString();
//Console.WriteLine(elapseMulti);
// 多次元配列の転置の計算
sw.Restart();
{
double[,] transposedMulti = new double[M, D];
for (int j = 0; j < M; j++)
{
for (int k = 0; k < D; k++)
{
transposedMulti[j, k] = right2D[k, j];
}
}
double[,] resultMulti = new double[N, M];
for (int i = 0; i < N; i++)
{
for (int j = 0; j < M; j++)
{
for (int k = 0; k < D; k++)
{
resultMulti[i, j] += left2D[i, k] * transposedMulti[j, k];
}
}
}
}
GC.Collect();
sw.Stop();
string elapseMultiT = sw.Elapsed.ToString();
//Console.WriteLine(elapseMultiT);
// 多次元配列の転置とポインターの計算
sw.Restart();
unsafe
{
double[,] transposedMulti = new double[M, D];
for (int j = 0; j < M; j++)
{
for (int k = 0; k < D; k++)
{
transposedMulti[j, k] = right2D[k, j];
}
}
double[,] resultMulti = new double[N, M];
int count = 0;
fixed(double* pMulti = resultMulti)
{
for (double* pM=pMulti;pM!=pMulti+resultMulti.Length;++pM)
{
for (int k = 0; k < D; k++)
{
*pM += left2D[count / M, k] * transposedMulti[count % M, k];
}
++count;
}
}
}
GC.Collect();
sw.Stop();
string elapseMultiTPointer = sw.Elapsed.ToString();
//Console.WriteLine(elapseMultiTPointer);
/*
// 多次元配列の転置の計算 + Span (できない)
sw.Restart();
{
double[,] transposedMulti = new double[M, D];
for (int j = 0; j < M; j++)
{
for (int k = 0; k < D; k++)
{
transposedMulti[j, k] = right2D[k, j];
}
}
double[,] resultMulti = new double[N, M];
Span<double> spanResultMulti = resultMulti.AsSpan();
for (int i = 0; i < N; i++)
{
for (int j = 0; j < M; j++)
{
for (int k = 0; k < D; k++)
{
spanResultMulti[i, j] += left2D[i, k] * transposedMulti[j, k];
}
}
}
}
GC.Collect();
sw.Stop();
string elapseMultiTSpan = sw.Elapsed.ToString();
//Console.WriteLine(elapseMultiT);
*/
Console.WriteLine("ジャグ配列の普通の計算:\t" + elapseJag);
Console.WriteLine("ジャグ配列の右の行列の転置:\t" + elapseJagT);
Console.WriteLine("ジャグ配列の右の行列の転置+ポインター:\t" + elapseJagTPointer);
Console.WriteLine("ジャグ配列の右の行列の転置+Span:\t" + elapseJagTSpan);
Console.WriteLine("多次元配列の普通の計算:\t" + elapseMulti);
Console.WriteLine("多次元配列の右の行列の転置:\t" + elapseMultiT);
Console.WriteLine("多次元配列の右の行列の転置+ポインター:\t" + elapseMultiTPointer);
}
}
}
結果
各計算方法の計算時間を記しました。Span<T>
構造体にすると、逆に遅くなりました。どういうことでしょうか。しかし、遅くなったとはいえ、ジャグ配列の方が多次元配列よりも早く計算できています。
配列 | 計算方法 | 計算時間 |
---|---|---|
ジャグ配列 | ①普通の行列積の計算 | 1分47秒 |
②右の行列を転置した計算 | 37秒57 | |
③ ②+戻り値へのアクセスをポインターにした計算 | 28秒86 | |
④ ②+戻り値へのアクセスをSpanにした計算 | 44秒70 | |
多次元配列 | ①普通の行列積の計算 | 2分17秒54 |
②右の行列を転置した計算 | 59秒96 | |
③ ②+戻り値へのアクセスをポインターにした計算 | 1分8秒15 |
終わりに
どうしてSpan<T>
構造体のほうが遅いんですかねえ。ヒープからスタックに移すコストかなあ。
Discussion
実際にアセンブリコード見てみないと何とも言えませんが、連続したメモリ領域にアクセスする場合、Span<T>は最適化されやすいですが、ソースを見た限りではループ内部で new double[M] していて、通常の配列とSpan<T>を混ぜて使っちゃっていますね。これではSpan<T> の恩恵はほぼ無さそうです。https://sharplab.io/ で、JIT Asm選択してどんなアセンブリコード出力しているか見てみるといいかもしれません。
細かくnewする事自体も遅いので、もし速度を詰めていくなら、メモリアロケーションを抑える工夫も必要になってきます。(stackalloc、ArrayPool<T>を使う等)
他には、単純にReleaseビルドにしてなかった、古いフレームワークを使用していた(.NET Framework4.xなど)が考えられます。
アドバイスありがとうございます。
勉強になります。
そもそもとして、変なタイミングでnew double[M]したので
Span<T>の最適化の邪魔をしている可能性があるのですね。
アセンブリコードを調べてみます。
SharpLab https://sharplab.io/ の情報ありがとうございます。
stackalloc、ArrayPool<T>というものがあるのですね。
仰る通りDebugビルドでした。
フレームワークだけは.Net 7.0でした。
ジャグ配列の転置の計算 + Span のコードを、もう少しSpan化してみました。
ついでに、CommunityToolkit.HighPerformance の Span2D で二次元配列をSpan化してみました。
x86 Release での結果は、以下のようになりました。
ジャグ配列の方は最速がポインタ・Spanと入れ替わる事もあり、そこまで有意な差を得られませんでした。多次元配列の方は、Span2Dで明確に改善が見られました。ジャグ配列にそこまで見劣りしない速度で、2次元配列の書き味で使える事を考えると、中々良いかもしれません。
書き味、読みやすさ度外視で、ジャグ配列を更なるSpan化。
手間の割に…という感じなので、なんかジャグ配列の場合は変に小細工しない方が良さそうですね。
アドバイスに加えご検証ありがとうございます。
ジャグ配列はアクセス速度を意識した工夫の効果が低いというより、普通に書くだけでもパフォーマンスが高いという感じなのですね。
読みやすい書き方でパフォーマンスを確保できるのはすごいです。
教えていただいたSharpLabのX64 Releaseで実行時間を計算してみました。
今回の使い方では、ジャグ配列の最適化は基本的にコンパイラに任せるのが良さそうだと思いました。
※1 X86 ではMemoryGuardExceptionエラーが出ました。
※2 計測時間のばらつきが大きいので、何度か試した最短時間を比較しました。
CommunityToolkit.HighPerformance というものが公式から出ているのですね
2次元配列のSpanはオプションのようなものだったのですね。
試してみます。ありがとうございます。