Simplify prefill PV read: use decode kernel's exact pattern

Replace complex n_sub-iterating read with the same HD/8 iteration
pattern as the proven decode kernel. Extract from lane qr%32 instead
of always lane 0. For qr>=32, use warp 1; for qr>=64, add TMEM offset.

This should fix the row 1 accuracy issue (was cos=0.94 vs decode).
This commit is contained in:
2026-06-03 03:22:49 +00:00
parent 2bf5e74e61
commit 223c22488f

View File

@@ -131,13 +131,15 @@ __device__ void prefill_read_qk_rows(uint32_t tb, float* sLogits,
*/
/**
* Read a single row (query row qr) from ALL PV TMEM results.
* The PV MMA wrote to tb + n_sub * 16 for each n_sub (0..N_SUB-1).
* Row qr has valid data in all N_SUB groups.
* Uses the SAME approach as the decode kernel PV read, but extracts
* from the lane corresponding to row qr instead of always lane 0.
*
* Strategy: iterate over all n_sub values and read row qr from each.
* Uses tcgen05.ld.32x32b.x8 — lane (qr % 32) holds row qr's data.
* For qr in [32,63]: warp 1 has the data (both warps read from same address).
* For qr in [64,127]: TMEM offset +256, warp 0 has [64,95], warp 1 has [96,127].
* For qr < 32: warp 0, lane qr
* For qr 32-63: warp 1, lane (qr-32) -- same TMEM address, different rows
* For qr 64-95: same but TMEM offset +256
* For qr 96-127: same but TMEM offset +256
*
* This mirrors the proven decode kernel read pattern exactly.
*/
template<int HD=512, int N_SUB=32>
__device__ void prefill_read_pv_all_subs(uint32_t tb, int qr,
@@ -145,29 +147,25 @@ __device__ void prefill_read_pv_all_subs(uint32_t tb, int qr,
const int lane = threadIdx.x & 31;
const int wid = threadIdx.x >> 5;
int rg = qr / 32; // row-group: 0..3
int lane_idx = qr % 32;
int warp_with_data = (rg % 2 == 0) ? 0 : 1; // warp 0 for even RG, warp 1 for odd
uint32_t rg_off = (rg >= 2) ? 256 : 0; // TMEM column offset for row-groups 2-3
int local_lane = qr % 32;
int target_wid = (qr < 32) ? 0 : 1;
uint32_t rg_off = (qr >= 64) ? 256 : 0;
for (int ns = 0; ns < N_SUB; ns++) {
for (int c8 = 0; c8 < 2; c8++) { // 2 reads of 8 cols = 16 values per n_sub
float tmp[8];
if (wid == warp_with_data) {
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb + rg_off + ns * 16 + c8 * 8));
asm volatile("tcgen05.wait::ld.sync.aligned;" ::: "memory");
}
for (int n = 0; n < HD / 8; n++) {
float tmp[8];
if (wid == target_wid) {
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb + rg_off + n * 8));
asm volatile("tcgen05.wait::ld.sync.aligned;" ::: "memory");
}
if (wid == warp_with_data && lane == lane_idx) {
for (int c = 0; c < 8; c++) {
int d = ns * 16 + c8 * 8 + c;
if (d < HD) {
sOacc[qr * HD + d] += tmp[c] * rescale;
}
}
if (wid == target_wid && lane == local_lane) {
#pragma unroll
for (int c = 0; c < 8; c++) {
int d = n * 8 + c;
sOacc[qr * HD + d] += tmp[c] * rescale;
}
}
}