(w.i.p) Attention 向け a64fx int8 sdot 90% peak eff のメモ
よくある? HPC gemm の例と異なり, ML/AI むけ(特に Attention, FFN)では低精度(e.g. int8) + 非矩形の小規模行列サイズ([4096, 128] など)となり, HPC 向けとは異なった最適化が必要となる.
(CPU 向けの FlashGEMM https://zenn.dev/syoyo/articles/7a4e410389addb 参照)
また, sdot(udot) 命令の場合は fmla(fused multiply-add)と異なり, sdot の場合は accum int32(int16 sdot の場合は int64)のため, accum のレジスタ再利用がしずらく fmla に比べてレジスタ圧が強い(ハズ).
movprfx は sdot には pack して使えない. mov や dup は costly(FLA のみ)なので, eor で accum レジスタクリアするとよい
fmla では https://github.com/fujitsu/A64FX/tree/master/sample に 5x4 のケース(大体 peak 92% っぽい. 頑張れば 96% いけるんやろか)があるが, sdot の場合は compute 密度達成(2 sdot/cycle peak)が足りず, 今のところ 6x4 が最適である(ただし accum 24, A 6, B 2 でレジスタ 32 本使ってしまうので, B の double-buffering はできない)
そんな感じで Claude クンにいろいろ試行錯誤してもらった.
今のところ d=512 で 83%+ で, 最大 90% くらい(21 sdot/cycle/CMG). sector cache hint あり.
L(seqlen, N)は 640 くらい.
行列データは L2 1 way(512kb)以下あたりにフィットするくらいのサイズ.
最近だと L=2048~8192 くらいと思いますので, L2 のうまいオーバーラップ考えないといけないカモ.
(あとは Linear Attention 想定にしてしまうとか...)
簡単な方針
SVE ではビット演算や整数演算でも問答無用で FL pipe を使うため, データの unpack などちょっとした処理でもすぐ peak 100% 出すのが無理になる.
ld/st + sdot/fmla だけという構成にしないといけない.
(せめてビット演算や加算くらいは別 pipe にしてほしかったやね)
コアごとに別の行列を計算する. 共有化しなくても L2 bandwidth 的には半分以下で足りる(ただしレイテンシ隠ぺいが問題). 現時点では共通化すると逆に性能落ちた.
あと sector cache hint 利用.
d=256, 512. 低い d だと性能でないかも
(d = head_dim = reduction axis)
A64FX Architecture Overview
Compute Resources (per CMG = 12 cores)
- 12 cores per CMG (Core Memory Group)
- 2 FPUs per core capable of SDOT
- Each FPU: 1 SDOT instruction/cycle
- Peak: 24 SDOT/cycle per CMG
SDOT Instruction
sdot z_acc.s, z_a.b, z_b.b
- Computes 16 independent dot products in parallel
- Each lane: 4 INT8 multiplies + accumulate into INT32
- SVE vector length: 512 bits = 64 bytes = 16 INT32 lanes
Memory Hierarchy
- L1 Data Cache: 64 KB per core (4-way, 256B line)
- L2 Cache: 8 MB shared per CMG(actual 7MB)
- Memory bandwidth: ~256 GB/s per CMG
- Sector cache hints: Control L1/L2 placement via address bit 56
Kernel Design: 6×4 Tile
Tile Dimensions
MR = 6 rows (A dimension)
NR = 4 vectors = 64 INT32 elements (B dimension)
K = reduction dimension (processed in groups of 4)
Register Allocation (32 SVE registers)
Accumulators (24 registers):
z0-z3: Row 0, columns 0-3
z4-z7: Row 1, columns 0-3
z8-z11: Row 2, columns 0-3
z12-z15: Row 3, columns 0-3
z16-z19: Row 4, columns 0-3
z20-z23: Row 5, columns 0-3
A broadcasts (6 registers):
z24-z29: 6 rows × 4 bytes each (broadcast via ld1rw)
B vectors (2 registers, double-buffered):
z30-z31: Current B vectors being processed
Memory Layout
Matrix A (packed, row-interleaved per K-group):
K-group 0: [row0_k0-3][row1_k0-3][row2_k0-3][row3_k0-3][row4_k0-3][row5_k0-3]
K-group 1: [row0_k4-7][row1_k4-7][row2_k4-7][row3_k4-7][row4_k4-7][row5_k4-7]
...
Size per tile: MR × K = 6 × K bytes
Matrix B (packed, 4 vectors per K-group):
K-group 0: [vec0: 64B][vec1: 64B][vec2: 64B][vec3: 64B] = 256 bytes
K-group 1: [vec0: 64B][vec1: 64B][vec2: 64B][vec3: 64B] = 256 bytes
...
Size per N-tile: (K/4) × 256 bytes
Main Loop Structure
Loop Overview
2 K-groups per iteration = 8 K values = 48 SDOT instructions
72 total instructions per iteration = 288 bytes = 18 decode groups
Instruction Breakdown per Iteration
| Type | Count | Purpose |
|---|---|---|
| ld1b (B loads) | 8 | Load B vectors (64 bytes each) |
| ld1rw (A broadcasts) | 12 | Broadcast A elements (4 bytes each) |
| sdot | 48 | 24 per K-group × 2 K-groups |
| add (pointer) | 2 | Advance A and B pointers |
| subs | 1 | Decrement loop counter |
| b.gt | 1 | Conditional branch |
| Total | 72 |
Scheduling Pattern (K-group 0, first half)
// Load B vectors for columns 0-1
ld1b {z30.b}, p0/z, [x1, #0, mul vl] // B col 0
ld1b {z31.b}, p0/z, [x1, #1, mul vl] // B col 1
// 12 SDOTs using B cols 0-1 (all 6 A rows)
sdot z0.s, z24.b, z30.b // C[0,0] += A[0] · B[0]
sdot z1.s, z24.b, z31.b // C[0,1] += A[0] · B[1]
sdot z4.s, z25.b, z30.b // C[1,0] += A[1] · B[0]
sdot z5.s, z25.b, z31.b // C[1,1] += A[1] · B[1]
sdot z8.s, z26.b, z30.b // C[2,0] += A[2] · B[0]
sdot z9.s, z26.b, z31.b // C[2,1] += A[2] · B[1]
sdot z12.s, z27.b, z30.b // C[3,0] += A[3] · B[0]
sdot z13.s, z27.b, z31.b // C[3,1] += A[3] · B[1]
sdot z16.s, z28.b, z30.b // C[4,0] += A[4] · B[0]
sdot z17.s, z28.b, z31.b // C[4,1] += A[4] · B[1]
sdot z20.s, z29.b, z30.b // C[5,0] += A[5] · B[0]
sdot z21.s, z29.b, z31.b // C[5,1] += A[5] · B[1]
Scheduling Pattern (K-group 0, second half with A preload)
// Load B vectors for columns 2-3
ld1b {z30.b}, p0/z, [x1, #2, mul vl] // B col 2
ld1b {z31.b}, p0/z, [x1, #3, mul vl] // B col 3
// Interleaved SDOTs + A loads for next K-group
sdot z2.s, z24.b, z30.b // C[0,2]
sdot z3.s, z24.b, z31.b // C[0,3]
ld1rw {z24.s}, p0/z, [x0, #24] // Preload A[0] for K-group 1
sdot z6.s, z25.b, z30.b // C[1,2]
sdot z7.s, z25.b, z31.b // C[1,3]
ld1rw {z25.s}, p0/z, [x0, #28] // Preload A[1]
sdot z10.s, z26.b, z30.b // C[2,2]
sdot z11.s, z26.b, z31.b // C[2,3]
ld1rw {z26.s}, p0/z, [x0, #32] // Preload A[2]
sdot z14.s, z27.b, z30.b // C[3,2]
sdot z15.s, z27.b, z31.b // C[3,3]
ld1rw {z27.s}, p0/z, [x0, #36] // Preload A[3]
sdot z18.s, z28.b, z30.b // C[4,2]
sdot z19.s, z28.b, z31.b // C[4,3]
ld1rw {z28.s}, p0/z, [x0, #40] // Preload A[4]
sdot z22.s, z29.b, z30.b // C[5,2]
sdot z23.s, z29.b, z31.b // C[5,3]
ld1rw {z29.s}, p0/z, [x0, #44] // Preload A[5]
Key Optimizations
1. Out-of-Order Load Hiding
A64FX has 11-cycle load latency. The kernel exploits OoO execution:
- Loads issued early, SDOTs use results later
- A loads interleaved with SDOTs (2 SDOTs between each A load)
- Register renaming allows loads to complete while SDOTs execute
Load-to-use distance:
- B loads: ~12 SDOTs before use (well hidden)
- A loads: ~24 SDOTs before use (fully hidden)
2. Sector Cache Hints for B Matrix
// Set bit 56 to use sector 1 (L2 streaming, bypass L1)
mov x19, #1
lsl x19, x19, #56
orr x1, x1, x19
B matrix is streamed through (each element used once per tile):
- Sector 1: Data goes to L2, not polluting L1
- L1 reserved for A matrix (reused across N tiles)
- Prevents cache thrashing
3. Optimal Branch Placement (Slot 3)
A64FX decodes 4 instructions/cycle (16 bytes):
Loop: 72 instructions = 288 bytes = 18 decode groups
Branch at instruction 72: offset 284 bytes
284 mod 16 = 12 → slot 3 (last in decode group)
Decode group structure at loop end:
[add x0] [add x1] [subs x6] [b.gt]
slot 0 slot 1 slot 2 slot 3
Branch at slot 3 = fetch of next iteration can begin immediately.
4. Loop Unrolling (2× K-groups)
Processing 2 K-groups per iteration:
- 48 SDOTs per iteration (vs 24 for 1×)
- Loop overhead: 4 instructions / 48 SDOTs = 8.3%
- Reduced branch misprediction impact
5. 64-Byte Loop Alignment
.align 6 // 64-byte alignment
.Lmain_v5:
- Loop start aligned to cache line boundary
- Optimal instruction fetch
- No partial cache line fetches at loop entry
Performance Analysis
Theoretical vs Achieved
Per iteration (2 K-groups):
- SDOTs: 48
- Minimum cycles: 48 / 2 = 24 cycles (2 SDOT/cycle per core)
Measured overhead:
- Actual cycles: ~26.5 per iteration
- Overhead: 10.4%
Overhead sources:
- Loop control: ~4 cycles (add, add, subs, b.gt)
- Memory latency (not fully hidden): ~2 cycles
- Instruction decode/issue gaps: ~0.5 cycles
Scaling to 12 Cores
| Cores | Peak | Achieved | Efficiency |
|---|---|---|---|
| 1 | 2.0 | 1.81 | 90.5% |
| 12 | 24.0 | 21.7 | 90.3% |
Near-linear scaling achieved through:
- Independent memory per core (separate A, B, C allocations)
- No shared data structures
- L2 bandwidth sufficient for all cores
Function Signature
void micro_kernel_6x4_ooo_v5(
const int8_t* A, // x0: Packed A tile (MR × K bytes)
const int8_t* B, // x1: Packed B tile (NR × K/4 × 64 bytes)
int32_t* C, // x2: Output C tile (MR × NR int32s)
int64_t K, // x3: K dimension (must be multiple of 4)
int64_t unused, // x4: Reserved (pass 0)
int64_t C_stride // x5: Row stride for C in bytes
);
Usage Example
#define MR 6
#define NR 64
#define NUM_CORES 12
// Allocate per-core buffers
int8_t* A[NUM_CORES]; // MR × K bytes each
int8_t* B[NUM_CORES]; // (N/NR) × (K/4) × 256 bytes each
int32_t* C[NUM_CORES]; // MR × NR × (M/MR) × (N/NR) int32s each
#pragma omp parallel
{
int tid = omp_get_thread_num();
int64_t C_stride = NR * sizeof(int32_t); // 256 bytes
for (int mt = 0; mt < M/MR; mt++) {
for (int nt = 0; nt < N/NR; nt++) {
micro_kernel_6x4_ooo_v5(
A[tid] + mt * MR * K,
B[tid] + nt * (K/4) * 256,
C[tid] + (mt * (N/NR) + nt) * MR * NR,
K, 0, C_stride);
}
}
}
Comparison with Other Approaches
| Kernel | K-unroll | SDOT/cyc | Efficiency | Notes |
|---|---|---|---|---|
| v5 (OoO) | 2× | 21.7 | 90.3% | Best balance |
| v6 (4× unroll) | 4× | 21.5 | 89.6% | Extra pointer update hurts |
| Basic (no OoO) | 1× | 17.5 | 72.9% | Load latency not hidden |
| Intrinsics | 1× | 10.8 | 45.0% | Compiler scheduling suboptimal |
Summary
The v5 kernel achieves 90.3% of theoretical peak through:
- OoO-friendly scheduling: Loads issued early, results used late
- Sector cache hints: B streams through L2, A stays in L1
- Optimal register usage: All 32 SVE registers utilized
- Perfect branch alignment: b.gt at decode slot 3
- 2× K-unroll: Minimal loop overhead
- Per-core memory: No contention between cores
micro-blocking
T.B.W 94% efficient micro-blocking kernel
おまけ
a64fx SVE 1.0 では fp16 のドット積命令 fdot(fp32 accum)は無い(SVE2.1, SME2 から)ので,
fp の場合は fmla を使うこととなる.
Discussion