From 5542a9da00463682eba7eeeed986edfcd99b75f7 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 29 May 2026 18:57:42 +0000 Subject: [PATCH] debug: V loaded directly from GMEM (not TMA) to isolate PV issue --- dsv4/kernels/attention/fmha_6warp_tma.cuh | 28 +++++++---------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/dsv4/kernels/attention/fmha_6warp_tma.cuh b/dsv4/kernels/attention/fmha_6warp_tma.cuh index 2a46d697..b4803309 100644 --- a/dsv4/kernels/attention/fmha_6warp_tma.cuh +++ b/dsv4/kernels/attention/fmha_6warp_tma.cuh @@ -30,7 +30,7 @@ namespace dsv4::kernels::attention { struct FmhaTmaParams { const bf16_t* __restrict__ q; const bf16_t* __restrict__ k; - const bf16_t* __restrict__ v; + const bf16_t* __restrict__ v; // direct GMEM pointer for V bf16_t* __restrict__ o; float* __restrict__ lse; int s_k, T; @@ -69,6 +69,7 @@ fmha_tma_kernel(FmhaTmaParams params) { bf16_t* __restrict__ q_head = (bf16_t*)params.q + head_idx * params.q_head_stride + batch_idx * params.q_batch_stride; bf16_t* __restrict__ o_head = params.o + head_idx * params.o_head_stride + batch_idx * params.o_batch_stride; + const bf16_t* __restrict__ v_head = (const bf16_t*)params.v + head_idx * params.v_head_stride + batch_idx * params.v_batch_stride; float* __restrict__ lse_head = params.lse ? params.lse + head_idx * params.lse_head_stride + batch_idx * params.lse_batch_stride : nullptr; CUtensorMap* __restrict__ tma_k = params.tma_k; @@ -241,28 +242,15 @@ fmha_tma_kernel(FmhaTmaParams params) { } __syncthreads(); - // V sub-tile: TMA load - if (wid == 0 && lane == 0) { - tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)tma_v, mbar_addr, col_start, d_base); - tma_mbarrier_arrive_expect_tx(mbar_addr, TMA_TILE_BYTES); - } - tma_mbarrier_wait(mbar_addr, phase); phase ^= 1; - __syncthreads(); - - // Convert V: TMA loaded (16, 128) row-major → (16, 16) canonical with CORES_MN_V=2 - // V in GMEM is (HD, s_k). The TMA tile covers rows [d_base, d_base+16) and all cols. - // We need V_sub = V[d_base:d_base+16, col_start:col_start+16] - // From sTmaBuf (16, 128): element at (dd, r) = sTmaBuf[dd * 128 + r] - // where dd is the head-dim index (0..15) and r is the sequence index (0..127) - // We only need r in [col_start, col_start+16) + // V sub-tile: direct load from GMEM (same as working test_fmha_gen) + const bf16_t* __restrict__ v_head = (const bf16_t*)params.v; for (int i = tid; i < V_SUB_SZ; i += 128) sV[i] = 0; - for (int dd = tid / 32; dd < 16; dd += 4) { // 4 warps x 32 lanes + for (int dd = tid / 32; dd < 16; dd += 4) { for (int lr = lane; lr < MMA_K_BF16; lr += 32) { - int r = col_start + lr; // sequence index + int r = col_start + lr; if (r < s_k) { - int g_mn = dd / 8, g_k = lr / 8; - int llr = dd % 8, lc = lr % 8; - sV[g_k * CORES_MN_V * 64 + g_mn * 64 + llr * 8 + lc] = sTmaBuf[dd * 128 + r]; + int g_mn = dd / 8, g_k = lr / 8, llr = dd % 8, lc = lr % 8; + sV[g_k * CORES_MN_V * 64 + g_mn * 64 + llr * 8 + lc] = v_head[(d_base + dd) * s_k + r]; } } }