CRITICAL FIX: Remove double Q normalization and fix RoPE sin slice
This commit is contained in:
@@ -144,26 +144,20 @@ def fused_qnorm_rope_kv_insert_py(
|
||||
nope_dim,
|
||||
rope_dim,
|
||||
) -> None:
|
||||
"""Pure PyTorch: per-head RMS norm on Q + GPT-J RoPE on Q.
|
||||
"""Pure PyTorch: RoPE on Q only.
|
||||
|
||||
NOTE: KV cache write is now handled separately in the attention forward
|
||||
(blackwell_attention_kv_write), which also applies RoPE and fp8 quantization
|
||||
before writing to the paged cache. This function only handles Q normalization
|
||||
and RoPE.
|
||||
Q is already normed (by fused_q_kv_rmsnorm), so we only apply RoPE.
|
||||
The original CUDA kernel also does KV cache insert, but we handle that
|
||||
separately in blackwell_attention_kv_write.
|
||||
"""
|
||||
T = q.shape[0]
|
||||
if T == 0:
|
||||
return
|
||||
|
||||
# Per-head RMS norm on Q (no learned weight)
|
||||
q_f32 = q.float()
|
||||
q_rms = q_f32.pow(2).mean(-1, keepdim=True)
|
||||
q.copy_(torch.rsqrt(q_rms + eps) * q_f32)
|
||||
|
||||
# GPT-J RoPE on Q
|
||||
# GPT-J RoPE on Q only (Q is already normed)
|
||||
half = rope_dim // 2
|
||||
cos_q = cos_sin_cache[positions, :half].unsqueeze(1).to(q.dtype)
|
||||
sin_q = cos_sin_cache[positions, half:].unsqueeze(1).to(q.dtype)
|
||||
sin_q = cos_sin_cache[positions, half:2*half].unsqueeze(1).to(q.dtype)
|
||||
q_rope = q[:, :, nope_dim:].clone()
|
||||
q[:, :, nope_dim:][:, :, 0::2] = q_rope[:, :, 0::2] * cos_q - q_rope[:, :, 1::2] * sin_q
|
||||
q[:, :, nope_dim:][:, :, 1::2] = q_rope[:, :, 0::2] * sin_q + q_rope[:, :, 1::2] * cos_q
|
||||
|
||||
Reference in New Issue
Block a user