CRITICAL FIX: Remove double Q normalization and fix RoPE sin slice

This commit is contained in:
2026-05-19 17:27:33 +00:00
parent facc6509e7
commit 658b12cb3d

View File

@@ -144,26 +144,20 @@ def fused_qnorm_rope_kv_insert_py(
nope_dim,
rope_dim,
) -> None:
"""Pure PyTorch: per-head RMS norm on Q + GPT-J RoPE on Q.
"""Pure PyTorch: RoPE on Q only.
NOTE: KV cache write is now handled separately in the attention forward
(blackwell_attention_kv_write), which also applies RoPE and fp8 quantization
before writing to the paged cache. This function only handles Q normalization
and RoPE.
Q is already normed (by fused_q_kv_rmsnorm), so we only apply RoPE.
The original CUDA kernel also does KV cache insert, but we handle that
separately in blackwell_attention_kv_write.
"""
T = q.shape[0]
if T == 0:
return
# Per-head RMS norm on Q (no learned weight)
q_f32 = q.float()
q_rms = q_f32.pow(2).mean(-1, keepdim=True)
q.copy_(torch.rsqrt(q_rms + eps) * q_f32)
# GPT-J RoPE on Q
# GPT-J RoPE on Q only (Q is already normed)
half = rope_dim // 2
cos_q = cos_sin_cache[positions, :half].unsqueeze(1).to(q.dtype)
sin_q = cos_sin_cache[positions, half:].unsqueeze(1).to(q.dtype)
sin_q = cos_sin_cache[positions, half:2*half].unsqueeze(1).to(q.dtype)
q_rope = q[:, :, nope_dim:].clone()
q[:, :, nope_dim:][:, :, 0::2] = q_rope[:, :, 0::2] * cos_q - q_rope[:, :, 1::2] * sin_q
q[:, :, nope_dim:][:, :, 1::2] = q_rope[:, :, 0::2] * sin_q + q_rope[:, :, 1::2] * cos_q