diff --git a/dsv4/kernels/attention/fmha_epilogue_sm100.cuh b/dsv4/kernels/attention/fmha_epilogue_sm100.cuh index c413ce98..6a2e7db1 100644 --- a/dsv4/kernels/attention/fmha_epilogue_sm100.cuh +++ b/dsv4/kernels/attention/fmha_epilogue_sm100.cuh @@ -105,24 +105,17 @@ fmha_decode_tmem( // Read tmem_base from SMEM (written by alloc) uint32_t tmem_base = *sTmemBase; + if (tid == 0) printf("[tmem] base=%u, n=%d, o_cols=%d\n", tmem_base, TMEM_N, TMEM_O_COLS); - // ================================================================ // Initialize TMEM O to zero — warp-collective - // Each lane writes 4 uint32_t (4 FP32) per column. - // Only lane 0's values matter for row 0, but ALL lanes must participate. - // ================================================================ - { - // Distribute columns across warp 0's lanes - // 32 lanes, each handles some columns + if (wid == 0) { for (int col = lane; col < TMEM_N; col += WARP) { tmem_store(tmem_base + col, 0, 0, 0, 0); } - // Fence to ensure stores are visible before MMA or loads - if (wid == 0) { - tmem_fence_store(); - } + tmem_fence_store(); } __syncthreads(); + if (tid == 0) printf("[tmem] zeroed OK\n"); // ================================================================ // Attention computation — thread 0 does the math, warp 0 does TMEM @@ -210,6 +203,7 @@ fmha_decode_tmem( sRowSums[0] = row_sum; } __syncthreads(); + if (tid == 0) printf("[tmem] attention computed, row_sum=%f, max=%f\n", sRowSums[0], sPvBuf[0]); // ================================================================ // One-way Correction Epilogue: SMEM → TMEM → regs → normalize → GMEM @@ -245,6 +239,7 @@ fmha_decode_tmem( tmem_fence_store(); } __syncthreads(); + if (tid == 0) printf("[tmem] wrote to TMEM OK\n"); // Step 2: Read from TMEM to registers (warp 0, warp-collective) if (wid == 0) { diff --git a/tests/unit/test_fmha_sm100_standalone.cu b/tests/unit/test_fmha_sm100_standalone.cu index 9ad4c830..0487b1a6 100644 --- a/tests/unit/test_fmha_sm100_standalone.cu +++ b/tests/unit/test_fmha_sm100_standalone.cu @@ -117,7 +117,7 @@ int main() { printf("=== FMHA SM100 Decode Kernel Test Suite ===\n\n"); int all_pass = 1; - int head_dims[] = {64, 128}; + int head_dims[] = {64}; int s_ks[] = {128}; for (int t = 0; t < 2; t++) { @@ -154,8 +154,8 @@ int main() { cudaMemcpy(dk,hkb,B*sk*HD*2,cudaMemcpyHostToDevice); cudaMemcpy(dv,hvb,B*HD*sk*2,cudaMemcpyHostToDevice); - all_pass &= test_kernel("reference", HD, sk, scale, dq,dk,dv,do_,d_lse,ho_ref,B,H); - // all_pass &= test_kernel("tmem_epilogue", HD, sk, scale, dq,dk,dv,do_,d_lse,ho_ref,B,H); + // all_pass &= test_kernel("reference", HD, sk, scale, dq,dk,dv,do_,d_lse,ho_ref,B,H); + all_pass &= test_kernel("tmem_epilogue", HD, sk, scale, dq,dk,dv,do_,d_lse,ho_ref,B,H); cudaFree(dq);cudaFree(dk);cudaFree(dv);cudaFree(do_);cudaFree(d_lse); free(hq);free(hk);free(hv);free(ho_ref);free(hqb);free(hkb);free(hvb);