From 1c18c16c68f3bdaf67955e56afe16917d79402bb Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 31 May 2026 09:17:36 +0000 Subject: [PATCH] Fix production rope.py: FP32 arithmetic for forward_rope_partial + inverse_rope_bf16 --- dsv4/ops/rope.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/dsv4/ops/rope.py b/dsv4/ops/rope.py index 03aa4e9c..697e7dac 100644 --- a/dsv4/ops/rope.py +++ b/dsv4/ops/rope.py @@ -53,14 +53,14 @@ def forward_rope_partial( freqs = 1.0 / (10000.0 ** (torch.arange(0, rope_dim, 2, dtype=torch.float32, device=x.device) / rope_dim)) pos_float = positions.float() # (T,) angles = torch.outer(pos_float, freqs) # (T, half_rope) - cos_vals = torch.cos(angles).unsqueeze(1).to(x.dtype) # (T, 1, half_rope) - sin_vals = torch.sin(angles).unsqueeze(1).to(x.dtype) # (T, 1, half_rope) + cos_vals = torch.cos(angles).unsqueeze(1) # (T, 1, half_rope) FP32 + sin_vals = torch.sin(angles).unsqueeze(1) # FP32 for accuracy # Reshape x to (T, n_h, head_dim) x_heads = x.reshape(T, n_h, head_dim) - # Extract RoPE portion - x_rope = x_heads[:, :, nope_dim:] # (T, n_h, rope_dim) + # Extract RoPE portion — compute in FP32 + x_rope = x_heads[:, :, nope_dim:].float() # (T, n_h, rope_dim) FP32 x_even = x_rope[:, :, 0::2] # (T, n_h, half_rope) x_odd = x_rope[:, :, 1::2] # (T, n_h, half_rope) @@ -75,7 +75,7 @@ def forward_rope_partial( # Copy NoPE portion unchanged result = x_heads.clone() - result[:, :, nope_dim:] = x_rot + result[:, :, nope_dim:] = x_rot.to(torch.bfloat16) return result.reshape(T, n_h * head_dim) @@ -110,11 +110,12 @@ def inverse_rope_bf16( sin_all = cos_sin_cache[positions, half_rope:] # (T, 32) # Expand for broadcasting: (T, 1, 32) → broadcasts over heads - cos_all = cos_all.unsqueeze(1).to(o.dtype) - sin_all = sin_all.unsqueeze(1).to(o.dtype) + # CRITICAL: compute in FP32, not BF16! BF16 cos/sin destroys cos²+sin²=1 + cos_all = cos_all.unsqueeze(1).float() + sin_all = sin_all.unsqueeze(1).float() - # Extract RoPE portion: (T, H, rope_dim) - o_rope = o[:, :, nope_dim:] + # Extract RoPE portion: (T, H, rope_dim) — compute in FP32 + o_rope = o[:, :, nope_dim:].float() # Split into even/odd pairs (interleaved GPT-J style) o_even = o_rope[:, :, 0::2] # (T, H, 32) @@ -133,6 +134,6 @@ def inverse_rope_bf16( # Copy NoPE portion unchanged, replace RoPE portion result = o.clone() - result[:, :, nope_dim:] = o_inv + result[:, :, nope_dim:] = o_inv.to(torch.bfloat16) return result