diff --git a/single_shot_inference.py b/single_shot_inference.py index 61d03fa1..eb807f10 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -187,11 +187,17 @@ def apply_rope_partial(x, positions, cos_cache, sin_cache, head_dim, rope_dim): nope = hd - rope_dim cos = cos_cache[positions].unsqueeze(1) # (T, 1, half) FP32 sin = sin_cache[positions].unsqueeze(1) - out = x.clone() x_rope = x[:, :, nope:].float() # FP32 for accurate rotation - out[:, :, nope:] = (x_rope[..., 0::2] * cos - x_rope[..., 1::2] * sin).to(torch.bfloat16) - out[:, :, nope:][..., 1::2] = (x_rope[..., 0::2] * sin + x_rope[..., 1::2] * cos).to(torch.bfloat16) - return out + x_even = x_rope[..., 0::2] + x_odd = x_rope[..., 1::2] + rot_even = x_even * cos - x_odd * sin + rot_odd = x_even * sin + x_odd * cos + result = x.clone() + rope_out = torch.empty_like(x_rope) + rope_out[..., 0::2] = rot_even + rope_out[..., 1::2] = rot_odd + result[:, :, nope:] = rope_out.to(torch.bfloat16) + return result def apply_inverse_rope(o, positions, cos_cache, sin_cache, head_dim, rope_dim): @@ -201,11 +207,17 @@ def apply_inverse_rope(o, positions, cos_cache, sin_cache, head_dim, rope_dim): nope = hd - rope_dim cos = cos_cache[positions].unsqueeze(1) sin = sin_cache[positions].unsqueeze(1) - out = o.clone() o_rope = o[:, :, nope:].float() - out[:, :, nope:] = (o_rope[..., 0::2] * cos + o_rope[..., 1::2] * sin).to(torch.bfloat16) - out[:, :, nope:][..., 1::2] = (-o_rope[..., 0::2] * sin + o_rope[..., 1::2] * cos).to(torch.bfloat16) - return out + o_even = o_rope[..., 0::2] + o_odd = o_rope[..., 1::2] + inv_even = o_even * cos + o_odd * sin + inv_odd = -o_even * sin + o_odd * cos + result = o.clone() + rope_out = torch.empty_like(o_rope) + rope_out[..., 0::2] = inv_even + rope_out[..., 1::2] = inv_odd + result[:, :, nope:] = rope_out.to(torch.bfloat16) + return result class SimpleKVCache: """Per-layer KV cache for decode. Stores BF16 K,V accumulated across steps. diff --git a/tests/test_minimal_e2e.py b/tests/test_minimal_e2e.py index fc2a1691..4e5faafc 100644 --- a/tests/test_minimal_e2e.py +++ b/tests/test_minimal_e2e.py @@ -71,13 +71,17 @@ def apply_rope_partial(x, positions, cos_cache, sin_cache, head_dim, rope_dim): nope = hd - rope_dim cos = cos_cache[positions].unsqueeze(1) # (T, 1, half) FP32 sin = sin_cache[positions].unsqueeze(1) - out = x.clone() - # Compute in FP32 for numerical stability - x_rope = x[:, :, nope:].float() - out[:, :, nope:] = (x_rope[..., 0::2] * cos - x_rope[..., 1::2] * sin).to(torch.bfloat16) - # Second pass for odd elements (need original even values) - out[:, :, nope:][..., 1::2] = (x_rope[..., 0::2] * sin + x_rope[..., 1::2] * cos).to(torch.bfloat16) - return out + x_rope = x[:, :, nope:].float() # (T, n_h, rope_dim) + x_even = x_rope[..., 0::2] # (T, n_h, half) + x_odd = x_rope[..., 1::2] + rot_even = x_even * cos - x_odd * sin + rot_odd = x_even * sin + x_odd * cos + result = x.clone() + rope_out = torch.empty_like(x_rope) + rope_out[..., 0::2] = rot_even + rope_out[..., 1::2] = rot_odd + result[:, :, nope:] = rope_out.to(torch.bfloat16) + return result def apply_inverse_rope(o, positions, cos_cache, sin_cache, head_dim, rope_dim): """Apply inverse RoPE (conjugate rotation). Computes in FP32 for accuracy.""" @@ -85,11 +89,17 @@ def apply_inverse_rope(o, positions, cos_cache, sin_cache, head_dim, rope_dim): nope = hd - rope_dim cos = cos_cache[positions].unsqueeze(1) sin = sin_cache[positions].unsqueeze(1) - out = o.clone() o_rope = o[:, :, nope:].float() - out[:, :, nope:] = (o_rope[..., 0::2] * cos + o_rope[..., 1::2] * sin).to(torch.bfloat16) - out[:, :, nope:][..., 1::2] = (-o_rope[..., 0::2] * sin + o_rope[..., 1::2] * cos).to(torch.bfloat16) - return out + o_even = o_rope[..., 0::2] + o_odd = o_rope[..., 1::2] + inv_even = o_even * cos + o_odd * sin + inv_odd = -o_even * sin + o_odd * cos + result = o.clone() + rope_out = torch.empty_like(o_rope) + rope_out[..., 0::2] = inv_even + rope_out[..., 1::2] = inv_odd + result[:, :, nope:] = rope_out.to(torch.bfloat16) + return result def load_weights_to_cpu(checkpoint_dir): from safetensors.torch import load_file