debug: V loaded directly from GMEM (not TMA) to isolate PV issue
This commit is contained in:
@@ -30,7 +30,7 @@ namespace dsv4::kernels::attention {
|
||||
struct FmhaTmaParams {
|
||||
const bf16_t* __restrict__ q;
|
||||
const bf16_t* __restrict__ k;
|
||||
const bf16_t* __restrict__ v;
|
||||
const bf16_t* __restrict__ v; // direct GMEM pointer for V
|
||||
bf16_t* __restrict__ o;
|
||||
float* __restrict__ lse;
|
||||
int s_k, T;
|
||||
@@ -69,6 +69,7 @@ fmha_tma_kernel(FmhaTmaParams params) {
|
||||
|
||||
bf16_t* __restrict__ q_head = (bf16_t*)params.q + head_idx * params.q_head_stride + batch_idx * params.q_batch_stride;
|
||||
bf16_t* __restrict__ o_head = params.o + head_idx * params.o_head_stride + batch_idx * params.o_batch_stride;
|
||||
const bf16_t* __restrict__ v_head = (const bf16_t*)params.v + head_idx * params.v_head_stride + batch_idx * params.v_batch_stride;
|
||||
float* __restrict__ lse_head = params.lse ? params.lse + head_idx * params.lse_head_stride + batch_idx * params.lse_batch_stride : nullptr;
|
||||
|
||||
CUtensorMap* __restrict__ tma_k = params.tma_k;
|
||||
@@ -241,28 +242,15 @@ fmha_tma_kernel(FmhaTmaParams params) {
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// V sub-tile: TMA load
|
||||
if (wid == 0 && lane == 0) {
|
||||
tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)tma_v, mbar_addr, col_start, d_base);
|
||||
tma_mbarrier_arrive_expect_tx(mbar_addr, TMA_TILE_BYTES);
|
||||
}
|
||||
tma_mbarrier_wait(mbar_addr, phase); phase ^= 1;
|
||||
__syncthreads();
|
||||
|
||||
// Convert V: TMA loaded (16, 128) row-major → (16, 16) canonical with CORES_MN_V=2
|
||||
// V in GMEM is (HD, s_k). The TMA tile covers rows [d_base, d_base+16) and all cols.
|
||||
// We need V_sub = V[d_base:d_base+16, col_start:col_start+16]
|
||||
// From sTmaBuf (16, 128): element at (dd, r) = sTmaBuf[dd * 128 + r]
|
||||
// where dd is the head-dim index (0..15) and r is the sequence index (0..127)
|
||||
// We only need r in [col_start, col_start+16)
|
||||
// V sub-tile: direct load from GMEM (same as working test_fmha_gen)
|
||||
const bf16_t* __restrict__ v_head = (const bf16_t*)params.v;
|
||||
for (int i = tid; i < V_SUB_SZ; i += 128) sV[i] = 0;
|
||||
for (int dd = tid / 32; dd < 16; dd += 4) { // 4 warps x 32 lanes
|
||||
for (int dd = tid / 32; dd < 16; dd += 4) {
|
||||
for (int lr = lane; lr < MMA_K_BF16; lr += 32) {
|
||||
int r = col_start + lr; // sequence index
|
||||
int r = col_start + lr;
|
||||
if (r < s_k) {
|
||||
int g_mn = dd / 8, g_k = lr / 8;
|
||||
int llr = dd % 8, lc = lr % 8;
|
||||
sV[g_k * CORES_MN_V * 64 + g_mn * 64 + llr * 8 + lc] = sTmaBuf[dd * 128 + r];
|
||||
int g_mn = dd / 8, g_k = lr / 8, llr = dd % 8, lc = lr % 8;
|
||||
sV[g_k * CORES_MN_V * 64 + g_mn * 64 + llr * 8 + lc] = v_head[(d_base + dd) * s_k + r];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user