Fix prefill kernel: add missing tb base in PV TMEM read, fix ACCUMULATE for per-row PV

Two critical fixes:
1. prefill_read_pv_all_subs: was missing 'tb' base in TMEM read address
2. PV MMA ACCUMULATE: use pv_kt == 0 (not kv_tile==0 && pv_kt==0 && n_sub==0)
   so each query row's PV starts fresh instead of accumulating into previous row's result
This commit is contained in:
2026-06-03 02:59:19 +00:00
parent 9034f67b0f
commit 99b6de316b

View File

@@ -157,7 +157,7 @@ __device__ void prefill_read_pv_all_subs(uint32_t tb, int qr,
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(rg_off + ns * 16 + c8 * 8));
: "r"(tb + rg_off + ns * 16 + c8 * 8));
asm volatile("tcgen05.wait::ld.sync.aligned;" ::: "memory");
}
@@ -422,7 +422,7 @@ fmha_mixed_fp8_prefill_kernel(FmhaMixedFp8PrefillParams p) {
}
__syncthreads();
bool first = (kv_tile == 0 && pv_kt == 0 && n_sub == 0);
bool first = (pv_kt == 0); // Fresh for each query row's PV
if (is_mma_warp && lane == 0) {
uint64_t dp = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sPk), 128);
uint64_t dv = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sV), 16);