From acf17e001e2f1983516defe3cf6bca5d098f06fd Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 14:16:43 +0000 Subject: [PATCH] Fix SMEM allocation (was half the needed size) + re-enable full pipeline --- tests/unit/test_fmha_ts_full.cu | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/tests/unit/test_fmha_ts_full.cu b/tests/unit/test_fmha_ts_full.cu index 1a876999..85787ed5 100644 --- a/tests/unit/test_fmha_ts_full.cu +++ b/tests/unit/test_fmha_ts_full.cu @@ -90,8 +90,7 @@ test_fmha_ts_full(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k, asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); __syncthreads(); - // ===== STEP 2: Softmax — SKIPPED FOR DEBUG ===== - /* + // ===== STEP 2: Softmax ===== if (wid == 0) { float s_vals[SK], row_max = -INFINITY; for (int n = 0; n < SK / 8; n++) { @@ -128,14 +127,14 @@ test_fmha_ts_full(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k, tmem_fence_store(); } __syncthreads(); - */ - // ===== STEP 3: PV GEMM (TS) — SKIPPED FOR DEBUG ===== - /* + // ===== STEP 3: PV GEMM (TS) ===== + // P(128,128) × V(128,16) → O(128,16) + // 8 K-tiles: A = P cols [16*kt..16*kt+15), B = V K-tile kt { uint32_t idesc_pv = make_idesc(BLOCK_MN, HD); - for (int kt = 0; kt < 1; kt++) { // DEBUG: single PV K-tile + for (int kt = 0; kt < NKT_PV; kt++) { uint32_t tmem_a = tb + kt * MMA_K_BF16; // A from P's kt-th 16 columns bf16_t* sv = sV + kt * 256; uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sv), 16); @@ -148,10 +147,9 @@ test_fmha_ts_full(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k, } asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); __syncthreads(); - */ - // ===== STEP 4: Epilogue — read S from TMEM (QK output) ===== - // Just verify QK works by reading first 16 values + // ===== STEP 4: Epilogue — read O from TMEM ===== + // MMA output is scaled by 0.5, so multiply by 2.0 if (wid == 0) { float o_vals[HD]; for (int n = 0; n < HD / 8; n++) { @@ -159,9 +157,9 @@ test_fmha_ts_full(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k, asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) - : "r"(tb + n*8)); + : "r"(tb_o + n*8)); asm volatile("tcgen05.wait::ld.sync.aligned;"); - if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c]; + if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c] * 2.0f; } if (lane == 0) for (int d=0;d>>(d_q, d_k, d_v, d_o, d_o_scalar, SCALE); cudaError_t err = cudaDeviceSynchronize();