debug: verify TMEM r/w works before MMA
This commit is contained in:
@@ -67,6 +67,27 @@ fmha_qk_verify(
|
||||
uint64_t desc_q = make_umma_desc_mn_none(sQ_smem, HD);
|
||||
uint64_t desc_k = make_umma_desc_k_none(sK_smem, HD);
|
||||
|
||||
// Quick test: write known data to TMEM and read it back
|
||||
// This verifies TMEM read/write works in this kernel
|
||||
if (wid == 0) {
|
||||
tmem_store(tmem_base + 0, f32_to_u32(42.0f), f32_to_u32(43.0f), f32_to_u32(44.0f), f32_to_u32(45.0f));
|
||||
tmem_fence_store();
|
||||
}
|
||||
__syncthreads();
|
||||
if (tid == 0) {
|
||||
uint32_t u0, u1, u2, u3;
|
||||
tmem_load(tmem_base + 0, u0, u1, u2, u3);
|
||||
printf("[qk] TMEM verify: %f %f %f %f\n", u32_to_f32(u0), u32_to_f32(u1), u32_to_f32(u2), u32_to_f32(u3));
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Now zero it for the actual MMA
|
||||
if (wid == 0) {
|
||||
tmem_store(tmem_base + 0, 0, 0, 0, 0);
|
||||
tmem_fence_store();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// QK GEMM: S = Q @ K^T (SS, both SMEM → TMEM)
|
||||
// MMA: called by ONE lane (elect_one_sync pattern)
|
||||
if (tid == 0) printf("[qk] tmem_base=%u sQ_smem=0x%x sK_smem=0x%x\n", tmem_base, sQ_smem, sK_smem);
|
||||
|
||||
Reference in New Issue
Block a user