ArmのScalable Matrix Extension (SME)を試す
最近のCPUには行列乗算に役立つ命令が載っていることがあります。IntelのAdvanced Matrix Extensions (AMX)、AppleのAMX、IBM PowerのMatrix-Multiply Assist (MMA),そしてここで取り上げるArmのScalable Matrix Extension (SME)です。
SMEはここ数年(2021年ごろから?)話は聞きますが、実物の話は聞かないという状況でした(私の中では)。それが最近発表されたApple M4に実装されているという噂を聞いて、俄然興味が出てきました。Apple M4の実物は私は持っていませんが、QEMUを使うとSMEの動作確認ができるようです。やってみましょう。
環境構築とベクトル長
Ubuntu 24.04上のGCC 14/Clang 18とQEMU 8.2で動作確認します。Ubuntuはx86_64で動いています。
$ sudo apt install gcc-14-aarch64-linux-gnu clang-18 qemu-user
SMEはSVEの拡張みたいな感じなので、ベクトル長がCPU依存です。SMEを使うコードはStreaming ModeとNon-streaming Modeの状態を持ち、Streaming Modeではベクトル長が通常のものから変わる可能性があります。Streaming Modeでのベクトル長(ビット単位)を
Hello world代わりに、それぞれのモードでのベクトル長を確認するコードを書いてみましょう。
#include <stdio.h>
#include <arm_sme.h>
__arm_locally_streaming
void streaming_fn(void)
{
printf("Streaming mode: svcntw() = %u, svcntsw() = %u\n", (unsigned)svcntw(), (unsigned)svcntsw());
}
int main(void)
{
printf("Has SME? %s\n", __arm_has_sme() ? "Yes" : "No");
printf("Non-streaming mode: svcntw() = %u, svcntsw() = %u\n", (unsigned)svcntw(), (unsigned)svcntsw());
streaming_fn();
}
__arm_locally_streaming
属性をつけた関数の内部でStreaming Modeが有効になります。svcnt
系の関数は従来のSVEにもあったやつで、これの値がStreaming Modeかどうかで変化します。svcnts
系の関数はSMEのやつで、Streaming Mode中でのベクトル長を返します。
ACLEだと 執筆時点のACLEだと __arm_locally_streaming
じゃなくて __attribute__((arm_locally_streaming))
のようですが、実際のコンパイラーに実装されているのは __arm_locally_streaming
です。謎です。将来変わるのかもしれません。__attribute__
を使っていましたが、ACLE 2024 Q1で __arm_locally_streaming
のようなキーワードを使うように改訂されたようです。コンパイラーに実装されたのは改訂後の仕様です。
御託はいいので、コンパイルして実行してみましょう。
GCCの場合:
$ aarch64-linux-gnu-gcc-14 -O2 -march=armv9-a+sme -static -o veclen veclen.c
$ qemu-aarch64 ./veclen
Has SME? Yes
Non-streaming mode: svcntw() = 16, svcntsw() = 8
Streaming mode: svcntw() = 8, svcntsw() = 8
Clangの場合:
$ clang-18 -O2 --target=aarch64-linux-gnu -march=armv9-a+sme -static -o veclen veclen.c
$ qemu-aarch64 ./veclen
Has SME? Yes
Non-streaming mode: svcntw() = 16, svcntsw() = 8
Streaming mode: svcntw() = 8, svcntsw() = 8
出力されたベクトル長を解釈します。Non-streaming modeでの svcntw()
は16を返しました。通常のSVEでのベクトル長がワード(32ビット)単位で16個、つまり512ビットということですね。一方、svcntsw()
は8を返しているので、Streaming SVE vectorは256ビットのようです。
なお、GCC 13のクロスコンパイラーが入った状態でClangでスタティックリンクすると undefined reference to `__arm_sme_state'
というリンクエラーが出ました。動的リンクして qemu-aarch64 -L /usr/aarch64-linux-gnu/
とするか、ここでの手順のようにGCC 14を入れましょう。
QEMUでは -cpu
オプションでベクトル長を変えることができます。Non-streaming時のベクトル長は sve-default-vector-length
で、Streamingベクトル長は sme-default-vector-length
で指定できます(詳細はArm CPU Features — QEMU documentationを参照)。単位はバイトです。指定をいじってみましょう。
$ qemu-aarch64 -cpu max,sve-default-vector-length=256,sme-default-vector-length=128 ./veclen
Has SME? Yes
Non-streaming mode: svcntw() = 64, svcntsw() = 32
Streaming mode: svcntw() = 32, svcntsw() = 32
Non-streaming時のベクトル長を256バイト(2048ビット)、Streaming時のベクトル長を128バイト(1024ビット)に指定しました。svcnt
系の関数の出力が変わっているのが見て取れます。
行列乗算のための命令とレジスター
SMEのコアとなる命令は、ベクトルの外積を計算・累積するMOP(sum of outer products)命令群です。浮動小数点数で加算するMOP命令はFMOPA(floating-point sum of outer products and accumulate)という感じです。
ベクトルの外積は数式で書くと
みたいなやつですね。これを繰り返して累積すれば行列乗算になるという寸法です。
乗算と加算は1回の丸め(FMA相当)で計算されます。
レジスターとしては、ZAというでかい領域が追加されます。これは
ZA領域はZA[N]
(ZA.B[N]
、16ビット要素なら ZA.H[N]
、32ビット要素なら ZA.S[N]
等という風に呼びます。
ZA領域を分割してタイルと呼ばれる領域を取り出すこともできます。8ビット要素ならZA0.B
と呼びます。32ビット要素ならZA0.S
, ZA1.S
, ZA2.S
, ZA3.S
と呼びます。
タイルの要素は行ベクトル、列ベクトルとしてアクセスできます。これらのベクトルのことをスライスと呼びます。行ベクトルは水平方向 (horizontal) の H
を、列ベクトルは垂直方向 (vertical) の V
をつけて、ZA0H.S[i]
とか ZA2V.S[j]
という風に呼びます。
行列乗算の実装
行列乗算
まずは、基本となるリファレンス実装です。速度とかは気にしません。ループの順番を入れ替えると早くなるみたいな話もありますが、ここでは考えません。
そこそこのサイズの行列積を計算して、左上の10×10を表示します。
#include <stdio.h>
void matmul(size_t l, size_t m, size_t n, const float A[l][m], const float B[m][n], float C[l][n])
{
for (size_t i = 0; i < l; ++i) {
for (size_t k = 0; k < n; ++k) {
C[i][k] = 0.0f;
for (size_t j = 0; j < m; ++j) {
C[i][k] += A[i][j] * B[j][k];
}
}
}
}
int main(void)
{
float A[100][200];
float B[200][150];
float C[100][150];
for (size_t i = 0; i < 100; ++i) {
for (size_t j = 0; j < 200; ++j) {
A[i][j] = (float)i + (float)j;
}
}
for (size_t i = 0; i < 200; ++i) {
for (size_t j = 0; j < 150; ++j) {
B[i][j] = (float)i - (float)j;
}
}
matmul(100, 200, 150, A, B, C);
for (size_t i = 0; i < 10; ++i) {
for (size_t j = 0; j < 10; ++j) {
printf("%g ", C[i][j]);
}
puts("");
}
}
GCCでのコンパイル・実行例:
$ aarch64-linux-gnu-gcc-14 -O2 -march=armv9-a+sme -static -o matmul matmul.c
$ qemu-aarch64 ./matmul
2.6467e+06 2.6268e+06 2.6069e+06 2.587e+06 2.5671e+06 2.5472e+06 2.5273e+06 2.5074e+06 2.4875e+06 2.4676e+06
2.6666e+06 2.6465e+06 2.6264e+06 2.6063e+06 2.5862e+06 2.5661e+06 2.546e+06 2.5259e+06 2.5058e+06 2.4857e+06
2.6865e+06 2.6662e+06 2.6459e+06 2.6256e+06 2.6053e+06 2.585e+06 2.5647e+06 2.5444e+06 2.5241e+06 2.5038e+06
2.7064e+06 2.6859e+06 2.6654e+06 2.6449e+06 2.6244e+06 2.6039e+06 2.5834e+06 2.5629e+06 2.5424e+06 2.5219e+06
2.7263e+06 2.7056e+06 2.6849e+06 2.6642e+06 2.6435e+06 2.6228e+06 2.6021e+06 2.5814e+06 2.5607e+06 2.54e+06
2.7462e+06 2.7253e+06 2.7044e+06 2.6835e+06 2.6626e+06 2.6417e+06 2.6208e+06 2.5999e+06 2.579e+06 2.5581e+06
2.7661e+06 2.745e+06 2.7239e+06 2.7028e+06 2.6817e+06 2.6606e+06 2.6395e+06 2.6184e+06 2.5973e+06 2.5762e+06
2.786e+06 2.7647e+06 2.7434e+06 2.7221e+06 2.7008e+06 2.6795e+06 2.6582e+06 2.6369e+06 2.6156e+06 2.5943e+06
2.8059e+06 2.7844e+06 2.7629e+06 2.7414e+06 2.7199e+06 2.6984e+06 2.6769e+06 2.6554e+06 2.6339e+06 2.6124e+06
2.8258e+06 2.8041e+06 2.7824e+06 2.7607e+06 2.739e+06 2.7173e+06 2.6956e+06 2.6739e+06 2.6522e+06 2.6305e+06
SMEを使った行列乗算ですが、
内側のループ(
しかし、FEAT_SME_FA64
を実装していれば使える)ようです。
ここでは、
まずは関数宣言です。SMEを使うので __arm_locally_streaming
を指定し、さらにZAを使用することを表す属性をつけます。ZAの状態は呼び出し元から受け継がないので、__arm_new("za")
を指定します。これもACLEとコンパイラーの実装で乖離しているように見えますが、ここでは現在のコンパイラーが受理してくれる記法を使います。
__arm_locally_streaming
__arm_new("za")
void matmul_sme(size_t l, size_t m, size_t n, const float A[l][m], const float B[m][n], float C[l][n])
{
// ...
}
for (size_t i = 0; i < l; i += svcntw()) {
svbool_t maskA = svwhilelt_b32_s64(i, l);
for (size_t k = 0; k < n; k += svcntw()) {
svbool_t maskB = svwhilelt_b32_s64(k, n);
// ...
}
}
32ビットだとタイルを4つ使用できます。ここではアキュムレーター用に0番を、
まずは、タイルのクリアと、
#define TILE_ACC 0
#define TILE_A 1
svzero_za(); // 全てのタイルをゼロクリアする
size_t limit_ii = min(l, i + svcntw());
for (size_t j = 0; j < m; j += svcntw()) {
svbool_t maskTA = svwhilelt_b32_s64(j, m);
svzero_mask_za(17 << TILE_A); // TILE_A をゼロクリアする(指定方法が独特なので注意)
for (size_t ii = i; ii < limit_ii; ++ii) {
svld1_hor_za32(TILE_A, ii /*, 0 */, maskTA, &A[ii][j]);
}
// ... この辺で外積の計算と加算を行う ...
}
// ... この辺で C の部分行列を書き込む ...
svld1_hor_za32
はメモリからタイルの列ベクトルにロードする組み込み関数です。ACLEの記述と実際の実装で引数の個数が違いました。謎です。スライスのインデックスは適宜modで計算されるので、ii - i
みたいなことをする必要はありません。
次は、外積の計算と加算です。
size_t limit_jj = min(m, j + svcntw());
for (size_t jj = j; jj < limit_jj; ++jj) {
svfloat32_t a;
a = svread_ver_za32_f32_m(a, maskA, TILE_A, jj); // タイルから列ベクトルを取得する
svfloat32_t b = svld1_f32(maskB, &B[jj][k]);
svmopa_za32_f32_m(TILE_ACC, maskA, maskB, a, b); // 外積を計算してタイルに加算する
}
svread_ver_za32_f32_m
はベクトルを読み出す関数だと思うのですが、なぜか引数にもベクトルを指定する必要があります。謎です。プレディケートで無効になった要素が引数から取得されるんでしょうか。これもスライスのインデックスは適宜modされるので、jj - j
とする必要はありません。
最後に、累積したタイルを
for (size_t ii = i; ii < limit_ii; ++ii) {
svst1_hor_za32(TILE_ACC, ii /*, 0 */, maskB, &C[ii][k]);
}
svst1_hor_za32
もACLEと実際の実装で引数の個数が違いました。謎です。これもスライスのインデックスは適宜modされるので、ii - i
とする必要はありません。
完全なソースコードはこんな感じになります:
#include <stdio.h>
#include <arm_sme.h>
void matmul(size_t l, size_t m, size_t n, const float A[l][m], const float B[m][n], float C[l][n])
{
for (size_t i = 0; i < l; ++i) {
for (size_t k = 0; k < n; ++k) {
C[i][k] = 0.0f;
for (size_t j = 0; j < m; ++j) {
C[i][k] += A[i][j] * B[j][k];
}
}
}
}
static inline size_t min(size_t x, size_t y) {
return x > y ? y : x;
}
__arm_locally_streaming
__arm_new("za")
void matmul_sme(size_t l, size_t m, size_t n, const float A[l][m], const float B[m][n], float C[l][n])
{
for (size_t i = 0; i < l; i += svcntw()) {
svbool_t maskA = svwhilelt_b32_s64(i, l);
for (size_t k = 0; k < n; k += svcntw()) {
svbool_t maskB = svwhilelt_b32_s64(k, n);
#define TILE_ACC 0
#define TILE_A 1
svzero_za();
size_t limit_ii = min(l, i + svcntw());
for (size_t j = 0; j < m; j += svcntw()) {
svbool_t maskTA = svwhilelt_b32_s64(j, m);
svzero_mask_za(17 << TILE_A);
for (size_t ii = i; ii < limit_ii; ++ii) {
svld1_hor_za32(TILE_A, ii /*, 0 */, maskTA, &A[ii][j]);
}
size_t limit_jj = min(m, j + svcntw());
for (size_t jj = j; jj < limit_jj; ++jj) {
svfloat32_t a;
a = svread_ver_za32_f32_m(a, maskA, TILE_A, jj);
svfloat32_t b = svld1_f32(maskB, &B[jj][k]);
svmopa_za32_f32_m(TILE_ACC, maskA, maskB, a, b);
}
}
for (size_t ii = i; ii < limit_ii; ++ii) {
svst1_hor_za32(TILE_ACC, ii /*, 0 */, maskB, &C[ii][k]);
}
}
}
}
int main(void)
{
float A[100][200];
float B[200][150];
float C[100][150];
for (size_t i = 0; i < 100; ++i) {
for (size_t j = 0; j < 200; ++j) {
A[i][j] = (float)i + (float)j;
}
}
for (size_t i = 0; i < 200; ++i) {
for (size_t j = 0; j < 150; ++j) {
B[i][j] = (float)i - (float)j;
}
}
matmul(100, 200, 150, A, B, C);
for (size_t i = 0; i < 10; ++i) {
for (size_t j = 0; j < 10; ++j) {
printf("%g ", C[i][j]);
}
puts("");
}
puts("---");
matmul_sme(100, 200, 150, A, B, C);
for (size_t i = 0; i < 10; ++i) {
for (size_t j = 0; j < 10; ++j) {
printf("%g ", C[i][j]);
}
puts("");
}
}
GCCでコンパイル、QEMUで実行してみましょう。
$ aarch64-linux-gnu-gcc-14 -O2 -march=armv9-a+sme -static -o matmul-sme matmul-sme.c
$ qemu-aarch64 ./matmul-sme
2.6467e+06 2.6268e+06 2.6069e+06 2.587e+06 2.5671e+06 2.5472e+06 2.5273e+06 2.5074e+06 2.4875e+06 2.4676e+06
2.6666e+06 2.6465e+06 2.6264e+06 2.6063e+06 2.5862e+06 2.5661e+06 2.546e+06 2.5259e+06 2.5058e+06 2.4857e+06
2.6865e+06 2.6662e+06 2.6459e+06 2.6256e+06 2.6053e+06 2.585e+06 2.5647e+06 2.5444e+06 2.5241e+06 2.5038e+06
2.7064e+06 2.6859e+06 2.6654e+06 2.6449e+06 2.6244e+06 2.6039e+06 2.5834e+06 2.5629e+06 2.5424e+06 2.5219e+06
2.7263e+06 2.7056e+06 2.6849e+06 2.6642e+06 2.6435e+06 2.6228e+06 2.6021e+06 2.5814e+06 2.5607e+06 2.54e+06
2.7462e+06 2.7253e+06 2.7044e+06 2.6835e+06 2.6626e+06 2.6417e+06 2.6208e+06 2.5999e+06 2.579e+06 2.5581e+06
2.7661e+06 2.745e+06 2.7239e+06 2.7028e+06 2.6817e+06 2.6606e+06 2.6395e+06 2.6184e+06 2.5973e+06 2.5762e+06
2.786e+06 2.7647e+06 2.7434e+06 2.7221e+06 2.7008e+06 2.6795e+06 2.6582e+06 2.6369e+06 2.6156e+06 2.5943e+06
2.8059e+06 2.7844e+06 2.7629e+06 2.7414e+06 2.7199e+06 2.6984e+06 2.6769e+06 2.6554e+06 2.6339e+06 2.6124e+06
2.8258e+06 2.8041e+06 2.7824e+06 2.7607e+06 2.739e+06 2.7173e+06 2.6956e+06 2.6739e+06 2.6522e+06 2.6305e+06
---
2.6467e+06 2.6268e+06 2.6069e+06 2.587e+06 2.5671e+06 2.5472e+06 2.5273e+06 2.5074e+06 2.4875e+06 2.4676e+06
2.6666e+06 2.6465e+06 2.6264e+06 2.6063e+06 2.5862e+06 2.5661e+06 2.546e+06 2.5259e+06 2.5058e+06 2.4857e+06
2.6865e+06 2.6662e+06 2.6459e+06 2.6256e+06 2.6053e+06 2.585e+06 2.5647e+06 2.5444e+06 2.5241e+06 2.5038e+06
2.7064e+06 2.6859e+06 2.6654e+06 2.6449e+06 2.6244e+06 2.6039e+06 2.5834e+06 2.5629e+06 2.5424e+06 2.5219e+06
2.7263e+06 2.7056e+06 2.6849e+06 2.6642e+06 2.6435e+06 2.6228e+06 2.6021e+06 2.5814e+06 2.5607e+06 2.54e+06
2.7462e+06 2.7253e+06 2.7044e+06 2.6835e+06 2.6626e+06 2.6417e+06 2.6208e+06 2.5999e+06 2.579e+06 2.5581e+06
2.7661e+06 2.745e+06 2.7239e+06 2.7028e+06 2.6817e+06 2.6606e+06 2.6395e+06 2.6184e+06 2.5973e+06 2.5762e+06
2.786e+06 2.7647e+06 2.7434e+06 2.7221e+06 2.7008e+06 2.6795e+06 2.6582e+06 2.6369e+06 2.6156e+06 2.5943e+06
2.8059e+06 2.7844e+06 2.7629e+06 2.7414e+06 2.7199e+06 2.6984e+06 2.6769e+06 2.6554e+06 2.6339e+06 2.6124e+06
2.8258e+06 2.8041e+06 2.7824e+06 2.7607e+06 2.739e+06 2.7173e+06 2.6956e+06 2.6739e+06 2.6522e+06 2.6305e+06
ループで書いたやつと一致してそうです。よかった!
もちろん、Clangでもコンパイルできます。
$ clang-18 -O2 --target=aarch64-linux-gnu -march=armv9-a+sme -static -o matmul-sme matmul-sme.c
$ qemu-aarch64 ./matmul-sme
... 略 ...
雑感
SMEで行列乗算を実装してみました。
AppleのCPUが行列乗算用の専用命令(AMX)を持っていたことは知られていましたが、SMEという形で実装されれば利用しやすくなって嬉しいです。SMEが実装されたらApple AMXは引退ですかね。
他のプラットフォームとの比較で言うと、IntelやIBMのサーバー向けCPUはなかなか個人では手が出ませんが(逸般の誤家庭にはあるのかもしれませんが)、AppleのCPUは普通に個人でも使えます。つまり、Apple M4がMacに搭載されれば個人のパソコンでSMEプログラミングができるようになり、命令セットオタク(CPUの総合的な性能よりも命令セット星取表を重視する輩)としては楽しみです。私の手元にはIntel MacとかApple M1 Macとかが転がっていてそろそろ買い替えたい気持ちが高まっていますが、Apple M4 Macが出てくるまで待つべきかもしれません。
気がかりなのは、ACLEの記述とGCC 14/Clang 18の実装に差があったことです。ここに書いたコードは将来動かなくなってしまうかもしれません。アセンブリーでゴリゴリ書く人には関係ないのかもしれませんが。 執筆時点のACLEは古かったようですが、ここに書いたコードは改訂後の仕様に沿っているので大丈夫そうです。
念の為ですが、私は行列乗算については素人ですし、ここで書いた行列乗算のコードにも間違いが含まれる可能性があります。この記事を参考にする場合はその点を頭に入れておいてください。
参考リンク
- Arm Architecture Reference Manual for A-profile architecture
-
The Scalable Matrix Extension (SME), for Armv9-A
- 現在はArchitecuter Reference ManualにマージされたのでこれはRETIRED扱い
- ACLE: Arm C Language Extensions
- LLVM: Support for AArch64 Scalable Matrix Extension in LLVM — LLVM 19.0.0git documentation
- Clang: Attributes in Clang — Clang 19.0.0git documentation
- QEMU: Arm CPU Features — QEMU documentation
Discussion