From facd509c3c5c9fd0bfef482a3f3c9bdc5f3be837 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 11:59:08 +0000 Subject: [PATCH] test: remove sanity check (zeroing loop overwrites), fix verify offsets --- tests/unit/test_umma_qk_hd64.cu | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/tests/unit/test_umma_qk_hd64.cu b/tests/unit/test_umma_qk_hd64.cu index 76e85f00..1ead976f 100644 --- a/tests/unit/test_umma_qk_hd64.cu +++ b/tests/unit/test_umma_qk_hd64.cu @@ -46,7 +46,6 @@ test_umma_qk_hd64(const bf16_t* q, const bf16_t* k, __syncthreads(); uint32_t tb = *sTmemBase; - // Load Q and K into SMEM in canonical layout // Zero all first for (int i = tid; i < 128 * hd; i += N_WARPS * 32) { sQ[i] = 0; @@ -54,19 +53,6 @@ test_umma_qk_hd64(const bf16_t* q, const bf16_t* k, } __syncthreads(); - // Sanity check: write to sQ[0] - if (tid == 0) { - uint16_t one_bf16 = f32_to_bf16(1.0f); - sQ[0] = one_bf16; - } - __syncthreads(); - if (tid == 0) { - uint16_t val = sQ[0]; - float fval = bf16_to_f32(val); - s_out[250] = fval; // Should be 1.0 - } - __syncthreads(); - // Write Q (1, hd) to row 0 of sQ in canonical layout for (int d = tid; d < hd; d += N_WARPS * 32) { int core_k = d / 8, local_c = d % 8; @@ -84,13 +70,14 @@ test_umma_qk_hd64(const bf16_t* q, const bf16_t* k, __syncthreads(); // Verify SMEM data for first K-tile (columns 0-15) + // In canonical layout, Q[d] for row 0 is at core_k * 16 * 64 + local_c if (tid == 0) { - // Q row 0, d=0..7: core(0,0) at offset 0, local_r=0, local_c=d + // Q row 0, d=0..7: core_k=0, local_c=d → sQ[d] for (int d = 0; d < 8; d++) - s_out[200+d] = bf16_to_f32(sQ[d]); // core(0,0), row 0, col d - // Q row 0, d=16..23: core(0,2) at offset 2*1024 = 2048, local_r=0, local_c=d-16 + s_out[200+d] = bf16_to_f32(sQ[d]); + // Q row 0, d=8..15: core_k=1, local_c=d-8 → sQ[1024 + d-8] for (int d = 0; d < 8; d++) - s_out[208+d] = bf16_to_f32(sQ[2048 + d]); // core(0,2), row 0, col 0..7 + s_out[208+d] = bf16_to_f32(sQ[1024 + d]); } __syncthreads(); uint32_t sQ_smem = __cvta_generic_to_shared(sQ);