test: clean SMEM write loops for HD=64
This commit is contained in:
@@ -47,9 +47,6 @@ test_umma_qk_hd64(const bf16_t* q, const bf16_t* k,
|
||||
uint32_t tb = *sTmemBase;
|
||||
|
||||
// Load Q and K into SMEM in canonical layout
|
||||
// Using the template with HD as a parameter
|
||||
// write_q_to_smem and write_k_to_smem need to work with hd=64
|
||||
// For now, use explicit loops
|
||||
// Zero all first
|
||||
for (int i = tid; i < 128 * hd; i += N_WARPS * 32) {
|
||||
sQ[i] = 0;
|
||||
@@ -58,13 +55,11 @@ test_umma_qk_hd64(const bf16_t* q, const bf16_t* k,
|
||||
__syncthreads();
|
||||
|
||||
// Write Q (1, hd) to row 0 of sQ in canonical layout
|
||||
// Canonical: core(g, c) at offset c * 16 * 64 + g * 64 + local_r * 8 + local_c
|
||||
for (int d = tid; d < hd; d += N_WARPS * 32) {
|
||||
int core_k = d / 8, local_c = d % 8;
|
||||
int idx = core_k * 16 * 64 + local_c; // tile_mn=0, local_r=0
|
||||
sQ[idx] = q[d];
|
||||
}
|
||||
|
||||
// Write K (sk, hd) to sK in canonical layout
|
||||
for (int i = tid; i < sk * hd; i += N_WARPS * 32) {
|
||||
int r = i / hd, c = i % hd;
|
||||
|
||||
Reference in New Issue
Block a user