diff --git a/tests/unit/test_umma_qk_hd64.cu b/tests/unit/test_umma_qk_hd64.cu index 66eb50f0..121804dc 100644 --- a/tests/unit/test_umma_qk_hd64.cu +++ b/tests/unit/test_umma_qk_hd64.cu @@ -75,7 +75,16 @@ test_umma_qk_hd64(const bf16_t* q, const bf16_t* k, } __syncthreads(); - // Construct base descriptors for Q and K + // Verify SMEM data for first K-tile (columns 0-15) + if (tid == 0) { + // Q row 0, d=0..7: core(0,0) at offset 0, local_r=0, local_c=d + for (int d = 0; d < 8; d++) + s_out[200+d] = bf16_to_f32(sQ[d]); // core(0,0), row 0, col d + // Q row 0, d=16..23: core(0,2) at offset 2*1024 = 2048, local_r=0, local_c=d-16 + for (int d = 0; d < 8; d++) + s_out[208+d] = bf16_to_f32(sQ[2048 + d]); // core(0,2), row 0, col 0..7 + } + __syncthreads(); uint32_t sQ_smem = __cvta_generic_to_shared(sQ); uint32_t sK_smem = __cvta_generic_to_shared(sK);