debug: read Q/K directly from SMEM
This commit is contained in:
@@ -28,7 +28,7 @@ fmha_qk_verify(
|
||||
// Must be 16-byte aligned for UMMA
|
||||
extern __shared__ char sbuf[];
|
||||
uint32_t* sTmemBase = (uint32_t*)sbuf;
|
||||
bf16_t* sQ = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~15);
|
||||
bf16_t* sQ = (bf16_t*)(((uintptr_t)(sbuf + 4) + 127) & ~127);
|
||||
bf16_t* sK = sQ + 128 * HD;
|
||||
|
||||
// Load Q: (1, HD) padded to (128, HD) with zeros
|
||||
@@ -70,14 +70,20 @@ fmha_qk_verify(
|
||||
// Quick test: verify SMEM data was loaded correctly
|
||||
// Write Q[0,0..3] * K[0,0..3] dot product (scalar) to s_out[0] as sanity check
|
||||
if (tid == 0) {
|
||||
// Read first few Q values directly from SMEM
|
||||
float q0 = bf16_to_f32(sQ[0]);
|
||||
float q1 = bf16_to_f32(sQ[1]);
|
||||
float k0 = bf16_to_f32(sK[0]);
|
||||
float k1 = bf16_to_f32(sK[1]);
|
||||
|
||||
float dot = 0;
|
||||
for (int d = 0; d < HD; d++) {
|
||||
dot += bf16_to_f32(sQ[d]) * bf16_to_f32(sK[d]);
|
||||
}
|
||||
s_out[0] = dot * scale;
|
||||
s_out[1] = (float)(sQ_smem);
|
||||
s_out[2] = (float)(sK_smem);
|
||||
s_out[3] = (float)(tmem_base);
|
||||
s_out[1] = q0; // first Q value
|
||||
s_out[2] = k0; // first K value
|
||||
s_out[3] = (float)(sQ_smem & 0xFFFF); // low 16 bits of SMEM address
|
||||
}
|
||||
__syncthreads();
|
||||
// TEMPORARILY SKIP MMA — just verify SMEM loads
|
||||
|
||||
Reference in New Issue
Block a user