debug: add printf to TMEM kernel to find hang point
This commit is contained in:
@@ -105,24 +105,17 @@ fmha_decode_tmem(
|
||||
|
||||
// Read tmem_base from SMEM (written by alloc)
|
||||
uint32_t tmem_base = *sTmemBase;
|
||||
if (tid == 0) printf("[tmem] base=%u, n=%d, o_cols=%d\n", tmem_base, TMEM_N, TMEM_O_COLS);
|
||||
|
||||
// ================================================================
|
||||
// Initialize TMEM O to zero — warp-collective
|
||||
// Each lane writes 4 uint32_t (4 FP32) per column.
|
||||
// Only lane 0's values matter for row 0, but ALL lanes must participate.
|
||||
// ================================================================
|
||||
{
|
||||
// Distribute columns across warp 0's lanes
|
||||
// 32 lanes, each handles some columns
|
||||
if (wid == 0) {
|
||||
for (int col = lane; col < TMEM_N; col += WARP) {
|
||||
tmem_store(tmem_base + col, 0, 0, 0, 0);
|
||||
}
|
||||
// Fence to ensure stores are visible before MMA or loads
|
||||
if (wid == 0) {
|
||||
tmem_fence_store();
|
||||
}
|
||||
tmem_fence_store();
|
||||
}
|
||||
__syncthreads();
|
||||
if (tid == 0) printf("[tmem] zeroed OK\n");
|
||||
|
||||
// ================================================================
|
||||
// Attention computation — thread 0 does the math, warp 0 does TMEM
|
||||
@@ -210,6 +203,7 @@ fmha_decode_tmem(
|
||||
sRowSums[0] = row_sum;
|
||||
}
|
||||
__syncthreads();
|
||||
if (tid == 0) printf("[tmem] attention computed, row_sum=%f, max=%f\n", sRowSums[0], sPvBuf[0]);
|
||||
|
||||
// ================================================================
|
||||
// One-way Correction Epilogue: SMEM → TMEM → regs → normalize → GMEM
|
||||
@@ -245,6 +239,7 @@ fmha_decode_tmem(
|
||||
tmem_fence_store();
|
||||
}
|
||||
__syncthreads();
|
||||
if (tid == 0) printf("[tmem] wrote to TMEM OK\n");
|
||||
|
||||
// Step 2: Read from TMEM to registers (warp 0, warp-collective)
|
||||
if (wid == 0) {
|
||||
|
||||
@@ -117,7 +117,7 @@ int main() {
|
||||
printf("=== FMHA SM100 Decode Kernel Test Suite ===\n\n");
|
||||
|
||||
int all_pass = 1;
|
||||
int head_dims[] = {64, 128};
|
||||
int head_dims[] = {64};
|
||||
int s_ks[] = {128};
|
||||
|
||||
for (int t = 0; t < 2; t++) {
|
||||
@@ -154,8 +154,8 @@ int main() {
|
||||
cudaMemcpy(dk,hkb,B*sk*HD*2,cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(dv,hvb,B*HD*sk*2,cudaMemcpyHostToDevice);
|
||||
|
||||
all_pass &= test_kernel("reference", HD, sk, scale, dq,dk,dv,do_,d_lse,ho_ref,B,H);
|
||||
// all_pass &= test_kernel("tmem_epilogue", HD, sk, scale, dq,dk,dv,do_,d_lse,ho_ref,B,H);
|
||||
// all_pass &= test_kernel("reference", HD, sk, scale, dq,dk,dv,do_,d_lse,ho_ref,B,H);
|
||||
all_pass &= test_kernel("tmem_epilogue", HD, sk, scale, dq,dk,dv,do_,d_lse,ho_ref,B,H);
|
||||
|
||||
cudaFree(dq);cudaFree(dk);cudaFree(dv);cudaFree(do_);cudaFree(d_lse);
|
||||
free(hq);free(hk);free(hv);free(ho_ref);free(hqb);free(hkb);free(hvb);
|
||||
|
||||
Reference in New Issue
Block a user