test: clean SMEM write loops for HD=64

This commit is contained in:
2026-05-28 11:52:51 +00:00
parent 2ffbfda47d
commit 8936a2dec7

View File

@@ -47,9 +47,6 @@ test_umma_qk_hd64(const bf16_t* q, const bf16_t* k,
uint32_t tb = *sTmemBase;
// Load Q and K into SMEM in canonical layout
// Using the template with HD as a parameter
// write_q_to_smem and write_k_to_smem need to work with hd=64
// For now, use explicit loops
// Zero all first
for (int i = tid; i < 128 * hd; i += N_WARPS * 32) {
sQ[i] = 0;
@@ -58,13 +55,11 @@ test_umma_qk_hd64(const bf16_t* q, const bf16_t* k,
__syncthreads();
// Write Q (1, hd) to row 0 of sQ in canonical layout
// Canonical: core(g, c) at offset c * 16 * 64 + g * 64 + local_r * 8 + local_c
for (int d = tid; d < hd; d += N_WARPS * 32) {
int core_k = d / 8, local_c = d % 8;
int idx = core_k * 16 * 64 + local_c; // tile_mn=0, local_r=0
sQ[idx] = q[d];
}
// Write K (sk, hd) to sK in canonical layout
for (int i = tid; i < sk * hd; i += N_WARPS * 32) {
int r = i / hd, c = i % hd;