Fix SMEM allocation (was half the needed size) + re-enable full pipeline

This commit is contained in:
2026-05-28 14:16:43 +00:00
parent fa6c124163
commit acf17e001e

View File

@@ -90,8 +90,7 @@ test_fmha_ts_full(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
// ===== STEP 2: Softmax — SKIPPED FOR DEBUG =====
/*
// ===== STEP 2: Softmax =====
if (wid == 0) {
float s_vals[SK], row_max = -INFINITY;
for (int n = 0; n < SK / 8; n++) {
@@ -128,14 +127,14 @@ test_fmha_ts_full(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
tmem_fence_store();
}
__syncthreads();
*/
// ===== STEP 3: PV GEMM (TS) — SKIPPED FOR DEBUG =====
/*
// ===== STEP 3: PV GEMM (TS) =====
// P(128,128) × V(128,16) → O(128,16)
// 8 K-tiles: A = P cols [16*kt..16*kt+15), B = V K-tile kt
{
uint32_t idesc_pv = make_idesc(BLOCK_MN, HD);
for (int kt = 0; kt < 1; kt++) { // DEBUG: single PV K-tile
for (int kt = 0; kt < NKT_PV; kt++) {
uint32_t tmem_a = tb + kt * MMA_K_BF16; // A from P's kt-th 16 columns
bf16_t* sv = sV + kt * 256;
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sv), 16);
@@ -148,10 +147,9 @@ test_fmha_ts_full(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
}
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
*/
// ===== STEP 4: Epilogue — read S from TMEM (QK output) =====
// Just verify QK works by reading first 16 values
// ===== STEP 4: Epilogue — read O from TMEM =====
// MMA output is scaled by 0.5, so multiply by 2.0
if (wid == 0) {
float o_vals[HD];
for (int n = 0; n < HD / 8; n++) {
@@ -159,9 +157,9 @@ test_fmha_ts_full(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k,
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb + n*8));
: "r"(tb_o + n*8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c];
if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c] * 2.0f;
}
if (lane == 0) for (int d=0;d<HD;d++) o_out[d] = f32_to_bf16(o_vals[d]);
}
@@ -215,8 +213,9 @@ int main() {
cudaMemcpy(d_k, h_k, SK*HD*sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_v, h_v, HD*SK*sizeof(bf16_t), cudaMemcpyHostToDevice);
// SMEM: tmem_base(4) + pad(12) + sQ(2048) + sK(2048) + sV(8*256=2048) + alignment
int smem = (4+16 + TILE_SZ*2 + NKT_PV*256 + 256 + 127) & ~127;
// SMEM: all sizes in BYTES
// tmem_base(4) + pad(12) + sQ(TILE_SZ*2) + sK(TILE_SZ*2) + sV(NKT_PV*256*2) + extra + align
int smem = (4+16 + TILE_SZ*2 + TILE_SZ*2 + NKT_PV*256*2 + 256 + 127) & ~127;
test_fmha_ts_full<<<1, 128, smem>>>(d_q, d_k, d_v, d_o, d_o_scalar, SCALE);
cudaError_t err = cudaDeviceSynchronize();