From c64bd7b8751e18aeb91ccc029db96dc04d99746f Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 08:43:39 +0000 Subject: [PATCH] debug: read Q/K directly from SMEM --- dsv4/kernels/attention/fmha_qk_verify.cuh | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/dsv4/kernels/attention/fmha_qk_verify.cuh b/dsv4/kernels/attention/fmha_qk_verify.cuh index 8b58c0f1..cde0ab2e 100644 --- a/dsv4/kernels/attention/fmha_qk_verify.cuh +++ b/dsv4/kernels/attention/fmha_qk_verify.cuh @@ -28,7 +28,7 @@ fmha_qk_verify( // Must be 16-byte aligned for UMMA extern __shared__ char sbuf[]; uint32_t* sTmemBase = (uint32_t*)sbuf; - bf16_t* sQ = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~15); + bf16_t* sQ = (bf16_t*)(((uintptr_t)(sbuf + 4) + 127) & ~127); bf16_t* sK = sQ + 128 * HD; // Load Q: (1, HD) padded to (128, HD) with zeros @@ -70,14 +70,20 @@ fmha_qk_verify( // Quick test: verify SMEM data was loaded correctly // Write Q[0,0..3] * K[0,0..3] dot product (scalar) to s_out[0] as sanity check if (tid == 0) { + // Read first few Q values directly from SMEM + float q0 = bf16_to_f32(sQ[0]); + float q1 = bf16_to_f32(sQ[1]); + float k0 = bf16_to_f32(sK[0]); + float k1 = bf16_to_f32(sK[1]); + float dot = 0; for (int d = 0; d < HD; d++) { dot += bf16_to_f32(sQ[d]) * bf16_to_f32(sK[d]); } s_out[0] = dot * scale; - s_out[1] = (float)(sQ_smem); - s_out[2] = (float)(sK_smem); - s_out[3] = (float)(tmem_base); + s_out[1] = q0; // first Q value + s_out[2] = k0; // first K value + s_out[3] = (float)(sQ_smem & 0xFFFF); // low 16 bits of SMEM address } __syncthreads(); // TEMPORARILY SKIP MMA — just verify SMEM loads