CRITICAL FIX: Add T-dimension strides to prefill FMHA kernel

The kernel was using head strides for the T (query row) dimension,
which happened to work for T=1 (qr=0 always) but was wrong for T>1.

For (B,H,T,NOPE) layout:
- Head stride = T*NOPE, but T stride = NOPE
- Scale head stride = T, but T stride = 1
- RoPE head stride = T*ROPE, but T stride = ROPE

Added q_nope_t_stride, q_scale_t_stride, q_rope_t_stride to params
struct, C API, and Python wrapper.
This commit is contained in:
2026-06-03 03:48:17 +00:00
parent dd1cbe1faa
commit 5417f65b08
3 changed files with 15 additions and 12 deletions

View File

@@ -46,9 +46,9 @@ struct FmhaMixedFp8PrefillParams {
const float* __restrict__ sink_bias; // (B,H), optional
int B, H, T, N, HD, NOPE, ROPE;
int q_nope_head_stride, q_nope_batch_stride;
int q_scale_head_stride, q_scale_batch_stride;
int q_rope_head_stride, q_rope_batch_stride;
int q_nope_t_stride, q_nope_head_stride, q_nope_batch_stride;
int q_scale_t_stride, q_scale_head_stride, q_scale_batch_stride;
int q_rope_t_stride, q_rope_head_stride, q_rope_batch_stride;
int o_head_stride, o_batch_stride, o_t_stride;
int lse_head_stride, lse_batch_stride, lse_t_stride;
float scale;
@@ -275,7 +275,7 @@ fmha_mixed_fp8_prefill_kernel(FmhaMixedFp8PrefillParams p) {
int qr = t_start + r;
for (int c = 0; c < MMA_K_F8; c++) {
int d = kt * MMA_K_F8 + c;
sQ8[_pfill_cidx_f8(r, c)] = q8[qr * p.q_nope_head_stride + d];
sQ8[_pfill_cidx_f8(r, c)] = q8[qr * p.q_nope_t_stride + d];
}
}
// K: same as decode
@@ -303,7 +303,7 @@ fmha_mixed_fp8_prefill_kernel(FmhaMixedFp8PrefillParams p) {
// Apply Q and K scales
for (int r = tid; r < T_ACT; r += blockDim.x) {
int qr = t_start + r;
float q_s = q8_scale[qr * p.q_scale_head_stride];
float q_s = q8_scale[qr * p.q_scale_t_stride];
for (int c = 0; c < kv_len; c++) {
float ks = p.k_nope_scale[kv_start + c];
sLogits[r * SK_TILE + c] *= q_s * ks;
@@ -321,7 +321,7 @@ fmha_mixed_fp8_prefill_kernel(FmhaMixedFp8PrefillParams p) {
int qr = t_start + r;
for (int c = 0; c < MMA_K_F16; c++) {
int d = kt * MMA_K_F16 + c;
sQ16[_pfill_cidx_bf16_128(r, c)] = qrope[qr * p.q_rope_head_stride + d];
sQ16[_pfill_cidx_bf16_128(r, c)] = qrope[qr * p.q_rope_t_stride + d];
}
}
for (int i = tid; i < kv_len * MMA_K_F16; i += blockDim.x) {

View File

@@ -20,9 +20,9 @@ int fmha_mixed_fp8_prefill_launch(
void* lse_ptr,
const float* sink_bias_ptr,
int B, int H, int T, int N, int HD, int NOPE, int ROPE,
int q_nope_head_stride, int q_nope_batch_stride,
int q_scale_head_stride, int q_scale_batch_stride,
int q_rope_head_stride, int q_rope_batch_stride,
int q_nope_t_stride, int q_nope_head_stride, int q_nope_batch_stride,
int q_scale_t_stride, int q_scale_head_stride, int q_scale_batch_stride,
int q_rope_t_stride, int q_rope_head_stride, int q_rope_batch_stride,
int o_head_stride, int o_batch_stride, int o_t_stride,
int lse_head_stride, int lse_batch_stride, int lse_t_stride,
float scale
@@ -42,10 +42,13 @@ int fmha_mixed_fp8_prefill_launch(
p.sink_bias = sink_bias_ptr;
p.B = B; p.H = H; p.T = T; p.N = N;
p.HD = HD; p.NOPE = NOPE; p.ROPE = ROPE;
p.q_nope_t_stride = q_nope_t_stride;
p.q_nope_head_stride = q_nope_head_stride;
p.q_nope_batch_stride = q_nope_batch_stride;
p.q_scale_t_stride = q_scale_t_stride;
p.q_scale_head_stride = q_scale_head_stride;
p.q_scale_batch_stride = q_scale_batch_stride;
p.q_rope_t_stride = q_rope_t_stride;
p.q_rope_head_stride = q_rope_head_stride;
p.q_rope_batch_stride = q_rope_batch_stride;
p.o_head_stride = o_head_stride;

View File

@@ -137,9 +137,9 @@ def fmha_mixed_fp8_prefill_raw(
sink_ptr,
ctypes.c_int(B), ctypes.c_int(H), ctypes.c_int(T), ctypes.c_int(N),
ctypes.c_int(HD), ctypes.c_int(NOPE), ctypes.c_int(rope_dim),
ctypes.c_int(q_nope_fp8.stride(1)), ctypes.c_int(q_nope_fp8.stride(0)),
ctypes.c_int(q_nope_scale.stride(1)), ctypes.c_int(q_nope_scale.stride(0)),
ctypes.c_int(q_rope.stride(1)), ctypes.c_int(q_rope.stride(0)),
ctypes.c_int(q_nope_fp8.stride(2)), ctypes.c_int(q_nope_fp8.stride(1)), ctypes.c_int(q_nope_fp8.stride(0)),
ctypes.c_int(q_nope_scale.stride(2)), ctypes.c_int(q_nope_scale.stride(1)), ctypes.c_int(q_nope_scale.stride(0)),
ctypes.c_int(q_rope.stride(2)), ctypes.c_int(q_rope.stride(1)), ctypes.c_int(q_rope.stride(0)),
ctypes.c_int(o.stride(1)), ctypes.c_int(o.stride(0)), ctypes.c_int(o.stride(2)),
ctypes.c_int(lse.stride(1)), ctypes.c_int(lse.stride(0)), ctypes.c_int(lse.stride(2)),
ctypes.c_float(scale),