Fix production rope.py: FP32 arithmetic for forward_rope_partial + inverse_rope_bf16

This commit is contained in:
2026-05-31 09:17:36 +00:00
parent 970869d017
commit 1c18c16c68

View File

@@ -53,14 +53,14 @@ def forward_rope_partial(
freqs = 1.0 / (10000.0 ** (torch.arange(0, rope_dim, 2, dtype=torch.float32, device=x.device) / rope_dim))
pos_float = positions.float() # (T,)
angles = torch.outer(pos_float, freqs) # (T, half_rope)
cos_vals = torch.cos(angles).unsqueeze(1).to(x.dtype) # (T, 1, half_rope)
sin_vals = torch.sin(angles).unsqueeze(1).to(x.dtype) # (T, 1, half_rope)
cos_vals = torch.cos(angles).unsqueeze(1) # (T, 1, half_rope) FP32
sin_vals = torch.sin(angles).unsqueeze(1) # FP32 for accuracy
# Reshape x to (T, n_h, head_dim)
x_heads = x.reshape(T, n_h, head_dim)
# Extract RoPE portion
x_rope = x_heads[:, :, nope_dim:] # (T, n_h, rope_dim)
# Extract RoPE portion — compute in FP32
x_rope = x_heads[:, :, nope_dim:].float() # (T, n_h, rope_dim) FP32
x_even = x_rope[:, :, 0::2] # (T, n_h, half_rope)
x_odd = x_rope[:, :, 1::2] # (T, n_h, half_rope)
@@ -75,7 +75,7 @@ def forward_rope_partial(
# Copy NoPE portion unchanged
result = x_heads.clone()
result[:, :, nope_dim:] = x_rot
result[:, :, nope_dim:] = x_rot.to(torch.bfloat16)
return result.reshape(T, n_h * head_dim)
@@ -110,11 +110,12 @@ def inverse_rope_bf16(
sin_all = cos_sin_cache[positions, half_rope:] # (T, 32)
# Expand for broadcasting: (T, 1, 32) → broadcasts over heads
cos_all = cos_all.unsqueeze(1).to(o.dtype)
sin_all = sin_all.unsqueeze(1).to(o.dtype)
# CRITICAL: compute in FP32, not BF16! BF16 cos/sin destroys cos²+sin²=1
cos_all = cos_all.unsqueeze(1).float()
sin_all = sin_all.unsqueeze(1).float()
# Extract RoPE portion: (T, H, rope_dim)
o_rope = o[:, :, nope_dim:]
# Extract RoPE portion: (T, H, rope_dim) — compute in FP32
o_rope = o[:, :, nope_dim:].float()
# Split into even/odd pairs (interleaved GPT-J style)
o_even = o_rope[:, :, 0::2] # (T, H, 32)
@@ -133,6 +134,6 @@ def inverse_rope_bf16(
# Copy NoPE portion unchanged, replace RoPE portion
result = o.clone()
result[:, :, nope_dim:] = o_inv
result[:, :, nope_dim:] = o_inv.to(torch.bfloat16)
return result