diff --git a/tests/unit/test_umma_qk_hd64.cu b/tests/unit/test_umma_qk_hd64.cu index 719027f3..327c6cc7 100644 --- a/tests/unit/test_umma_qk_hd64.cu +++ b/tests/unit/test_umma_qk_hd64.cu @@ -47,9 +47,6 @@ test_umma_qk_hd64(const bf16_t* q, const bf16_t* k, uint32_t tb = *sTmemBase; // Load Q and K into SMEM in canonical layout - // Using the template with HD as a parameter - // write_q_to_smem and write_k_to_smem need to work with hd=64 - // For now, use explicit loops // Zero all first for (int i = tid; i < 128 * hd; i += N_WARPS * 32) { sQ[i] = 0; @@ -58,13 +55,11 @@ test_umma_qk_hd64(const bf16_t* q, const bf16_t* k, __syncthreads(); // Write Q (1, hd) to row 0 of sQ in canonical layout - // Canonical: core(g, c) at offset c * 16 * 64 + g * 64 + local_r * 8 + local_c for (int d = tid; d < hd; d += N_WARPS * 32) { int core_k = d / 8, local_c = d % 8; int idx = core_k * 16 * 64 + local_c; // tile_mn=0, local_r=0 sQ[idx] = q[d]; } - // Write K (sk, hd) to sK in canonical layout for (int i = tid; i < sk * hd; i += N_WARPS * 32) { int r = i / hd, c = i % hd;