test: HD=16 FMHA softmax only (skip PV for now)
This commit is contained in:
@@ -134,54 +134,17 @@ test_fmha_hd16(const bf16_t* q, const bf16_t* k, const bf16_t* v,
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ================================================================
|
||||
// STEP 3: PV GEMM — P (TMEM) × V (SMEM) → O (TMEM)
|
||||
// tcgen05.mma TS: A from TMEM, B from SMEM, result to TMEM.
|
||||
// For each PV K-tile kt (K=16):
|
||||
// A = P[:, 16*kt:16*kt+16] → TMEM starting at tb + 16*kt
|
||||
// B = V[16*kt:16*kt+16, :] → SMEM at sV0 + kt * V_TILE_SZ
|
||||
// C = O (128, 16) → TMEM starting at tb (overwrite P's first 16 cols)
|
||||
// accumulate across K-tiles
|
||||
// ================================================================
|
||||
// For HD=16, O is (128, 16) in TMEM — 16 TMEM columns.
|
||||
// But we allocated 128 TMEM columns (from QK). O needs 16 columns.
|
||||
// We can reuse the same TMEM region (tb) for O since P is being consumed.
|
||||
// The O output TMEM base can be tb (reusing the same allocation).
|
||||
//
|
||||
// idesc for PV: M=128, N=16 (O is 128×16)
|
||||
// MMA_M = 128/16 = 8, MMA_N = 16/8 = 2
|
||||
uint32_t idesc_pv = make_idesc(BLOCK_MN, HD);
|
||||
|
||||
for (int kt = 0; kt < VKT; kt++) {
|
||||
bf16_t* sv = sV0 + kt * V_TILE_SZ;
|
||||
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sv), MMA_K_BF16);
|
||||
// P's K-tile in TMEM: columns [16*kt, 16*kt+15]
|
||||
// For tcgen05.mma TS, the A operand is a TMEM address.
|
||||
// The hardware reads 16 consecutive TMEM columns starting from tmem_a.
|
||||
uint32_t tmem_a = tb + kt * MMA_K_BF16; // 16 columns of P
|
||||
|
||||
if (tid == 0) umma_ts_f16(tb, tmem_a, dv, idesc_pv, kt > 0);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
__syncthreads();
|
||||
}
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
__syncthreads();
|
||||
|
||||
// ================================================================
|
||||
// STEP 4: Epilogue — O (TMEM) → normalize → BF16 → GMEM
|
||||
// O is (128, 16) in TMEM, only row 0 has data.
|
||||
// Read row 0 from TMEM, write to GMEM.
|
||||
// ================================================================
|
||||
// SKIP PV for now — just verify softmax
|
||||
// Read P from TMEM and write to output
|
||||
if (wid == 0) {
|
||||
// O is in the first 16 TMEM columns (cols 0..15)
|
||||
float o_vals[HD];
|
||||
for (int n = 0; n < HD / 8; n++) { // 2 iterations for HD=16
|
||||
float p_vals[SK];
|
||||
for (int n = 0; n < SK / 8; n++) {
|
||||
float tmp[8];
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) : "r"(tb + n*8));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c];
|
||||
if (lane == 0) for (int c=0;c<8;c++) p_vals[n*8+c] = tmp[c];
|
||||
}
|
||||
if (lane == 0) for (int d=0;d<HD;d++) o_out[d] = f32_to_bf16(o_vals[d]);
|
||||
if (lane == 0) for (int j=0;j<SK;j++) o_out[j] = f32_to_bf16(p_vals[j]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user