🦆

[C#]unsafe ポインターで行列積を高速計算する

2022/11/19に公開

目的

行列積の計算は3重ループのため、アルゴリズムやコードの効率化の結果を可視化しやすいです。同じアルゴリズムの場合、ジャグ配列の方が多次元配列より早いです。では、unsafeでポインターをいじるとどちらが早いでしょうか?

注意事項と参考URL

2011年の時点で、C#でポインターをいじる旨味は無いとのこと。ポインターをいじるとバグの温床になりかねないので、自己責任でお願いします。
個人的には、最初に左の行列の列数と右の行列の行数が等しい場合にunsafeを用いたメソッドで計算すれば良いのでは、とも思います。
1 C#で配列に大量にアクセスする場合、ポインターを使うのと使わ... - Yahoo!知恵袋
2 unsafe キーワード - C# リファレンス | Microsoft Learn
3 unsafe - C# によるプログラミング入門 | ++C++; // 未確認飛行 C

結果

最初に結果を示します。ジャグ配列が圧倒的に早いです。

配列 計算方法 計算時間
ジャグ配列 ①普通の行列積の計算 1分47秒
②右の行列を転置した計算 37秒57
③ ②+戻り値へのアクセスをポインターにした計算 28秒86
多次元配列 ①普通の行列積の計算 2分17秒54
②右の行列を転置した計算 59秒96
③ ②+戻り値へのアクセスをポインターにした計算 1分8秒15

行列積のアルゴリズムと計算量

アルゴリズムやコードの効率の話をするために、行列積の式を確認します。ND列の行列ADM列の行列Bの行列積(NM列の行列C)を考えます。

C=AB

行列Cの一般項は下記になります。

c_{i,j}=\sum_{k=0}^{D-1} a_{i,k}b_{k,j} \ \ \ \ where \ 0 \leqq i \leqq N-1, 0 \leqq j \leqq M-1

ij列目の要素を計算するためにD個の積の和を計算するので、行列積の計算量は下記になります。

O(NM \times (2D)) = O(2DMN)

この式を見る限りでは、計算量はデータ量に比例するので、データを間引かない限り計算量を減らせないようにも見えます。そこで、疑似コードでプログラムの計算時間を考え直してみます。

疑似コードで考えるプログラムの計算時間

データアクセスやループの継続条件を明示した疑似コードを書いてみました。このように書くと効率化の余地がありそうです。

疑似コードでの行列積
for(左側の行列の行 i)
{
    for(右側の行列の列 j)
    {
        for(左側の行列の列 k)
	{
	    左側の行列a[i,k]にアクセス
	    右側の行列b[k,j]にアクセス
	    a[i,k]とb[k,j]の積
	    積c[i,j]にアクセス
	    積c[i,j]にa[i,k]とb[k,j]の積を加える
	    次のkに移る(k+1の開始判定,境界判定a[i,k+1],b[k+1,j])
	}
	次のjに移る(j+1の開始判定,境界判定b[*,j+1],c[i,j+1])
    }
    次のiに移る(i+1の開始判定,境界判定a[i+1,*],c[i+1,*])
}

右側の行列を転置する

よくある効率化手法です。右側の行列を転置して、メモリアクセスを効率化します。

c_{i,j}=\sum_{k=0}^{D-1} a_{i,k}b^T_{j,k} \ \ \ \ where \ 0 \leqq i \leqq N-1, 0 \leqq j \leqq M-1

疑似コードは下記になります。変化点が見えるようにコメントアウトしました。C#の場合、列に相当するデータへのアクセスの方が早いので、プログラムの計算時間が短くなります。

疑似コードでの行列積(右の行列の転置)
for(左側の行列の列 k)
{
    for(右側の行列の列 j)
    {
        b^T[k,j]=b[j,k]
    }
}
for(左側の行列の行 i)
{
    for(右側の行列の列 j)
    {
        for(左側の行列の列 k)
	{
	    左側の行列a[i,k]にアクセス
	    右側の行列b^T[j,k]/*b[k,j]*/にアクセス
	    a[i,k]とb^T[j,k]/*b[k,j]*/の積
	    積c[i,j]にアクセス
	    積c[i,j]にa[i,k]とb^T[j,k]/*b[k,j]*/の積を加える
	    次のkに移る(k+1の開始判定,境界判定a[i,k+1],b^T[j,k+1]/*b[k+1,j]*/)
	}
	次のjに移る(j+1の開始判定,境界判定b^T[j+1,*]/*b[*,j+1]*/,c[i,j+1])
    }
    次のiに移る(i+1の開始判定,境界判定a[i+1,*],c[i+1,*])
}

ポインター

foreachで数値の代入なんてできないんですが、ま、疑似コードなんで多少はね?

疑似コードでの行列積(右の行列の転置 + 多重配列でポインター使用)
for(左側の行列の列 k)
{
    for(右側の行列の列 j)
    {
        b^T[k,j]=b[j,k]
    }
}
unsafe foreach(double* c in C行列積の各要素へのアクセスをポインターで行う)
{
    i=count/M
    j=count%M
        for(左側の行列の列 k)
	{
	    左側の行列a[i,k]にアクセス
	    右側の行列b^T[j,k]/*b[k,j]*/にアクセス
	    a[i,k]とb^T[j,k]/*b[k,j]*/の積
	    積c[i,j]にアクセス
	    積*c[i,j]/*c[i,j]*/にa[i,k]とb^T[j,k]/*b[k,j]*/の積を加える
	    次のkに移る(k+1の開始判定,境界判定a[i,k+1],b^T[j,k+1]/*b[k+1,j]*/)
	}
    ++count;
    次の要素に移る(境界判定c)
}

テストプログラム

2022行2022列の正方行列の積で①普通の行列積の計算、②右の行列を転置した計算、③ ②+戻り値へのアクセスをポインターにした計算の3種類を比較しました。

行列の計算速度の比較
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace Sandbox.Tremendous1192.SelfEmployed.UnsafeMathDotNet.Zenn.Matrix
{
    /// <summary>
    /// 行列に関するテストクラス
    /// </summary>
    public static partial class TestMatrix
    {
        /// <summary>
        /// 行列積の速度比較
        /// </summary>
        public static void TestMultiply()
        {
            // 初期化
            int N = 2022; // 左側の行列の行
            int D = 2022; // 左側の行列の列
            int M = 2022; // 右側の行列の列
            Console.WriteLine("左の行列の行数" + N + "\t列数" + D + "\t右の行列の行数" + M);

            double[,] left2D=new double[N, D];
            double[,] right2D=new double[D, M];

            double[][] leftJag = new double[N][];
            double[][] rightJag = new double[D][];

            for (int i=0;i<N;i++)
            {
                leftJag[i] = new double[D];
                for (int k=0;k<D;k++)
                {
                    left2D[i, k] = i + k;
                    leftJag[i][k] = i + k;                
                }
            }
            for (int k = 0; k < D; k++)
            {
                rightJag[k] = new double[M];
                for (int j = 0; j < M; j++)
                {
                    right2D[k, j] = k + j + 1;
                    rightJag[k][j] = k + j + 1;
                }
            }


            Stopwatch sw= new Stopwatch();
            sw.Start();
            // ジャグ配列の普通の計算
            {
                double[][] resultJag = new double[N][];
                for (int i = 0; i < N; i++)
                {
                    resultJag[i] = new double[M];
                    for (int j = 0; j < M; j++)
                    {
                        for (int k = 0; k < D; k++)
                        {
                            resultJag[i][j] += leftJag[i][k] * rightJag[k][j];
                        }
                    }
                }
            }
            GC.Collect();
            sw.Stop();
            string elapseJag = sw.Elapsed.ToString();
            //Console.WriteLine(elapseJag);

            // ジャグ配列の転置の計算
            sw.Restart();
            {
                double[][] transposedJag = new double[M][];
                for (int j = 0; j < M; j++)
                {
                    transposedJag[j] = new double[D];
                    for (int k = 0; k < D; k++)
                    {
                        transposedJag[j][k] = rightJag[k][j];
                    }
                }
                double[][] resultJag = new double[N][];
                for (int i = 0; i < N; i++)
                {
                    resultJag[i] = new double[M];
                    for (int j = 0; j < M; j++)
                    {
                        for (int k = 0; k < D; k++)
                        {
                            resultJag[i][j] += leftJag[i][k] * transposedJag[j][k];
                        }
                    }
                }
            }
            GC.Collect();
            sw.Stop();
            string elapseJagT = sw.Elapsed.ToString();
            //Console.WriteLine(elapseJagT);


            // ジャグ配列の転置の計算
            sw.Restart();
            unsafe
            {
                double[][] transposedJag = new double[M][];
                for (int j = 0; j < M; j++)
                {
                    transposedJag[j] = new double[D];
                    for (int k = 0; k < D; k++)
                    {
                        transposedJag[j][k] = rightJag[k][j];
                    }
                }
                double[][] resultJag = new double[N][];
                for (int i = 0; i < N; i++)
                {
                    resultJag[i] = new double[M];
                }

                for (int i = 0; i < N; i++)
                {
                    fixed (double* pResultJag = &resultJag[i][0])
                    {
                        int j = 0;
                        for (double* pR = pResultJag; pR != pResultJag + resultJag[i].Length; ++pR)
                        {
                            for (int k = 0; k < D; k++)
                            {
                              *pR += leftJag[i][k] * transposedJag[j][k];
                            }
                            ++j;
                        }
                    }
                }
            }
            GC.Collect();
            sw.Stop();
            string elapseJagTPointer = sw.Elapsed.ToString();
            //Console.WriteLine(elapseJagTPointer);



            //多次元配列の普通の計算
            sw.Restart();
            {
                double[,] resultMulti = new double[N,M];
                for (int i = 0; i < N; i++)
                {
                    for (int j = 0; j < M; j++)
                    {
                        for (int k = 0; k < D; k++)
                        {
                            resultMulti[i,j] += left2D[i, k] * right2D[k, j];
                        }
                    }
                }
            }
            GC.Collect();
            sw.Stop();
            string elapseMulti = sw.Elapsed.ToString();
            //Console.WriteLine(elapseMulti);

            // 多次元配列の転置の計算
            sw.Restart();
            {
                double[,] transposedMulti = new double[M, D];
                for (int j = 0; j < M; j++)
                {
                    for (int k = 0; k < D; k++)
                    {
                        transposedMulti[j, k] = right2D[k, j];
                    }
                }
                double[,] resultMulti = new double[N, M];
                for (int i = 0; i < N; i++)
                {
                    for (int j = 0; j < M; j++)
                    {
                        for (int k = 0; k < D; k++)
                        {
                            resultMulti[i, j] += left2D[i, k] * transposedMulti[j, k];
                        }
                    }
                }
            }
            GC.Collect();
            sw.Stop();
            string elapseMultiT = sw.Elapsed.ToString();
            //Console.WriteLine(elapseMultiT);

            // 多次元配列の転置とポインターの計算
            sw.Restart();
            unsafe
            {
                double[,] transposedMulti = new double[M, D];
                for (int j = 0; j < M; j++)
                {
                    for (int k = 0; k < D; k++)
                    {
                        transposedMulti[j, k] = right2D[k, j];
                    }
                }
                double[,] resultMulti = new double[N, M];
                int count = 0;
                fixed(double* pMulti = resultMulti) 
                {
                    for (double* pM=pMulti;pM!=pMulti+resultMulti.Length;++pM)
                    {
                        for (int k = 0; k < D; k++)
                        {
                            *pM += left2D[count / M, k] * transposedMulti[count % M, k];
                        }
                        ++count;
                    }                
                }
            }
            GC.Collect();
            sw.Stop();
            string elapseMultiTPointer = sw.Elapsed.ToString();
            //Console.WriteLine(elapseMultiTPointer);



            Console.WriteLine("ジャグ配列の普通の計算:\t" + elapseJag);
            Console.WriteLine("ジャグ配列の右の行列の転置:\t" + elapseJagT);
            Console.WriteLine("ジャグ配列の右の行列の転置+ポインター:\t" + elapseJagTPointer);

            Console.WriteLine("多次元配列の普通の計算:\t" + elapseMulti);
            Console.WriteLine("多次元配列の右の行列の転置:\t" + elapseMultiT);
            Console.WriteLine("多次元配列の右の行列の転置+ポインター:\t" + elapseMultiTPointer);
        }

    }
}

結果

ジャグ配列のほうが圧倒的に早かったです。多次元配列の方はポインターの方が遅かったです。単純比較としてポインターを中途半端にしか使わなかったことも影響しているかもしれません。

配列 計算方法 計算時間
ジャグ配列 ①普通の行列積の計算 1分47秒
②右の行列を転置した計算 37秒57
③ ②+戻り値へのアクセスをポインターにした計算 28秒86
多次元配列 ①普通の行列積の計算 2分17秒54
②右の行列を転置した計算 59秒96
③ ②+戻り値へのアクセスをポインターにした計算 1分8秒15

終わりに

やっぱりジャグ配列がナンバー1 !!

Discussion