test: HD=16 FMHA softmax only (skip PV for now)

This commit is contained in:
2026-05-28 13:03:06 +00:00
parent 834d682443
commit 38d7bcd776

View File

@@ -134,54 +134,17 @@ test_fmha_hd16(const bf16_t* q, const bf16_t* k, const bf16_t* v,
}
__syncthreads();
// ================================================================
// STEP 3: PV GEMM — P (TMEM) × V (SMEM) → O (TMEM)
// tcgen05.mma TS: A from TMEM, B from SMEM, result to TMEM.
// For each PV K-tile kt (K=16):
// A = P[:, 16*kt:16*kt+16] → TMEM starting at tb + 16*kt
// B = V[16*kt:16*kt+16, :] → SMEM at sV0 + kt * V_TILE_SZ
// C = O (128, 16) → TMEM starting at tb (overwrite P's first 16 cols)
// accumulate across K-tiles
// ================================================================
// For HD=16, O is (128, 16) in TMEM — 16 TMEM columns.
// But we allocated 128 TMEM columns (from QK). O needs 16 columns.
// We can reuse the same TMEM region (tb) for O since P is being consumed.
// The O output TMEM base can be tb (reusing the same allocation).
//
// idesc for PV: M=128, N=16 (O is 128×16)
// MMA_M = 128/16 = 8, MMA_N = 16/8 = 2
uint32_t idesc_pv = make_idesc(BLOCK_MN, HD);
for (int kt = 0; kt < VKT; kt++) {
bf16_t* sv = sV0 + kt * V_TILE_SZ;
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sv), MMA_K_BF16);
// P's K-tile in TMEM: columns [16*kt, 16*kt+15]
// For tcgen05.mma TS, the A operand is a TMEM address.
// The hardware reads 16 consecutive TMEM columns starting from tmem_a.
uint32_t tmem_a = tb + kt * MMA_K_BF16; // 16 columns of P
if (tid == 0) umma_ts_f16(tb, tmem_a, dv, idesc_pv, kt > 0);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
}
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
// ================================================================
// STEP 4: Epilogue — O (TMEM) → normalize → BF16 → GMEM
// O is (128, 16) in TMEM, only row 0 has data.
// Read row 0 from TMEM, write to GMEM.
// ================================================================
// SKIP PV for now — just verify softmax
// Read P from TMEM and write to output
if (wid == 0) {
// O is in the first 16 TMEM columns (cols 0..15)
float o_vals[HD];
for (int n = 0; n < HD / 8; n++) { // 2 iterations for HD=16
float p_vals[SK];
for (int n = 0; n < SK / 8; n++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) : "r"(tb + n*8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c];
if (lane == 0) for (int c=0;c<8;c++) p_vals[n*8+c] = tmp[c];
}
if (lane == 0) for (int d=0;d<HD;d++) o_out[d] = f32_to_bf16(o_vals[d]);
if (lane == 0) for (int j=0;j<SK;j++) o_out[j] = f32_to_bf16(p_vals[j]);
}
__syncthreads();