fix: QK direct test — per-K-sub-tile Q load (same as working kernel)
This commit is contained in:
@@ -55,21 +55,37 @@ test_qk_direct_kernel(float* __restrict__ out_s,
|
||||
__syncthreads();
|
||||
uint32_t tb = *sTmemBase;
|
||||
|
||||
// Direct load Q: row 0 only
|
||||
write_q_to_smem<HD>(sQ, q);
|
||||
__syncthreads();
|
||||
|
||||
// Load Q and K per K-sub-tile (same pattern as working fmha_6warp_multirow)
|
||||
for (int kt = 0; kt < NKT; kt++) {
|
||||
// Load K sub-tile (128, 16) directly
|
||||
// Load Q sub-tile (128, 16) — only T rows have data
|
||||
for (int i = tid; i < TILE_SZ; i += NTHREADS) sK[i] = 0; // reuse sK as temp for Q sub-tile
|
||||
// Actually use a separate sQ0
|
||||
// sQ0 = sK for now (they share the same buffer since we load Q first)
|
||||
// Let me use the same SMEM for both Q0 and K0 since they're used sequentially
|
||||
// Actually we have sQ (128*HD) and sK (128*16). Use sK for Q0 since Q0 is (128,16) too.
|
||||
// But we need both Q0 and K in SMEM at the same time for MMA...
|
||||
// OK let's use the first (128,16) chunk of sQ as sQ0
|
||||
bf16_t* sQ0 = (bf16_t*)sbuf + 128 * 0; // overlaps with start of sQ
|
||||
for (int i = tid; i < TILE_SZ; i += NTHREADS) sQ0[i] = 0;
|
||||
__syncthreads();
|
||||
for (int r = 0; r < T; r++) {
|
||||
for (int d = threadIdx.x % 32; d < MMA_K_BF16; d += 32) {
|
||||
int full_d = kt * MMA_K_BF16 + d;
|
||||
if (full_d < HD) {
|
||||
int ck = d/8, lc = d%8, cm = r/8, lr = r%8;
|
||||
sQ0[ck*16*64 + cm*64 + lr*8 + lc] = q[r * HD + full_d];
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Load K sub-tile
|
||||
for (int i = tid; i < TILE_SZ; i += NTHREADS) sK[i] = 0;
|
||||
__syncthreads();
|
||||
// K in GMEM: (s_k, HD). Sub-tile at columns [kt*16, kt*16+16)
|
||||
for (int i = tid; i < SK * MMA_K_BF16; i += NTHREADS) {
|
||||
int r = i / MMA_K_BF16;
|
||||
int c = i % MMA_K_BF16;
|
||||
int r = i / MMA_K_BF16, c = i % MMA_K_BF16;
|
||||
int gmem_c = kt * MMA_K_BF16 + c;
|
||||
bf16_t val = k[r * HD + gmem_c];
|
||||
// Write to canonical
|
||||
int core_mn = r / 8, core_k = c / 8;
|
||||
int local_r = r % 8, local_c = c % 8;
|
||||
int dst_idx = core_k * 16 * 64 + core_mn * 64 + local_r * 8 + local_c;
|
||||
@@ -79,8 +95,7 @@ test_qk_direct_kernel(float* __restrict__ out_s,
|
||||
|
||||
if (is_mma_warp) {
|
||||
uint32_t idesc = make_idesc(128, 128);
|
||||
uint32_t sq_kt = (uint32_t)__cvta_generic_to_shared(sQ) + kt * 128 * 32;
|
||||
uint64_t dq = make_umma_desc_kmajor_none(sq_kt, 128);
|
||||
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), 128);
|
||||
uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK), 128);
|
||||
if (tid == 128) umma_ss_f16(tb, dq, dk, idesc, kt > 0);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
|
||||
Reference in New Issue
Block a user