debug: read Q/K directly from SMEM

This commit is contained in:
2026-05-28 08:43:39 +00:00
parent 58b610c96c
commit c64bd7b875

View File

@@ -28,7 +28,7 @@ fmha_qk_verify(
// Must be 16-byte aligned for UMMA
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sQ = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~15);
bf16_t* sQ = (bf16_t*)(((uintptr_t)(sbuf + 4) + 127) & ~127);
bf16_t* sK = sQ + 128 * HD;
// Load Q: (1, HD) padded to (128, HD) with zeros
@@ -70,14 +70,20 @@ fmha_qk_verify(
// Quick test: verify SMEM data was loaded correctly
// Write Q[0,0..3] * K[0,0..3] dot product (scalar) to s_out[0] as sanity check
if (tid == 0) {
// Read first few Q values directly from SMEM
float q0 = bf16_to_f32(sQ[0]);
float q1 = bf16_to_f32(sQ[1]);
float k0 = bf16_to_f32(sK[0]);
float k1 = bf16_to_f32(sK[1]);
float dot = 0;
for (int d = 0; d < HD; d++) {
dot += bf16_to_f32(sQ[d]) * bf16_to_f32(sK[d]);
}
s_out[0] = dot * scale;
s_out[1] = (float)(sQ_smem);
s_out[2] = (float)(sK_smem);
s_out[3] = (float)(tmem_base);
s_out[1] = q0; // first Q value
s_out[2] = k0; // first K value
s_out[3] = (float)(sQ_smem & 0xFFFF); // low 16 bits of SMEM address
}
__syncthreads();
// TEMPORARILY SKIP MMA — just verify SMEM loads