CRITICAL FIX: Add T-dimension strides to prefill FMHA kernel
The kernel was using head strides for the T (query row) dimension, which happened to work for T=1 (qr=0 always) but was wrong for T>1. For (B,H,T,NOPE) layout: - Head stride = T*NOPE, but T stride = NOPE - Scale head stride = T, but T stride = 1 - RoPE head stride = T*ROPE, but T stride = ROPE Added q_nope_t_stride, q_scale_t_stride, q_rope_t_stride to params struct, C API, and Python wrapper.
This commit is contained in:
@@ -46,9 +46,9 @@ struct FmhaMixedFp8PrefillParams {
|
||||
const float* __restrict__ sink_bias; // (B,H), optional
|
||||
|
||||
int B, H, T, N, HD, NOPE, ROPE;
|
||||
int q_nope_head_stride, q_nope_batch_stride;
|
||||
int q_scale_head_stride, q_scale_batch_stride;
|
||||
int q_rope_head_stride, q_rope_batch_stride;
|
||||
int q_nope_t_stride, q_nope_head_stride, q_nope_batch_stride;
|
||||
int q_scale_t_stride, q_scale_head_stride, q_scale_batch_stride;
|
||||
int q_rope_t_stride, q_rope_head_stride, q_rope_batch_stride;
|
||||
int o_head_stride, o_batch_stride, o_t_stride;
|
||||
int lse_head_stride, lse_batch_stride, lse_t_stride;
|
||||
float scale;
|
||||
@@ -275,7 +275,7 @@ fmha_mixed_fp8_prefill_kernel(FmhaMixedFp8PrefillParams p) {
|
||||
int qr = t_start + r;
|
||||
for (int c = 0; c < MMA_K_F8; c++) {
|
||||
int d = kt * MMA_K_F8 + c;
|
||||
sQ8[_pfill_cidx_f8(r, c)] = q8[qr * p.q_nope_head_stride + d];
|
||||
sQ8[_pfill_cidx_f8(r, c)] = q8[qr * p.q_nope_t_stride + d];
|
||||
}
|
||||
}
|
||||
// K: same as decode
|
||||
@@ -303,7 +303,7 @@ fmha_mixed_fp8_prefill_kernel(FmhaMixedFp8PrefillParams p) {
|
||||
// Apply Q and K scales
|
||||
for (int r = tid; r < T_ACT; r += blockDim.x) {
|
||||
int qr = t_start + r;
|
||||
float q_s = q8_scale[qr * p.q_scale_head_stride];
|
||||
float q_s = q8_scale[qr * p.q_scale_t_stride];
|
||||
for (int c = 0; c < kv_len; c++) {
|
||||
float ks = p.k_nope_scale[kv_start + c];
|
||||
sLogits[r * SK_TILE + c] *= q_s * ks;
|
||||
@@ -321,7 +321,7 @@ fmha_mixed_fp8_prefill_kernel(FmhaMixedFp8PrefillParams p) {
|
||||
int qr = t_start + r;
|
||||
for (int c = 0; c < MMA_K_F16; c++) {
|
||||
int d = kt * MMA_K_F16 + c;
|
||||
sQ16[_pfill_cidx_bf16_128(r, c)] = qrope[qr * p.q_rope_head_stride + d];
|
||||
sQ16[_pfill_cidx_bf16_128(r, c)] = qrope[qr * p.q_rope_t_stride + d];
|
||||
}
|
||||
}
|
||||
for (int i = tid; i < kv_len * MMA_K_F16; i += blockDim.x) {
|
||||
|
||||
@@ -20,9 +20,9 @@ int fmha_mixed_fp8_prefill_launch(
|
||||
void* lse_ptr,
|
||||
const float* sink_bias_ptr,
|
||||
int B, int H, int T, int N, int HD, int NOPE, int ROPE,
|
||||
int q_nope_head_stride, int q_nope_batch_stride,
|
||||
int q_scale_head_stride, int q_scale_batch_stride,
|
||||
int q_rope_head_stride, int q_rope_batch_stride,
|
||||
int q_nope_t_stride, int q_nope_head_stride, int q_nope_batch_stride,
|
||||
int q_scale_t_stride, int q_scale_head_stride, int q_scale_batch_stride,
|
||||
int q_rope_t_stride, int q_rope_head_stride, int q_rope_batch_stride,
|
||||
int o_head_stride, int o_batch_stride, int o_t_stride,
|
||||
int lse_head_stride, int lse_batch_stride, int lse_t_stride,
|
||||
float scale
|
||||
@@ -42,10 +42,13 @@ int fmha_mixed_fp8_prefill_launch(
|
||||
p.sink_bias = sink_bias_ptr;
|
||||
p.B = B; p.H = H; p.T = T; p.N = N;
|
||||
p.HD = HD; p.NOPE = NOPE; p.ROPE = ROPE;
|
||||
p.q_nope_t_stride = q_nope_t_stride;
|
||||
p.q_nope_head_stride = q_nope_head_stride;
|
||||
p.q_nope_batch_stride = q_nope_batch_stride;
|
||||
p.q_scale_t_stride = q_scale_t_stride;
|
||||
p.q_scale_head_stride = q_scale_head_stride;
|
||||
p.q_scale_batch_stride = q_scale_batch_stride;
|
||||
p.q_rope_t_stride = q_rope_t_stride;
|
||||
p.q_rope_head_stride = q_rope_head_stride;
|
||||
p.q_rope_batch_stride = q_rope_batch_stride;
|
||||
p.o_head_stride = o_head_stride;
|
||||
|
||||
@@ -137,9 +137,9 @@ def fmha_mixed_fp8_prefill_raw(
|
||||
sink_ptr,
|
||||
ctypes.c_int(B), ctypes.c_int(H), ctypes.c_int(T), ctypes.c_int(N),
|
||||
ctypes.c_int(HD), ctypes.c_int(NOPE), ctypes.c_int(rope_dim),
|
||||
ctypes.c_int(q_nope_fp8.stride(1)), ctypes.c_int(q_nope_fp8.stride(0)),
|
||||
ctypes.c_int(q_nope_scale.stride(1)), ctypes.c_int(q_nope_scale.stride(0)),
|
||||
ctypes.c_int(q_rope.stride(1)), ctypes.c_int(q_rope.stride(0)),
|
||||
ctypes.c_int(q_nope_fp8.stride(2)), ctypes.c_int(q_nope_fp8.stride(1)), ctypes.c_int(q_nope_fp8.stride(0)),
|
||||
ctypes.c_int(q_nope_scale.stride(2)), ctypes.c_int(q_nope_scale.stride(1)), ctypes.c_int(q_nope_scale.stride(0)),
|
||||
ctypes.c_int(q_rope.stride(2)), ctypes.c_int(q_rope.stride(1)), ctypes.c_int(q_rope.stride(0)),
|
||||
ctypes.c_int(o.stride(1)), ctypes.c_int(o.stride(0)), ctypes.c_int(o.stride(2)),
|
||||
ctypes.c_int(lse.stride(1)), ctypes.c_int(lse.stride(0)), ctypes.c_int(lse.stride(2)),
|
||||
ctypes.c_float(scale),
|
||||
|
||||
Reference in New Issue
Block a user