diff --git a/single_shot_inference.py b/single_shot_inference.py index c4aafb06..0050d91a 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -65,15 +65,25 @@ def build_rope_cache(max_pos, rope_dim, device, theta=10000., rope_type="default return torch.cos(angles).to(device), torch.sin(angles).to(device) def _apply_rope(x, pos, cos, sin, rope_dim, inverse=False): + """In-place RoPE — mutates x, no full clone, no empty_like allocation. + + P5: Eliminates x.clone() + empty_like per RoPE call. + Old: 183 calls/token × 128KB clone = 23MB pointless memcpy + 183 kernel launches. + New: Operates on the rope dims in-place, one slice copy back. + """ T, nh, hd = x.shape; nope = hd - rope_dim if pos.device != cos.device: pos = pos.to(cos.device) c, s = cos[pos].unsqueeze(1), sin[pos].unsqueeze(1) - xr = x[:, :, nope:].float(); ev, od = xr[..., 0::2], xr[..., 1::2] - if inverse: rev, rod = ev*c + od*s, -ev*s + od*c - else: rev, rod = ev*c - od*s, ev*s + od*c - out = x.clone(); ro = torch.empty_like(xr) - ro[..., 0::2], ro[..., 1::2] = rev, rod - out[:, :, nope:] = ro.bfloat16(); return out + xr = x[:, :, nope:] # view, not copy + ev = xr[..., 0::2].clone() # need original ev for the mix + od = xr[..., 1::2] # view; will be overwritten below + if inverse: + xr[..., 0::2] = (ev * c + od * s).bfloat16() + xr[..., 1::2] = (-ev * s + od * c).bfloat16() + else: + xr[..., 0::2] = (ev * c - od * s).bfloat16() + xr[..., 1::2] = (ev * s + od * c).bfloat16() + return x # mutated in place # ===================================================================== # Weight loading