From 71b353577d3e36daa94183dec70d996bf2d25244 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 29 May 2026 18:35:00 +0000 Subject: [PATCH] =?UTF-8?q?fix:=20QK=20direct=20test=20=E2=80=94=20per-K-s?= =?UTF-8?q?ub-tile=20Q=20load=20(same=20as=20working=20kernel)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/unit/test_qk_direct.cu | 37 +++++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/tests/unit/test_qk_direct.cu b/tests/unit/test_qk_direct.cu index fad1945a..5d0281ab 100644 --- a/tests/unit/test_qk_direct.cu +++ b/tests/unit/test_qk_direct.cu @@ -55,21 +55,37 @@ test_qk_direct_kernel(float* __restrict__ out_s, __syncthreads(); uint32_t tb = *sTmemBase; - // Direct load Q: row 0 only - write_q_to_smem(sQ, q); - __syncthreads(); - + // Load Q and K per K-sub-tile (same pattern as working fmha_6warp_multirow) for (int kt = 0; kt < NKT; kt++) { - // Load K sub-tile (128, 16) directly + // Load Q sub-tile (128, 16) — only T rows have data + for (int i = tid; i < TILE_SZ; i += NTHREADS) sK[i] = 0; // reuse sK as temp for Q sub-tile + // Actually use a separate sQ0 + // sQ0 = sK for now (they share the same buffer since we load Q first) + // Let me use the same SMEM for both Q0 and K0 since they're used sequentially + // Actually we have sQ (128*HD) and sK (128*16). Use sK for Q0 since Q0 is (128,16) too. + // But we need both Q0 and K in SMEM at the same time for MMA... + // OK let's use the first (128,16) chunk of sQ as sQ0 + bf16_t* sQ0 = (bf16_t*)sbuf + 128 * 0; // overlaps with start of sQ + for (int i = tid; i < TILE_SZ; i += NTHREADS) sQ0[i] = 0; + __syncthreads(); + for (int r = 0; r < T; r++) { + for (int d = threadIdx.x % 32; d < MMA_K_BF16; d += 32) { + int full_d = kt * MMA_K_BF16 + d; + if (full_d < HD) { + int ck = d/8, lc = d%8, cm = r/8, lr = r%8; + sQ0[ck*16*64 + cm*64 + lr*8 + lc] = q[r * HD + full_d]; + } + } + } + __syncthreads(); + + // Load K sub-tile for (int i = tid; i < TILE_SZ; i += NTHREADS) sK[i] = 0; __syncthreads(); - // K in GMEM: (s_k, HD). Sub-tile at columns [kt*16, kt*16+16) for (int i = tid; i < SK * MMA_K_BF16; i += NTHREADS) { - int r = i / MMA_K_BF16; - int c = i % MMA_K_BF16; + int r = i / MMA_K_BF16, c = i % MMA_K_BF16; int gmem_c = kt * MMA_K_BF16 + c; bf16_t val = k[r * HD + gmem_c]; - // Write to canonical int core_mn = r / 8, core_k = c / 8; int local_r = r % 8, local_c = c % 8; int dst_idx = core_k * 16 * 64 + core_mn * 64 + local_r * 8 + local_c; @@ -79,8 +95,7 @@ test_qk_direct_kernel(float* __restrict__ out_s, if (is_mma_warp) { uint32_t idesc = make_idesc(128, 128); - uint32_t sq_kt = (uint32_t)__cvta_generic_to_shared(sQ) + kt * 128 * 32; - uint64_t dq = make_umma_desc_kmajor_none(sq_kt, 128); + uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), 128); uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK), 128); if (tid == 128) umma_ss_f16(tb, dq, dk, idesc, kt > 0); asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");