diff --git a/dsv4/kernels/attention/fmha_mixed_fp8_prefill.cuh b/dsv4/kernels/attention/fmha_mixed_fp8_prefill.cuh index 839a8f5c..ab3f2d8e 100644 --- a/dsv4/kernels/attention/fmha_mixed_fp8_prefill.cuh +++ b/dsv4/kernels/attention/fmha_mixed_fp8_prefill.cuh @@ -46,9 +46,9 @@ struct FmhaMixedFp8PrefillParams { const float* __restrict__ sink_bias; // (B,H), optional int B, H, T, N, HD, NOPE, ROPE; - int q_nope_head_stride, q_nope_batch_stride; - int q_scale_head_stride, q_scale_batch_stride; - int q_rope_head_stride, q_rope_batch_stride; + int q_nope_t_stride, q_nope_head_stride, q_nope_batch_stride; + int q_scale_t_stride, q_scale_head_stride, q_scale_batch_stride; + int q_rope_t_stride, q_rope_head_stride, q_rope_batch_stride; int o_head_stride, o_batch_stride, o_t_stride; int lse_head_stride, lse_batch_stride, lse_t_stride; float scale; @@ -275,7 +275,7 @@ fmha_mixed_fp8_prefill_kernel(FmhaMixedFp8PrefillParams p) { int qr = t_start + r; for (int c = 0; c < MMA_K_F8; c++) { int d = kt * MMA_K_F8 + c; - sQ8[_pfill_cidx_f8(r, c)] = q8[qr * p.q_nope_head_stride + d]; + sQ8[_pfill_cidx_f8(r, c)] = q8[qr * p.q_nope_t_stride + d]; } } // K: same as decode @@ -303,7 +303,7 @@ fmha_mixed_fp8_prefill_kernel(FmhaMixedFp8PrefillParams p) { // Apply Q and K scales for (int r = tid; r < T_ACT; r += blockDim.x) { int qr = t_start + r; - float q_s = q8_scale[qr * p.q_scale_head_stride]; + float q_s = q8_scale[qr * p.q_scale_t_stride]; for (int c = 0; c < kv_len; c++) { float ks = p.k_nope_scale[kv_start + c]; sLogits[r * SK_TILE + c] *= q_s * ks; @@ -321,7 +321,7 @@ fmha_mixed_fp8_prefill_kernel(FmhaMixedFp8PrefillParams p) { int qr = t_start + r; for (int c = 0; c < MMA_K_F16; c++) { int d = kt * MMA_K_F16 + c; - sQ16[_pfill_cidx_bf16_128(r, c)] = qrope[qr * p.q_rope_head_stride + d]; + sQ16[_pfill_cidx_bf16_128(r, c)] = qrope[qr * p.q_rope_t_stride + d]; } } for (int i = tid; i < kv_len * MMA_K_F16; i += blockDim.x) { diff --git a/dsv4/kernels/attention/fmha_mixed_fp8_prefill_capi.cu b/dsv4/kernels/attention/fmha_mixed_fp8_prefill_capi.cu index e751acad..662ba126 100644 --- a/dsv4/kernels/attention/fmha_mixed_fp8_prefill_capi.cu +++ b/dsv4/kernels/attention/fmha_mixed_fp8_prefill_capi.cu @@ -20,9 +20,9 @@ int fmha_mixed_fp8_prefill_launch( void* lse_ptr, const float* sink_bias_ptr, int B, int H, int T, int N, int HD, int NOPE, int ROPE, - int q_nope_head_stride, int q_nope_batch_stride, - int q_scale_head_stride, int q_scale_batch_stride, - int q_rope_head_stride, int q_rope_batch_stride, + int q_nope_t_stride, int q_nope_head_stride, int q_nope_batch_stride, + int q_scale_t_stride, int q_scale_head_stride, int q_scale_batch_stride, + int q_rope_t_stride, int q_rope_head_stride, int q_rope_batch_stride, int o_head_stride, int o_batch_stride, int o_t_stride, int lse_head_stride, int lse_batch_stride, int lse_t_stride, float scale @@ -42,10 +42,13 @@ int fmha_mixed_fp8_prefill_launch( p.sink_bias = sink_bias_ptr; p.B = B; p.H = H; p.T = T; p.N = N; p.HD = HD; p.NOPE = NOPE; p.ROPE = ROPE; + p.q_nope_t_stride = q_nope_t_stride; p.q_nope_head_stride = q_nope_head_stride; p.q_nope_batch_stride = q_nope_batch_stride; + p.q_scale_t_stride = q_scale_t_stride; p.q_scale_head_stride = q_scale_head_stride; p.q_scale_batch_stride = q_scale_batch_stride; + p.q_rope_t_stride = q_rope_t_stride; p.q_rope_head_stride = q_rope_head_stride; p.q_rope_batch_stride = q_rope_batch_stride; p.o_head_stride = o_head_stride; diff --git a/dsv4/kernels/attention/fmha_mixed_fp8_prefill_op.py b/dsv4/kernels/attention/fmha_mixed_fp8_prefill_op.py index 5216b281..c43d2c97 100644 --- a/dsv4/kernels/attention/fmha_mixed_fp8_prefill_op.py +++ b/dsv4/kernels/attention/fmha_mixed_fp8_prefill_op.py @@ -137,9 +137,9 @@ def fmha_mixed_fp8_prefill_raw( sink_ptr, ctypes.c_int(B), ctypes.c_int(H), ctypes.c_int(T), ctypes.c_int(N), ctypes.c_int(HD), ctypes.c_int(NOPE), ctypes.c_int(rope_dim), - ctypes.c_int(q_nope_fp8.stride(1)), ctypes.c_int(q_nope_fp8.stride(0)), - ctypes.c_int(q_nope_scale.stride(1)), ctypes.c_int(q_nope_scale.stride(0)), - ctypes.c_int(q_rope.stride(1)), ctypes.c_int(q_rope.stride(0)), + ctypes.c_int(q_nope_fp8.stride(2)), ctypes.c_int(q_nope_fp8.stride(1)), ctypes.c_int(q_nope_fp8.stride(0)), + ctypes.c_int(q_nope_scale.stride(2)), ctypes.c_int(q_nope_scale.stride(1)), ctypes.c_int(q_nope_scale.stride(0)), + ctypes.c_int(q_rope.stride(2)), ctypes.c_int(q_rope.stride(1)), ctypes.c_int(q_rope.stride(0)), ctypes.c_int(o.stride(1)), ctypes.c_int(o.stride(0)), ctypes.c_int(o.stride(2)), ctypes.c_int(lse.stride(1)), ctypes.c_int(lse.stride(0)), ctypes.c_int(lse.stride(2)), ctypes.c_float(scale),