Fix prefill kernel: read ALL n_sub PV results (was only n_sub=0)

Critical bug: prefill_read_pv_row only read n_sub=0 (16 out of 512 HD dims).
Replaced with prefill_read_pv_all_subs that iterates over all 32 n_sub groups.
Also fixed TMEM row-group/warp mapping for rows 32-127.
This commit is contained in:
2026-06-03 02:54:59 +00:00
parent a4ef6c3454
commit 9034f67b0f

View File

@@ -129,52 +129,44 @@ __device__ void prefill_read_qk_rows(uint32_t tb, float* sLogits,
*
* Writes 16 values (one n_sub PV output) to sOacc[qr*HD + d_base + 0..15].
*/
__device__ void prefill_read_pv_row(uint32_t tb, int qr, int n_sub,
float* sOacc, int HD, float rescale) {
/**
* Read a single row (query row qr) from ALL PV TMEM results.
* The PV MMA wrote to tb + n_sub * 16 for each n_sub (0..N_SUB-1).
* Row qr has valid data in all N_SUB groups.
*
* Strategy: iterate over all n_sub values and read row qr from each.
* Uses tcgen05.ld.32x32b.x8 — lane (qr % 32) holds row qr's data.
* For qr in [32,63]: warp 1 has the data (both warps read from same address).
* For qr in [64,127]: TMEM offset +256, warp 0 has [64,95], warp 1 has [96,127].
*/
template<int HD=512, int N_SUB=32>
__device__ void prefill_read_pv_all_subs(uint32_t tb, int qr,
float* sOacc, float rescale) {
const int lane = threadIdx.x & 31;
const int wid = threadIdx.x >> 5;
// Only warp 0 participates (for rows 0-31 and 64-95)
// Warp 1 for rows 32-63 and 96-127
// But we can use any warp — the data is in TMEM, we just need the right lane
int rg = (qr < 32) ? 0 : (qr < 64) ? 1 : (qr < 96) ? 2 : 3;
uint32_t rg_off = (rg >= 2) ? 256 : 0;
int lane_idx = qr % 32; // Which lane has row qr's data
int warp_for_row = (rg < 2) ? 0 : 0; // Both warps read from same address
int rg = qr / 32; // row-group: 0..3
int lane_idx = qr % 32;
int warp_with_data = (rg % 2 == 0) ? 0 : 1; // warp 0 for even RG, warp 1 for odd
uint32_t rg_off = (rg >= 2) ? 256 : 0; // TMEM column offset for row-groups 2-3
// Actually, let me just use warp 0 for all reads. If qr is in rows 32-63,
// warp 1 has the data. I need to be more careful.
//
// Simpler approach: read with ALL warps, but only the lane matching qr extracts.
// But tcgen05.ld is warp-collective — all 32 lanes must participate.
// So just use one warp (warp 0) and handle the row mapping.
for (int ns = 0; ns < N_SUB; ns++) {
for (int c8 = 0; c8 < 2; c8++) { // 2 reads of 8 cols = 16 values per n_sub
float tmp[8];
if (wid == warp_with_data) {
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(rg_off + ns * 16 + c8 * 8));
asm volatile("tcgen05.wait::ld.sync.aligned;" ::: "memory");
}
// For the PV MMA result at tb + n_sub * 16:
// tcgen05.ld.32x32b.x8 from (tb + n_sub * 16 + rg_off + col_base)
// gives: warp 0 = rows (rg_start + 0..31), warp 1 = rows (rg_start + 32..63)
// where rg_start = 0 for rg_off=0, rg_start = 64 for rg_off=256
// We need 2 reads (8 cols each) to cover 16 TMEM columns per n_sub
for (int c8 = 0; c8 < 2; c8++) {
float tmp[8];
if (wid < 2) { // Both warps participate in the collective read
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(rg_off + n_sub * 16 + c8 * 8));
asm volatile("tcgen05.wait::ld.sync.aligned;" ::: "memory");
}
__syncwarp();
// Only the thread that has the right row and is in the right warp extracts
// For rows 0-31 and 64-95: warp 0, lane = row % 32
// For rows 32-63 and 96-127: warp 1, lane = row % 32
int expected_wid = (rg < 2) ? ((qr < 32) ? 0 : 1) : ((qr < 96) ? 0 : 1);
if (wid == expected_wid && lane == lane_idx) {
for (int c = 0; c < 8; c++) {
int d = n_sub * 16 + c8 * 8 + c;
if (d < HD) {
sOacc[qr * HD + d] += tmp[c] * rescale;
if (wid == warp_with_data && lane == lane_idx) {
for (int c = 0; c < 8; c++) {
int d = ns * 16 + c8 * 8 + c;
if (d < HD) {
sOacc[qr * HD + d] += tmp[c] * rescale;
}
}
}
}
@@ -444,12 +436,7 @@ fmha_mixed_fp8_prefill_kernel(FmhaMixedFp8PrefillParams p) {
// Read PV result for row qr from TMEM
asm volatile("fence.sc.gpu;" ::: "memory");
__syncthreads();
prefill_read_pv_row(tb, qr, 0, sOacc, HD, p_rescale);
// Note: prefill_read_pv_row only reads n_sub=0 (first 16 HD dims).
// We need to loop over all n_sub values.
// For brevity, the full implementation reads all 32 n_sub values.
// TODO: implement the full n_sub loop in prefill_read_pv_row.
// For now, this is a placeholder that only reads n_sub=0.
prefill_read_pv_all_subs<HD, N_SUB>(tb, qr, sOacc, p_rescale);
__syncthreads();
} // qr
} // kv_tile