From 99b6de316b4ff8bbcf8aaec6cdc888dee787ece5 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 3 Jun 2026 02:59:19 +0000 Subject: [PATCH] Fix prefill kernel: add missing tb base in PV TMEM read, fix ACCUMULATE for per-row PV Two critical fixes: 1. prefill_read_pv_all_subs: was missing 'tb' base in TMEM read address 2. PV MMA ACCUMULATE: use pv_kt == 0 (not kv_tile==0 && pv_kt==0 && n_sub==0) so each query row's PV starts fresh instead of accumulating into previous row's result --- dsv4/kernels/attention/fmha_mixed_fp8_prefill.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dsv4/kernels/attention/fmha_mixed_fp8_prefill.cuh b/dsv4/kernels/attention/fmha_mixed_fp8_prefill.cuh index de49e28e..57e413c6 100644 --- a/dsv4/kernels/attention/fmha_mixed_fp8_prefill.cuh +++ b/dsv4/kernels/attention/fmha_mixed_fp8_prefill.cuh @@ -157,7 +157,7 @@ __device__ void prefill_read_pv_all_subs(uint32_t tb, int qr, asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) - : "r"(rg_off + ns * 16 + c8 * 8)); + : "r"(tb + rg_off + ns * 16 + c8 * 8)); asm volatile("tcgen05.wait::ld.sync.aligned;" ::: "memory"); } @@ -422,7 +422,7 @@ fmha_mixed_fp8_prefill_kernel(FmhaMixedFp8PrefillParams p) { } __syncthreads(); - bool first = (kv_tile == 0 && pv_kt == 0 && n_sub == 0); + bool first = (pv_kt == 0); // Fresh for each query row's PV if (is_mma_warp && lane == 0) { uint64_t dp = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sPk), 128); uint64_t dv = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sV), 16);