diff --git a/vllm/patches/layers/csa_attention.py b/vllm/patches/layers/csa_attention.py index c9f2b9af..2147f336 100644 --- a/vllm/patches/layers/csa_attention.py +++ b/vllm/patches/layers/csa_attention.py @@ -144,26 +144,20 @@ def fused_qnorm_rope_kv_insert_py( nope_dim, rope_dim, ) -> None: - """Pure PyTorch: per-head RMS norm on Q + GPT-J RoPE on Q. + """Pure PyTorch: RoPE on Q only. - NOTE: KV cache write is now handled separately in the attention forward - (blackwell_attention_kv_write), which also applies RoPE and fp8 quantization - before writing to the paged cache. This function only handles Q normalization - and RoPE. + Q is already normed (by fused_q_kv_rmsnorm), so we only apply RoPE. + The original CUDA kernel also does KV cache insert, but we handle that + separately in blackwell_attention_kv_write. """ T = q.shape[0] if T == 0: return - # Per-head RMS norm on Q (no learned weight) - q_f32 = q.float() - q_rms = q_f32.pow(2).mean(-1, keepdim=True) - q.copy_(torch.rsqrt(q_rms + eps) * q_f32) - - # GPT-J RoPE on Q + # GPT-J RoPE on Q only (Q is already normed) half = rope_dim // 2 cos_q = cos_sin_cache[positions, :half].unsqueeze(1).to(q.dtype) - sin_q = cos_sin_cache[positions, half:].unsqueeze(1).to(q.dtype) + sin_q = cos_sin_cache[positions, half:2*half].unsqueeze(1).to(q.dtype) q_rope = q[:, :, nope_dim:].clone() q[:, :, nope_dim:][:, :, 0::2] = q_rope[:, :, 0::2] * cos_q - q_rope[:, :, 1::2] * sin_q q[:, :, nope_dim:][:, :, 1::2] = q_rope[:, :, 0::2] * sin_q + q_rope[:, :, 1::2] * cos_q