debug: add printf to TMEM kernel to find hang point

This commit is contained in:
2026-05-28 07:39:53 +00:00
parent 44fb04fa1f
commit 0ddcc6bafd
2 changed files with 9 additions and 14 deletions

View File

@@ -105,24 +105,17 @@ fmha_decode_tmem(
// Read tmem_base from SMEM (written by alloc)
uint32_t tmem_base = *sTmemBase;
if (tid == 0) printf("[tmem] base=%u, n=%d, o_cols=%d\n", tmem_base, TMEM_N, TMEM_O_COLS);
// ================================================================
// Initialize TMEM O to zero — warp-collective
// Each lane writes 4 uint32_t (4 FP32) per column.
// Only lane 0's values matter for row 0, but ALL lanes must participate.
// ================================================================
{
// Distribute columns across warp 0's lanes
// 32 lanes, each handles some columns
if (wid == 0) {
for (int col = lane; col < TMEM_N; col += WARP) {
tmem_store(tmem_base + col, 0, 0, 0, 0);
}
// Fence to ensure stores are visible before MMA or loads
if (wid == 0) {
tmem_fence_store();
}
tmem_fence_store();
}
__syncthreads();
if (tid == 0) printf("[tmem] zeroed OK\n");
// ================================================================
// Attention computation — thread 0 does the math, warp 0 does TMEM
@@ -210,6 +203,7 @@ fmha_decode_tmem(
sRowSums[0] = row_sum;
}
__syncthreads();
if (tid == 0) printf("[tmem] attention computed, row_sum=%f, max=%f\n", sRowSums[0], sPvBuf[0]);
// ================================================================
// One-way Correction Epilogue: SMEM → TMEM → regs → normalize → GMEM
@@ -245,6 +239,7 @@ fmha_decode_tmem(
tmem_fence_store();
}
__syncthreads();
if (tid == 0) printf("[tmem] wrote to TMEM OK\n");
// Step 2: Read from TMEM to registers (warp 0, warp-collective)
if (wid == 0) {

View File

@@ -117,7 +117,7 @@ int main() {
printf("=== FMHA SM100 Decode Kernel Test Suite ===\n\n");
int all_pass = 1;
int head_dims[] = {64, 128};
int head_dims[] = {64};
int s_ks[] = {128};
for (int t = 0; t < 2; t++) {
@@ -154,8 +154,8 @@ int main() {
cudaMemcpy(dk,hkb,B*sk*HD*2,cudaMemcpyHostToDevice);
cudaMemcpy(dv,hvb,B*HD*sk*2,cudaMemcpyHostToDevice);
all_pass &= test_kernel("reference", HD, sk, scale, dq,dk,dv,do_,d_lse,ho_ref,B,H);
// all_pass &= test_kernel("tmem_epilogue", HD, sk, scale, dq,dk,dv,do_,d_lse,ho_ref,B,H);
// all_pass &= test_kernel("reference", HD, sk, scale, dq,dk,dv,do_,d_lse,ho_ref,B,H);
all_pass &= test_kernel("tmem_epilogue", HD, sk, scale, dq,dk,dv,do_,d_lse,ho_ref,B,H);
cudaFree(dq);cudaFree(dk);cudaFree(dv);cudaFree(do_);cudaFree(d_lse);
free(hq);free(hk);free(hv);free(ho_ref);free(hqb);free(hkb);free(hvb);