fix: QK direct test — per-K-sub-tile Q load (same as working kernel)

This commit is contained in:
2026-05-29 18:35:00 +00:00
parent 35d0596893
commit 71b353577d

View File

@@ -55,21 +55,37 @@ test_qk_direct_kernel(float* __restrict__ out_s,
__syncthreads();
uint32_t tb = *sTmemBase;
// Direct load Q: row 0 only
write_q_to_smem<HD>(sQ, q);
__syncthreads();
// Load Q and K per K-sub-tile (same pattern as working fmha_6warp_multirow)
for (int kt = 0; kt < NKT; kt++) {
// Load K sub-tile (128, 16) directly
// Load Q sub-tile (128, 16) — only T rows have data
for (int i = tid; i < TILE_SZ; i += NTHREADS) sK[i] = 0; // reuse sK as temp for Q sub-tile
// Actually use a separate sQ0
// sQ0 = sK for now (they share the same buffer since we load Q first)
// Let me use the same SMEM for both Q0 and K0 since they're used sequentially
// Actually we have sQ (128*HD) and sK (128*16). Use sK for Q0 since Q0 is (128,16) too.
// But we need both Q0 and K in SMEM at the same time for MMA...
// OK let's use the first (128,16) chunk of sQ as sQ0
bf16_t* sQ0 = (bf16_t*)sbuf + 128 * 0; // overlaps with start of sQ
for (int i = tid; i < TILE_SZ; i += NTHREADS) sQ0[i] = 0;
__syncthreads();
for (int r = 0; r < T; r++) {
for (int d = threadIdx.x % 32; d < MMA_K_BF16; d += 32) {
int full_d = kt * MMA_K_BF16 + d;
if (full_d < HD) {
int ck = d/8, lc = d%8, cm = r/8, lr = r%8;
sQ0[ck*16*64 + cm*64 + lr*8 + lc] = q[r * HD + full_d];
}
}
}
__syncthreads();
// Load K sub-tile
for (int i = tid; i < TILE_SZ; i += NTHREADS) sK[i] = 0;
__syncthreads();
// K in GMEM: (s_k, HD). Sub-tile at columns [kt*16, kt*16+16)
for (int i = tid; i < SK * MMA_K_BF16; i += NTHREADS) {
int r = i / MMA_K_BF16;
int c = i % MMA_K_BF16;
int r = i / MMA_K_BF16, c = i % MMA_K_BF16;
int gmem_c = kt * MMA_K_BF16 + c;
bf16_t val = k[r * HD + gmem_c];
// Write to canonical
int core_mn = r / 8, core_k = c / 8;
int local_r = r % 8, local_c = c % 8;
int dst_idx = core_k * 16 * 64 + core_mn * 64 + local_r * 8 + local_c;
@@ -79,8 +95,7 @@ test_qk_direct_kernel(float* __restrict__ out_s,
if (is_mma_warp) {
uint32_t idesc = make_idesc(128, 128);
uint32_t sq_kt = (uint32_t)__cvta_generic_to_shared(sQ) + kt * 128 * 32;
uint64_t dq = make_umma_desc_kmajor_none(sq_kt, 128);
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), 128);
uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK), 128);
if (tid == 128) umma_ss_f16(tb, dq, dk, idesc, kt > 0);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");