💨

(w.i.p) Attention 向け a64fx int8 sdot 90% peak eff のメモ

に公開

よくある? HPC gemm の例と異なり, ML/AI むけ(特に Attention, FFN)では低精度(e.g. int8) + 非矩形の小規模行列サイズ([4096, 128] など)となり, HPC 向けとは異なった最適化が必要となる.
(CPU 向けの FlashGEMM https://zenn.dev/syoyo/articles/7a4e410389addb 参照)

また, sdot(udot) 命令の場合は fmla(fused multiply-add)と異なり, sdot の場合は accum int32(int16 sdot の場合は int64)のため, accum のレジスタ再利用がしずらく fmla に比べてレジスタ圧が強い(ハズ).

movprfx は sdot には pack して使えない. mov や dup は costly(FLA のみ)なので, eor で accum レジスタクリアするとよい

fmla では https://github.com/fujitsu/A64FX/tree/master/sample に 5x4 のケース(大体 peak 92% っぽい. 頑張れば 96% いけるんやろか)があるが, sdot の場合は compute 密度達成(2 sdot/cycle peak)が足りず, 今のところ 6x4 が最適である(ただし accum 24, A 6, B 2 でレジスタ 32 本使ってしまうので, B の double-buffering はできない)

そんな感じで Claude クンにいろいろ試行錯誤してもらった.
今のところ d=512 で 83%+ で, 最大 90% くらい(21 sdot/cycle/CMG). sector cache hint あり.
L(seqlen, N)は 640 くらい.
行列データは L2 1 way(512kb)以下あたりにフィットするくらいのサイズ.

最近だと L=2048~8192 くらいと思いますので, L2 のうまいオーバーラップ考えないといけないカモ.
(あとは Linear Attention 想定にしてしまうとか...)

簡単な方針

SVE ではビット演算や整数演算でも問答無用で FL pipe を使うため, データの unpack などちょっとした処理でもすぐ peak 100% 出すのが無理になる.
ld/st + sdot/fmla だけという構成にしないといけない.
(せめてビット演算や加算くらいは別 pipe にしてほしかったやね)

コアごとに別の行列を計算する. 共有化しなくても L2 bandwidth 的には半分以下で足りる(ただしレイテンシ隠ぺいが問題). 現時点では共通化すると逆に性能落ちた.
あと sector cache hint 利用.

d=256, 512. 低い d だと性能でないかも
(d = head_dim = reduction axis)

A64FX Architecture Overview

Compute Resources (per CMG = 12 cores)

  • 12 cores per CMG (Core Memory Group)
  • 2 FPUs per core capable of SDOT
  • Each FPU: 1 SDOT instruction/cycle
  • Peak: 24 SDOT/cycle per CMG

SDOT Instruction

sdot z_acc.s, z_a.b, z_b.b
  • Computes 16 independent dot products in parallel
  • Each lane: 4 INT8 multiplies + accumulate into INT32
  • SVE vector length: 512 bits = 64 bytes = 16 INT32 lanes

Memory Hierarchy

  • L1 Data Cache: 64 KB per core (4-way, 256B line)
  • L2 Cache: 8 MB shared per CMG(actual 7MB)
  • Memory bandwidth: ~256 GB/s per CMG
  • Sector cache hints: Control L1/L2 placement via address bit 56

Kernel Design: 6×4 Tile

Tile Dimensions

MR = 6 rows (A dimension)
NR = 4 vectors = 64 INT32 elements (B dimension)
K  = reduction dimension (processed in groups of 4)

Register Allocation (32 SVE registers)

Accumulators (24 registers):
  z0-z3:   Row 0, columns 0-3
  z4-z7:   Row 1, columns 0-3
  z8-z11:  Row 2, columns 0-3
  z12-z15: Row 3, columns 0-3
  z16-z19: Row 4, columns 0-3
  z20-z23: Row 5, columns 0-3

A broadcasts (6 registers):
  z24-z29: 6 rows × 4 bytes each (broadcast via ld1rw)

B vectors (2 registers, double-buffered):
  z30-z31: Current B vectors being processed

Memory Layout

Matrix A (packed, row-interleaved per K-group):

K-group 0: [row0_k0-3][row1_k0-3][row2_k0-3][row3_k0-3][row4_k0-3][row5_k0-3]
K-group 1: [row0_k4-7][row1_k4-7][row2_k4-7][row3_k4-7][row4_k4-7][row5_k4-7]
...
Size per tile: MR × K = 6 × K bytes

Matrix B (packed, 4 vectors per K-group):

K-group 0: [vec0: 64B][vec1: 64B][vec2: 64B][vec3: 64B]  = 256 bytes
K-group 1: [vec0: 64B][vec1: 64B][vec2: 64B][vec3: 64B]  = 256 bytes
...
Size per N-tile: (K/4) × 256 bytes

Main Loop Structure

Loop Overview

2 K-groups per iteration = 8 K values = 48 SDOT instructions
72 total instructions per iteration = 288 bytes = 18 decode groups

Instruction Breakdown per Iteration

Type Count Purpose
ld1b (B loads) 8 Load B vectors (64 bytes each)
ld1rw (A broadcasts) 12 Broadcast A elements (4 bytes each)
sdot 48 24 per K-group × 2 K-groups
add (pointer) 2 Advance A and B pointers
subs 1 Decrement loop counter
b.gt 1 Conditional branch
Total 72

Scheduling Pattern (K-group 0, first half)

// Load B vectors for columns 0-1
ld1b    {z30.b}, p0/z, [x1, #0, mul vl]    // B col 0
ld1b    {z31.b}, p0/z, [x1, #1, mul vl]    // B col 1

// 12 SDOTs using B cols 0-1 (all 6 A rows)
sdot    z0.s, z24.b, z30.b     // C[0,0] += A[0] · B[0]
sdot    z1.s, z24.b, z31.b     // C[0,1] += A[0] · B[1]
sdot    z4.s, z25.b, z30.b     // C[1,0] += A[1] · B[0]
sdot    z5.s, z25.b, z31.b     // C[1,1] += A[1] · B[1]
sdot    z8.s, z26.b, z30.b     // C[2,0] += A[2] · B[0]
sdot    z9.s, z26.b, z31.b     // C[2,1] += A[2] · B[1]
sdot    z12.s, z27.b, z30.b    // C[3,0] += A[3] · B[0]
sdot    z13.s, z27.b, z31.b    // C[3,1] += A[3] · B[1]
sdot    z16.s, z28.b, z30.b    // C[4,0] += A[4] · B[0]
sdot    z17.s, z28.b, z31.b    // C[4,1] += A[4] · B[1]
sdot    z20.s, z29.b, z30.b    // C[5,0] += A[5] · B[0]
sdot    z21.s, z29.b, z31.b    // C[5,1] += A[5] · B[1]

Scheduling Pattern (K-group 0, second half with A preload)

// Load B vectors for columns 2-3
ld1b    {z30.b}, p0/z, [x1, #2, mul vl]    // B col 2
ld1b    {z31.b}, p0/z, [x1, #3, mul vl]    // B col 3

// Interleaved SDOTs + A loads for next K-group
sdot    z2.s, z24.b, z30.b     // C[0,2]
sdot    z3.s, z24.b, z31.b     // C[0,3]
ld1rw   {z24.s}, p0/z, [x0, #24]           // Preload A[0] for K-group 1
sdot    z6.s, z25.b, z30.b     // C[1,2]
sdot    z7.s, z25.b, z31.b     // C[1,3]
ld1rw   {z25.s}, p0/z, [x0, #28]           // Preload A[1]
sdot    z10.s, z26.b, z30.b    // C[2,2]
sdot    z11.s, z26.b, z31.b    // C[2,3]
ld1rw   {z26.s}, p0/z, [x0, #32]           // Preload A[2]
sdot    z14.s, z27.b, z30.b    // C[3,2]
sdot    z15.s, z27.b, z31.b    // C[3,3]
ld1rw   {z27.s}, p0/z, [x0, #36]           // Preload A[3]
sdot    z18.s, z28.b, z30.b    // C[4,2]
sdot    z19.s, z28.b, z31.b    // C[4,3]
ld1rw   {z28.s}, p0/z, [x0, #40]           // Preload A[4]
sdot    z22.s, z29.b, z30.b    // C[5,2]
sdot    z23.s, z29.b, z31.b    // C[5,3]
ld1rw   {z29.s}, p0/z, [x0, #44]           // Preload A[5]

Key Optimizations

1. Out-of-Order Load Hiding

A64FX has 11-cycle load latency. The kernel exploits OoO execution:

  • Loads issued early, SDOTs use results later
  • A loads interleaved with SDOTs (2 SDOTs between each A load)
  • Register renaming allows loads to complete while SDOTs execute

Load-to-use distance:

  • B loads: ~12 SDOTs before use (well hidden)
  • A loads: ~24 SDOTs before use (fully hidden)

2. Sector Cache Hints for B Matrix

// Set bit 56 to use sector 1 (L2 streaming, bypass L1)
mov     x19, #1
lsl     x19, x19, #56
orr     x1, x1, x19

B matrix is streamed through (each element used once per tile):

  • Sector 1: Data goes to L2, not polluting L1
  • L1 reserved for A matrix (reused across N tiles)
  • Prevents cache thrashing

3. Optimal Branch Placement (Slot 3)

A64FX decodes 4 instructions/cycle (16 bytes):

Loop: 72 instructions = 288 bytes = 18 decode groups
Branch at instruction 72: offset 284 bytes
284 mod 16 = 12 → slot 3 (last in decode group)

Decode group structure at loop end:

[add x0] [add x1] [subs x6] [b.gt]
 slot 0   slot 1   slot 2   slot 3

Branch at slot 3 = fetch of next iteration can begin immediately.

4. Loop Unrolling (2× K-groups)

Processing 2 K-groups per iteration:

  • 48 SDOTs per iteration (vs 24 for 1×)
  • Loop overhead: 4 instructions / 48 SDOTs = 8.3%
  • Reduced branch misprediction impact

5. 64-Byte Loop Alignment

.align 6        // 64-byte alignment
.Lmain_v5:
  • Loop start aligned to cache line boundary
  • Optimal instruction fetch
  • No partial cache line fetches at loop entry

Performance Analysis

Theoretical vs Achieved

Per iteration (2 K-groups):

  • SDOTs: 48
  • Minimum cycles: 48 / 2 = 24 cycles (2 SDOT/cycle per core)

Measured overhead:

  • Actual cycles: ~26.5 per iteration
  • Overhead: 10.4%

Overhead sources:

  1. Loop control: ~4 cycles (add, add, subs, b.gt)
  2. Memory latency (not fully hidden): ~2 cycles
  3. Instruction decode/issue gaps: ~0.5 cycles

Scaling to 12 Cores

Cores Peak Achieved Efficiency
1 2.0 1.81 90.5%
12 24.0 21.7 90.3%

Near-linear scaling achieved through:

  • Independent memory per core (separate A, B, C allocations)
  • No shared data structures
  • L2 bandwidth sufficient for all cores

Function Signature

void micro_kernel_6x4_ooo_v5(
    const int8_t* A,      // x0: Packed A tile (MR × K bytes)
    const int8_t* B,      // x1: Packed B tile (NR × K/4 × 64 bytes)
    int32_t* C,           // x2: Output C tile (MR × NR int32s)
    int64_t K,            // x3: K dimension (must be multiple of 4)
    int64_t unused,       // x4: Reserved (pass 0)
    int64_t C_stride      // x5: Row stride for C in bytes
);

Usage Example

#define MR 6
#define NR 64
#define NUM_CORES 12

// Allocate per-core buffers
int8_t* A[NUM_CORES];   // MR × K bytes each
int8_t* B[NUM_CORES];   // (N/NR) × (K/4) × 256 bytes each
int32_t* C[NUM_CORES];  // MR × NR × (M/MR) × (N/NR) int32s each

#pragma omp parallel
{
    int tid = omp_get_thread_num();
    int64_t C_stride = NR * sizeof(int32_t);  // 256 bytes

    for (int mt = 0; mt < M/MR; mt++) {
        for (int nt = 0; nt < N/NR; nt++) {
            micro_kernel_6x4_ooo_v5(
                A[tid] + mt * MR * K,
                B[tid] + nt * (K/4) * 256,
                C[tid] + (mt * (N/NR) + nt) * MR * NR,
                K, 0, C_stride);
        }
    }
}

Comparison with Other Approaches

Kernel K-unroll SDOT/cyc Efficiency Notes
v5 (OoO) 21.7 90.3% Best balance
v6 (4× unroll) 21.5 89.6% Extra pointer update hurts
Basic (no OoO) 17.5 72.9% Load latency not hidden
Intrinsics 10.8 45.0% Compiler scheduling suboptimal

Summary

The v5 kernel achieves 90.3% of theoretical peak through:

  1. OoO-friendly scheduling: Loads issued early, results used late
  2. Sector cache hints: B streams through L2, A stays in L1
  3. Optimal register usage: All 32 SVE registers utilized
  4. Perfect branch alignment: b.gt at decode slot 3
  5. 2× K-unroll: Minimal loop overhead
  6. Per-core memory: No contention between cores

micro-blocking

T.B.W 94% efficient micro-blocking kernel

おまけ

a64fx SVE 1.0 では fp16 のドット積命令 fdot(fp32 accum)は無い(SVE2.1, SME2 から)ので,
fp の場合は fmla を使うこととなる.

https://dougallj.github.io/asil/doc/fdot_z_zzz_16.html

Discussion