🦆
[C#]Parallel並列化で、賢くて素早い行列積
はじめに
これまで、データアクセスの観点から行列積の効率化に挑戦してきました。
1 多次元配列よりジャグ配列の方が早い
2 unsafeの効果はジャグ配列の方が可視化しやすい
3Span<T>構造体はジャグ配列にしか使えないわりに、逆に遅くなる
上記はシングルタスクの効率化です。そろそろ並列処理を導入しても良いかと思いました。今回の記事では、並列化を用いた行列積の効率化を検証してみます。
注記
この記事は2022/11/28に書きました。以前の記事で頂いたアドバイスを反映できていません。
行列積を例に挙げている理由
1 3重ループのため、計算時間の差を可視化しやすい
2 ほどほどに複雑なアルゴリズムのため、教科書の次のレベルのコードになる
3 目的がはっきりしているため、パフォーマンス向上へのモチベーションを維持できる
参考URL
【C#4.0~】Parallel.Forによる並列処理 | イメージングソリューション
結果
以前の記事で疑似コードを見ながら効率化を考えてきました。同じ事を繰り返し書いてもくどいので、理論的な内容は割愛して、結果を下記に記します。
並列化ってスゲー!!
配列 | 転置 | ポインター | Span | 並列化 | 計算時間 |
---|---|---|---|---|---|
ジャグ配列 | - | - | - | - | 1分53秒97 |
〃 | - | - | - | 〇 | 21秒45 |
〃 | 〇 | - | - | - | 41秒26 |
〃 | 〇 | - | - | 〇 | 8秒92 |
〃 | 〇 | 〇 | - | - | 30秒32 |
〃 | 〇 | 〇 | - | 〇 | 6秒56 |
〃 | 〇 | - | 〇 | - | 43秒55 |
多次元配列 | - | - | - | - | 2分19秒99 |
〃 | - | - | - | 〇 | 26秒82 |
〃 | 〇 | - | - | - | 1分1秒18 |
〃 | 〇 | - | - | 〇 | 15秒93 |
〃 | 〇 | 〇 | - | - | 1分15秒19 |
〃 | 〇 | 〇 | - | 〇 | 1分23秒61 |
終わりに
並列化ってすごいですね。
テストコードが長くなったので、後ろに書きます。
文字数稼ぎではないんです! 信じてください!
テストコード
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace Sandbox.Tremendous1192.SelfEmployed.UnsafeMathDotNet.Zenn.Matrix
{
/// <summary>
/// 行列に関するテストクラス
/// </summary>
public static partial class TestMatrix
{
/// <summary>
/// 行列積の速度比較
/// </summary>
public static void TestMultiplyUnsafeSpanParallel()
{
// 初期化
int N = 2022; // 左側の行列の行
int D = 2022; // 左側の行列の列
int M = 2022; // 右側の行列の列
Console.WriteLine("左の行列の行数" + N + "\t列数" + D + "\t右の行列の行数" + M);
double[,] left2D=new double[N, D];
double[,] right2D=new double[D, M];
double[][] leftJag = new double[N][];
double[][] rightJag = new double[D][];
for (int i=0;i<N;i++)
{
leftJag[i] = new double[D];
for (int k=0;k<D;k++)
{
left2D[i, k] = i + k;
leftJag[i][k] = i + k;
}
}
for (int k = 0; k < D; k++)
{
rightJag[k] = new double[M];
for (int j = 0; j < M; j++)
{
right2D[k, j] = k + j + 1;
rightJag[k][j] = k + j + 1;
}
}
Stopwatch sw= new Stopwatch();
sw.Start();
// 1 ジャグ配列の普通の計算
{
double[][] resultJag = new double[N][];
for (int i = 0; i < N; i++)
{
resultJag[i] = new double[M];
for (int j = 0; j < M; j++)
{
for (int k = 0; k < D; k++)
{
resultJag[i][j] += leftJag[i][k] * rightJag[k][j];
}
}
}
}
GC.Collect();
sw.Stop();
string elapseJag = sw.Elapsed.ToString();
//Console.WriteLine(elapseJag);
Console.WriteLine("1 ジャグ配列の普通の計算:\t" + elapseJag);
sw.Restart();
// 1-*-1 ジャグ配列の普通の計算 + 並列化
{
double[][] resultJag = new double[N][];
Parallel.For(0, N, i =>
{
resultJag[i] = new double[M];
for (int j = 0; j < M; j++)
{
for (int k = 0; k < D; k++)
{
resultJag[i][j] += leftJag[i][k] * rightJag[k][j];
}
}
});
}
GC.Collect();
sw.Stop();
string elapseJagParallel = sw.Elapsed.ToString();
//Console.WriteLine(elapseJag);
Console.WriteLine("1-*-1 ジャグ配列の普通の計算 + 並列化:\t" + elapseJagParallel);
// 2 ジャグ配列の転置の計算
sw.Restart();
{
double[][] transposedJag = new double[M][];
for (int j = 0; j < M; j++)
{
transposedJag[j] = new double[D];
for (int k = 0; k < D; k++)
{
transposedJag[j][k] = rightJag[k][j];
}
}
double[][] resultJag = new double[N][];
for (int i = 0; i < N; i++)
{
resultJag[i] = new double[M];
for (int j = 0; j < M; j++)
{
for (int k = 0; k < D; k++)
{
resultJag[i][j] += leftJag[i][k] * transposedJag[j][k];
}
}
}
}
GC.Collect();
sw.Stop();
string elapseJagT = sw.Elapsed.ToString();
//Console.WriteLine(elapseJagT);
Console.WriteLine("2 ジャグ配列の右の行列の転置:\t" + elapseJagT);
// 2-*-1 ジャグ配列の転置の計算 + 並列化
sw.Restart();
{
double[][] transposedJag = new double[M][];
for (int j = 0; j < M; j++)
{
transposedJag[j] = new double[D];
for (int k = 0; k < D; k++)
{
transposedJag[j][k] = rightJag[k][j];
}
}
double[][] resultJag = new double[N][];
Parallel.For(0, N, i =>
{
resultJag[i] = new double[M];
for (int j = 0; j < M; j++)
{
for (int k = 0; k < D; k++)
{
resultJag[i][j] += leftJag[i][k] * transposedJag[j][k];
}
}
});
}
GC.Collect();
sw.Stop();
string elapseJagTParallel = sw.Elapsed.ToString();
//Console.WriteLine(elapseJagT);
Console.WriteLine("2-*-1 ジャグ配列の右の行列の転置 + 並列化:\t" + elapseJagTParallel);
// 2-1 ジャグ配列の転置の計算 + ポインター
sw.Restart();
unsafe
{
double[][] transposedJag = new double[M][];
for (int j = 0; j < M; j++)
{
transposedJag[j] = new double[D];
for (int k = 0; k < D; k++)
{
transposedJag[j][k] = rightJag[k][j];
}
}
double[][] resultJag = new double[N][];
for (int i = 0; i < N; i++)
{
resultJag[i] = new double[M];
}
for (int i = 0; i < N; i++)
{
fixed (double* pResultJag = &resultJag[i][0])
{
int j = 0;
for (double* pR = pResultJag; pR != pResultJag + resultJag[i].Length; ++pR)
{
for (int k = 0; k < D; k++)
{
*pR += leftJag[i][k] * transposedJag[j][k];
}
++j;
}
}
}
}
GC.Collect();
sw.Stop();
string elapseJagTPointer = sw.Elapsed.ToString();
//Console.WriteLine(elapseJagTPointer);
Console.WriteLine("2-1 ジャグ配列の右の行列の転置 + ポインター:\t" + elapseJagTPointer);
// 2-1-1 ジャグ配列の転置の計算 + ポインター + 並列化
sw.Restart();
unsafe
{
double[][] transposedJag = new double[M][];
for (int j = 0; j < M; j++)
{
transposedJag[j] = new double[D];
for (int k = 0; k < D; k++)
{
transposedJag[j][k] = rightJag[k][j];
}
}
double[][] resultJag = new double[N][];
for (int i = 0; i < N; i++)
{
resultJag[i] = new double[M];
}
Parallel.For(0, N, i =>
{
fixed (double* pResultJag = &resultJag[i][0])
{
int j = 0;
for (double* pR = pResultJag; pR != pResultJag + resultJag[i].Length; ++pR)
{
for (int k = 0; k < D; k++)
{
*pR += leftJag[i][k] * transposedJag[j][k];
}
++j;
}
}
});
}
GC.Collect();
sw.Stop();
string elapseJagTPointerParallel = sw.Elapsed.ToString();
//Console.WriteLine(elapseJagTPointer);
Console.WriteLine("2-1-1 ジャグ配列の右の行列の転置 + ポインター + 並列化:\t" + elapseJagTPointerParallel);
// 2-2 ジャグ配列の転置の計算 + Span
sw.Restart();
{
double[][] transposedJag = new double[M][];
for (int j = 0; j < M; j++)
{
transposedJag[j] = new double[D];
for (int k = 0; k < D; k++)
{
transposedJag[j][k] = rightJag[k][j];
}
}
double[][] resultJag = new double[N][];
Span<double[]> spanResultJag = resultJag.AsSpan();
for (int i = 0; i < N; i++)
{
resultJag[i] = new double[M];
for (int j = 0; j < M; j++)
{
for (int k = 0; k < D; k++)
{
spanResultJag[i][j] += leftJag[i][k] * transposedJag[j][k];
}
}
}
}
GC.Collect();
sw.Stop();
string elapseJagTSpan = sw.Elapsed.ToString();
//Console.WriteLine(elapseJagT);
Console.WriteLine("2-2 ジャグ配列の右の行列の転置+Span:\t" + elapseJagTSpan);
Console.WriteLine("\n\n");
// 1 多次元配列の普通の計算
sw.Restart();
{
double[,] resultMulti = new double[N,M];
for (int i = 0; i < N; i++)
{
for (int j = 0; j < M; j++)
{
for (int k = 0; k < D; k++)
{
resultMulti[i,j] += left2D[i, k] * right2D[k, j];
}
}
}
}
GC.Collect();
sw.Stop();
string elapseMulti = sw.Elapsed.ToString();
//Console.WriteLine(elapseMulti);
Console.WriteLine("1 多次元配列の普通の計算:\t" + elapseMulti);
// 1-*-1 多次元配列の普通の計算 + 並列化
sw.Restart();
{
double[,] resultMulti = new double[N, M];
Parallel.For(0, N, i =>
{
for (int j = 0; j < M; j++)
{
for (int k = 0; k < D; k++)
{
resultMulti[i, j] += left2D[i, k] * right2D[k, j];
}
}
});
}
GC.Collect();
sw.Stop();
string elapseMultiParallel = sw.Elapsed.ToString();
//Console.WriteLine(elapseMulti);
Console.WriteLine("1-*-1 多次元配列の普通の計算 + 並列化:\t" + elapseMultiParallel);
// 2 多次元配列の転置の計算
sw.Restart();
{
double[,] transposedMulti = new double[M, D];
for (int j = 0; j < M; j++)
{
for (int k = 0; k < D; k++)
{
transposedMulti[j, k] = right2D[k, j];
}
}
double[,] resultMulti = new double[N, M];
for (int i = 0; i < N; i++)
{
for (int j = 0; j < M; j++)
{
for (int k = 0; k < D; k++)
{
resultMulti[i, j] += left2D[i, k] * transposedMulti[j, k];
}
}
}
}
GC.Collect();
sw.Stop();
string elapseMultiT = sw.Elapsed.ToString();
//Console.WriteLine(elapseMultiT);
Console.WriteLine("2 多次元配列の右の行列の転置:\t" + elapseMultiT);
// 2-*-1 多次元配列の転置の計算 + 並列化
sw.Restart();
{
double[,] transposedMulti = new double[M, D];
for (int j = 0; j < M; j++)
{
for (int k = 0; k < D; k++)
{
transposedMulti[j, k] = right2D[k, j];
}
}
double[,] resultMulti = new double[N, M];
Parallel.For(0, N, i =>
{
for (int j = 0; j < M; j++)
{
for (int k = 0; k < D; k++)
{
resultMulti[i, j] += left2D[i, k] * transposedMulti[j, k];
}
}
});
}
GC.Collect();
sw.Stop();
string elapseMultiTParallel = sw.Elapsed.ToString();
//Console.WriteLine(elapseMultiT);
Console.WriteLine("2-*-1 多次元配列の右の行列の転置 + 並列化:\t" + elapseMultiTParallel);
// 2-1 多次元配列の転置とポインターの計算
sw.Restart();
unsafe
{
double[,] transposedMulti = new double[M, D];
for (int j = 0; j < M; j++)
{
for (int k = 0; k < D; k++)
{
transposedMulti[j, k] = right2D[k, j];
}
}
double[,] resultMulti = new double[N, M];
int count = 0;
fixed(double* pMulti = resultMulti)
{
for (double* pM=pMulti;pM!=pMulti+resultMulti.Length;++pM)
{
for (int k = 0; k < D; k++)
{
*pM += left2D[count / M, k] * transposedMulti[count % M, k];
}
++count;
}
}
}
GC.Collect();
sw.Stop();
string elapseMultiTPointer = sw.Elapsed.ToString();
//Console.WriteLine(elapseMultiTPointer);
Console.WriteLine("2-1 多次元配列の右の行列の転置 + ポインター:\t" + elapseMultiTPointer);
// 2-1-1 多次元配列の転置とポインターの計算 + 並列化
sw.Restart();
unsafe
{
double[,] transposedMulti = new double[M, D];
for (int j = 0; j < M; j++)
{
for (int k = 0; k < D; k++)
{
transposedMulti[j, k] = right2D[k, j];
}
}
double[,] resultMulti = new double[N, M];
int count = 0;
fixed (double* pMulti = resultMulti)
{
double* pM = pMulti;
Parallel.For(0, N*M, i =>
{
for (int k = 0; k < D; k++)
{
*pM += left2D[count / M, k] * transposedMulti[count % M, k];
}
++count;
++pM;
});
}
}
GC.Collect();
sw.Stop();
string elapseMultiTPointerParallel = sw.Elapsed.ToString();
//Console.WriteLine(elapseMultiTPointer);
Console.WriteLine("2-1-1 多次元配列の右の行列の転置 + ポインター + 並列化:\t" + elapseMultiTPointerParallel);
//Console.WriteLine("1 ジャグ配列の普通の計算:\t" + elapseJag);
//Console.WriteLine("1-*-1 ジャグ配列の普通の計算 + 並列化:\t" + elapseJagParallel);
//Console.WriteLine("2 ジャグ配列の右の行列の転置:\t" + elapseJagT);
//Console.WriteLine("2-*-1 ジャグ配列の右の行列の転置 + 並列化:\t" + elapseJagTParallel);
//Console.WriteLine("2-1 ジャグ配列の右の行列の転置 + ポインター:\t" + elapseJagTPointer);
//Console.WriteLine("2-1-1 ジャグ配列の右の行列の転置 + ポインター + 並列化:\t" + elapseJagTPointerParallel);
//Console.WriteLine("2-2 ジャグ配列の右の行列の転置+Span:\t" + elapseJagTSpan);
//Console.WriteLine("2-2 ジャグ配列の右の行列の転置+Span: 並列化は使えない");
//Console.WriteLine("1 多次元配列の普通の計算:\t" + elapseMulti);
//Console.WriteLine("1-*-1 多次元配列の普通の計算 + 並列化:\t" + elapseMultiParallel);
//Console.WriteLine("2 多次元配列の右の行列の転置:\t" + elapseMultiT);
//Console.WriteLine("2-*-1 多次元配列の右の行列の転置 + 並列化:\t" + elapseMultiT);
//Console.WriteLine("2-1 多次元配列の右の行列の転置 + ポインター:\t" + elapseMultiTPointer);
//Console.WriteLine("2-1-1 多次元配列の右の行列の転置 + ポインター + 並列化:\t" + elapseMultiTPointerParallel);
}
}
}
Discussion