debug: verify TMEM r/w works before MMA

This commit is contained in:
2026-05-28 08:39:12 +00:00
parent a9d71ff6ab
commit 53139d24bf

View File

@@ -67,6 +67,27 @@ fmha_qk_verify(
uint64_t desc_q = make_umma_desc_mn_none(sQ_smem, HD);
uint64_t desc_k = make_umma_desc_k_none(sK_smem, HD);
// Quick test: write known data to TMEM and read it back
// This verifies TMEM read/write works in this kernel
if (wid == 0) {
tmem_store(tmem_base + 0, f32_to_u32(42.0f), f32_to_u32(43.0f), f32_to_u32(44.0f), f32_to_u32(45.0f));
tmem_fence_store();
}
__syncthreads();
if (tid == 0) {
uint32_t u0, u1, u2, u3;
tmem_load(tmem_base + 0, u0, u1, u2, u3);
printf("[qk] TMEM verify: %f %f %f %f\n", u32_to_f32(u0), u32_to_f32(u1), u32_to_f32(u2), u32_to_f32(u3));
}
__syncthreads();
// Now zero it for the actual MMA
if (wid == 0) {
tmem_store(tmem_base + 0, 0, 0, 0, 0);
tmem_fence_store();
}
__syncthreads();
// QK GEMM: S = Q @ K^T (SS, both SMEM → TMEM)
// MMA: called by ONE lane (elect_one_sync pattern)
if (tid == 0) printf("[qk] tmem_base=%u sQ_smem=0x%x sK_smem=0x%x\n", tmem_base, sQ_smem, sK_smem);