🦆

[C#]Parallel並列化で、賢くて素早い行列積

2022/12/17に公開

はじめに

これまで、データアクセスの観点から行列積の効率化に挑戦してきました。
1 多次元配列よりジャグ配列の方が早い
2 unsafeの効果はジャグ配列の方が可視化しやすい
3Span<T>構造体はジャグ配列にしか使えないわりに、逆に遅くなる

上記はシングルタスクの効率化です。そろそろ並列処理を導入しても良いかと思いました。今回の記事では、並列化を用いた行列積の効率化を検証してみます。

注記

この記事は2022/11/28に書きました。以前の記事で頂いたアドバイスを反映できていません。

行列積を例に挙げている理由

1 3重ループのため、計算時間の差を可視化しやすい
2 ほどほどに複雑なアルゴリズムのため、教科書の次のレベルのコードになる
3 目的がはっきりしているため、パフォーマンス向上へのモチベーションを維持できる

参考URL

【C#4.0~】Parallel.Forによる並列処理 | イメージングソリューション

結果

以前の記事で疑似コードを見ながら効率化を考えてきました。同じ事を繰り返し書いてもくどいので、理論的な内容は割愛して、結果を下記に記します。
並列化ってスゲー!!

配列 転置 ポインター Span 並列化 計算時間
ジャグ配列 - - - - 1分53秒97
- - - 21秒45
- - - 41秒26
- - 8秒92
- - 30秒32
- 6秒56
- - 43秒55
多次元配列 - - - - 2分19秒99
- - - 26秒82
- - - 1分1秒18
- - 15秒93
- - 1分15秒19
- 1分23秒61

終わりに

並列化ってすごいですね。

テストコードが長くなったので、後ろに書きます。

文字数稼ぎではないんです! 信じてください!

テストコード
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace Sandbox.Tremendous1192.SelfEmployed.UnsafeMathDotNet.Zenn.Matrix
{
    /// <summary>
    /// 行列に関するテストクラス
    /// </summary>
    public static partial class TestMatrix
    {
        /// <summary>
        /// 行列積の速度比較
        /// </summary>
        public static void TestMultiplyUnsafeSpanParallel()
        {
            // 初期化
            int N = 2022; // 左側の行列の行
            int D = 2022; // 左側の行列の列
            int M = 2022; // 右側の行列の列
            Console.WriteLine("左の行列の行数" + N + "\t列数" + D + "\t右の行列の行数" + M);

            double[,] left2D=new double[N, D];
            double[,] right2D=new double[D, M];

            double[][] leftJag = new double[N][];
            double[][] rightJag = new double[D][];

            for (int i=0;i<N;i++)
            {
                leftJag[i] = new double[D];
                for (int k=0;k<D;k++)
                {
                    left2D[i, k] = i + k;
                    leftJag[i][k] = i + k;                
                }
            }
            for (int k = 0; k < D; k++)
            {
                rightJag[k] = new double[M];
                for (int j = 0; j < M; j++)
                {
                    right2D[k, j] = k + j + 1;
                    rightJag[k][j] = k + j + 1;
                }
            }


            Stopwatch sw= new Stopwatch();
            sw.Start();
            // 1 ジャグ配列の普通の計算
            {
                double[][] resultJag = new double[N][];
                for (int i = 0; i < N; i++)
                {
                    resultJag[i] = new double[M];
                    for (int j = 0; j < M; j++)
                    {
                        for (int k = 0; k < D; k++)
                        {
                            resultJag[i][j] += leftJag[i][k] * rightJag[k][j];
                        }
                    }
                }
            }
            GC.Collect();
            sw.Stop();
            string elapseJag = sw.Elapsed.ToString();
            //Console.WriteLine(elapseJag);
            Console.WriteLine("1 ジャグ配列の普通の計算:\t" + elapseJag);

            sw.Restart();
            // 1-*-1 ジャグ配列の普通の計算 + 並列化
            {
                double[][] resultJag = new double[N][];
                Parallel.For(0, N, i =>
                {
                    resultJag[i] = new double[M];
                    for (int j = 0; j < M; j++)
                    {
                        for (int k = 0; k < D; k++)
                        {
                            resultJag[i][j] += leftJag[i][k] * rightJag[k][j];
                        }
                    }
                });
            }
            GC.Collect();
            sw.Stop();
            string elapseJagParallel = sw.Elapsed.ToString();
            //Console.WriteLine(elapseJag);
            Console.WriteLine("1-*-1 ジャグ配列の普通の計算 + 並列化:\t" + elapseJagParallel);

            // 2 ジャグ配列の転置の計算
            sw.Restart();
            {
                double[][] transposedJag = new double[M][];
                for (int j = 0; j < M; j++)
                {
                    transposedJag[j] = new double[D];
                    for (int k = 0; k < D; k++)
                    {
                        transposedJag[j][k] = rightJag[k][j];
                    }
                }
                double[][] resultJag = new double[N][];
                for (int i = 0; i < N; i++)
                {
                    resultJag[i] = new double[M];
                    for (int j = 0; j < M; j++)
                    {
                        for (int k = 0; k < D; k++)
                        {
                            resultJag[i][j] += leftJag[i][k] * transposedJag[j][k];
                        }
                    }
                }
            }
            GC.Collect();
            sw.Stop();
            string elapseJagT = sw.Elapsed.ToString();
            //Console.WriteLine(elapseJagT);
            Console.WriteLine("2 ジャグ配列の右の行列の転置:\t" + elapseJagT);

            // 2-*-1 ジャグ配列の転置の計算 + 並列化
            sw.Restart();
            {
                double[][] transposedJag = new double[M][];
                for (int j = 0; j < M; j++)
                {
                    transposedJag[j] = new double[D];
                    for (int k = 0; k < D; k++)
                    {
                        transposedJag[j][k] = rightJag[k][j];
                    }
                }
                double[][] resultJag = new double[N][];
                Parallel.For(0, N, i =>
                {
                    resultJag[i] = new double[M];
                    for (int j = 0; j < M; j++)
                    {
                        for (int k = 0; k < D; k++)
                        {
                            resultJag[i][j] += leftJag[i][k] * transposedJag[j][k];
                        }
                    }
                });
            }
            GC.Collect();
            sw.Stop();
            string elapseJagTParallel = sw.Elapsed.ToString();
            //Console.WriteLine(elapseJagT);
            Console.WriteLine("2-*-1 ジャグ配列の右の行列の転置 + 並列化:\t" + elapseJagTParallel);

            // 2-1 ジャグ配列の転置の計算 + ポインター
            sw.Restart();
            unsafe
            {
                double[][] transposedJag = new double[M][];
                for (int j = 0; j < M; j++)
                {
                    transposedJag[j] = new double[D];
                    for (int k = 0; k < D; k++)
                    {
                        transposedJag[j][k] = rightJag[k][j];
                    }
                }
                double[][] resultJag = new double[N][];
                for (int i = 0; i < N; i++)
                {
                    resultJag[i] = new double[M];
                }

                for (int i = 0; i < N; i++)
                {
                    fixed (double* pResultJag = &resultJag[i][0])
                    {
                        int j = 0;
                        for (double* pR = pResultJag; pR != pResultJag + resultJag[i].Length; ++pR)
                        {
                            for (int k = 0; k < D; k++)
                            {
                              *pR += leftJag[i][k] * transposedJag[j][k];
                            }
                            ++j;
                        }
                    }
                }
            }
            GC.Collect();
            sw.Stop();
            string elapseJagTPointer = sw.Elapsed.ToString();
            //Console.WriteLine(elapseJagTPointer);
            Console.WriteLine("2-1 ジャグ配列の右の行列の転置 + ポインター:\t" + elapseJagTPointer);

            // 2-1-1 ジャグ配列の転置の計算 + ポインター + 並列化
            sw.Restart();
            unsafe
            {
                double[][] transposedJag = new double[M][];
                for (int j = 0; j < M; j++)
                {
                    transposedJag[j] = new double[D];
                    for (int k = 0; k < D; k++)
                    {
                        transposedJag[j][k] = rightJag[k][j];
                    }
                }
                double[][] resultJag = new double[N][];
                for (int i = 0; i < N; i++)
                {
                    resultJag[i] = new double[M];
                }

                Parallel.For(0, N, i =>
                {
                    fixed (double* pResultJag = &resultJag[i][0])
                    {
                        int j = 0;
                        for (double* pR = pResultJag; pR != pResultJag + resultJag[i].Length; ++pR)
                        {
                            for (int k = 0; k < D; k++)
                            {
                                *pR += leftJag[i][k] * transposedJag[j][k];
                            }
                            ++j;
                        }
                    }
                });
            }
            GC.Collect();
            sw.Stop();
            string elapseJagTPointerParallel = sw.Elapsed.ToString();
            //Console.WriteLine(elapseJagTPointer);
            Console.WriteLine("2-1-1 ジャグ配列の右の行列の転置 + ポインター + 並列化:\t" + elapseJagTPointerParallel);

            // 2-2 ジャグ配列の転置の計算 + Span
            sw.Restart();
            {
                double[][] transposedJag = new double[M][];
                for (int j = 0; j < M; j++)
                {
                    transposedJag[j] = new double[D];
                    for (int k = 0; k < D; k++)
                    {
                        transposedJag[j][k] = rightJag[k][j];
                    }
                }
                double[][] resultJag = new double[N][];
                Span<double[]> spanResultJag = resultJag.AsSpan();
                for (int i = 0; i < N; i++)
                {
                    resultJag[i] = new double[M];
                    for (int j = 0; j < M; j++)
                    {
                        for (int k = 0; k < D; k++)
                        {
                            spanResultJag[i][j] += leftJag[i][k] * transposedJag[j][k];
                        }
                    }
                }
            }
            GC.Collect();
            sw.Stop();
            string elapseJagTSpan = sw.Elapsed.ToString();
            //Console.WriteLine(elapseJagT);
            Console.WriteLine("2-2 ジャグ配列の右の行列の転置+Span:\t" + elapseJagTSpan);


            Console.WriteLine("\n\n");
            // 1 多次元配列の普通の計算
            sw.Restart();
            {
                double[,] resultMulti = new double[N,M];
                for (int i = 0; i < N; i++)
                {
                    for (int j = 0; j < M; j++)
                    {
                        for (int k = 0; k < D; k++)
                        {
                            resultMulti[i,j] += left2D[i, k] * right2D[k, j];
                        }
                    }
                }
            }
            GC.Collect();
            sw.Stop();
            string elapseMulti = sw.Elapsed.ToString();
            //Console.WriteLine(elapseMulti);
            Console.WriteLine("1 多次元配列の普通の計算:\t" + elapseMulti);

            // 1-*-1 多次元配列の普通の計算 + 並列化
            sw.Restart();
            {
                double[,] resultMulti = new double[N, M];
                Parallel.For(0, N, i =>
                {
                    for (int j = 0; j < M; j++)
                    {
                        for (int k = 0; k < D; k++)
                        {
                            resultMulti[i, j] += left2D[i, k] * right2D[k, j];
                        }
                    }
                });
            }
            GC.Collect();
            sw.Stop();
            string elapseMultiParallel = sw.Elapsed.ToString();
            //Console.WriteLine(elapseMulti);
            Console.WriteLine("1-*-1 多次元配列の普通の計算 + 並列化:\t" + elapseMultiParallel);

            // 2 多次元配列の転置の計算
            sw.Restart();
            {
                double[,] transposedMulti = new double[M, D];
                for (int j = 0; j < M; j++)
                {
                    for (int k = 0; k < D; k++)
                    {
                        transposedMulti[j, k] = right2D[k, j];
                    }
                }
                double[,] resultMulti = new double[N, M];
                for (int i = 0; i < N; i++)
                {
                    for (int j = 0; j < M; j++)
                    {
                        for (int k = 0; k < D; k++)
                        {
                            resultMulti[i, j] += left2D[i, k] * transposedMulti[j, k];
                        }
                    }
                }
            }
            GC.Collect();
            sw.Stop();
            string elapseMultiT = sw.Elapsed.ToString();
            //Console.WriteLine(elapseMultiT);
            Console.WriteLine("2 多次元配列の右の行列の転置:\t" + elapseMultiT);

            // 2-*-1 多次元配列の転置の計算 + 並列化
            sw.Restart();
            {
                double[,] transposedMulti = new double[M, D];
                for (int j = 0; j < M; j++)
                {
                    for (int k = 0; k < D; k++)
                    {
                        transposedMulti[j, k] = right2D[k, j];
                    }
                }
                double[,] resultMulti = new double[N, M];
                Parallel.For(0, N, i =>
                {
                    for (int j = 0; j < M; j++)
                    {
                        for (int k = 0; k < D; k++)
                        {
                            resultMulti[i, j] += left2D[i, k] * transposedMulti[j, k];
                        }
                    }
                });
            }
            GC.Collect();
            sw.Stop();
            string elapseMultiTParallel = sw.Elapsed.ToString();
            //Console.WriteLine(elapseMultiT);
            Console.WriteLine("2-*-1 多次元配列の右の行列の転置 + 並列化:\t" + elapseMultiTParallel);

            // 2-1 多次元配列の転置とポインターの計算
            sw.Restart();
            unsafe
            {
                double[,] transposedMulti = new double[M, D];
                for (int j = 0; j < M; j++)
                {
                    for (int k = 0; k < D; k++)
                    {
                        transposedMulti[j, k] = right2D[k, j];
                    }
                }
                double[,] resultMulti = new double[N, M];
                int count = 0;
                fixed(double* pMulti = resultMulti) 
                {
                    for (double* pM=pMulti;pM!=pMulti+resultMulti.Length;++pM)
                    {
                        for (int k = 0; k < D; k++)
                        {
                            *pM += left2D[count / M, k] * transposedMulti[count % M, k];
                        }
                        ++count;
                    }                
                }
            }
            GC.Collect();
            sw.Stop();
            string elapseMultiTPointer = sw.Elapsed.ToString();
            //Console.WriteLine(elapseMultiTPointer);
            Console.WriteLine("2-1 多次元配列の右の行列の転置 + ポインター:\t" + elapseMultiTPointer);

            // 2-1-1 多次元配列の転置とポインターの計算 + 並列化
            sw.Restart();
            unsafe
            {
                double[,] transposedMulti = new double[M, D];
                for (int j = 0; j < M; j++)
                {
                    for (int k = 0; k < D; k++)
                    {
                        transposedMulti[j, k] = right2D[k, j];
                    }
                }
                double[,] resultMulti = new double[N, M];
                int count = 0;
                fixed (double* pMulti = resultMulti)
                {
                    double* pM = pMulti;
                    Parallel.For(0, N*M, i =>
                    {
                        for (int k = 0; k < D; k++)
                        {
                            *pM += left2D[count / M, k] * transposedMulti[count % M, k];
                        }
                        ++count;
                        ++pM;
                    });
                }
            }
            GC.Collect();
            sw.Stop();
            string elapseMultiTPointerParallel = sw.Elapsed.ToString();
            //Console.WriteLine(elapseMultiTPointer);
            Console.WriteLine("2-1-1 多次元配列の右の行列の転置 + ポインター + 並列化:\t" + elapseMultiTPointerParallel);

            //Console.WriteLine("1 ジャグ配列の普通の計算:\t" + elapseJag);
            //Console.WriteLine("1-*-1 ジャグ配列の普通の計算 + 並列化:\t" + elapseJagParallel);
            //Console.WriteLine("2 ジャグ配列の右の行列の転置:\t" + elapseJagT);
            //Console.WriteLine("2-*-1 ジャグ配列の右の行列の転置 + 並列化:\t" + elapseJagTParallel);
            //Console.WriteLine("2-1 ジャグ配列の右の行列の転置 + ポインター:\t" + elapseJagTPointer);
            //Console.WriteLine("2-1-1 ジャグ配列の右の行列の転置 + ポインター + 並列化:\t" + elapseJagTPointerParallel);
            //Console.WriteLine("2-2 ジャグ配列の右の行列の転置+Span:\t" + elapseJagTSpan);
            //Console.WriteLine("2-2 ジャグ配列の右の行列の転置+Span: 並列化は使えない");

            //Console.WriteLine("1 多次元配列の普通の計算:\t" + elapseMulti);
            //Console.WriteLine("1-*-1 多次元配列の普通の計算 + 並列化:\t" + elapseMultiParallel);
            //Console.WriteLine("2 多次元配列の右の行列の転置:\t" + elapseMultiT);
            //Console.WriteLine("2-*-1 多次元配列の右の行列の転置 + 並列化:\t" + elapseMultiT);
            //Console.WriteLine("2-1 多次元配列の右の行列の転置 + ポインター:\t" + elapseMultiTPointer);
            //Console.WriteLine("2-1-1 多次元配列の右の行列の転置 + ポインター + 並列化:\t" + elapseMultiTPointerParallel);

        }

    }
}

Discussion