P5: In-place RoPE — no x.clone(), no empty_like allocation

Eliminates 183 kernel launches per decoded token from pointless memcpy.
Operates on rope dims in-place via views instead of cloning the full tensor
and allocating an empty_like buffer.
This commit is contained in:
2026-06-01 21:18:41 +00:00
parent 00746c2d2b
commit e3412cf913

View File

@@ -65,15 +65,25 @@ def build_rope_cache(max_pos, rope_dim, device, theta=10000., rope_type="default
return torch.cos(angles).to(device), torch.sin(angles).to(device)
def _apply_rope(x, pos, cos, sin, rope_dim, inverse=False):
"""In-place RoPE — mutates x, no full clone, no empty_like allocation.
P5: Eliminates x.clone() + empty_like per RoPE call.
Old: 183 calls/token × 128KB clone = 23MB pointless memcpy + 183 kernel launches.
New: Operates on the rope dims in-place, one slice copy back.
"""
T, nh, hd = x.shape; nope = hd - rope_dim
if pos.device != cos.device: pos = pos.to(cos.device)
c, s = cos[pos].unsqueeze(1), sin[pos].unsqueeze(1)
xr = x[:, :, nope:].float(); ev, od = xr[..., 0::2], xr[..., 1::2]
if inverse: rev, rod = ev*c + od*s, -ev*s + od*c
else: rev, rod = ev*c - od*s, ev*s + od*c
out = x.clone(); ro = torch.empty_like(xr)
ro[..., 0::2], ro[..., 1::2] = rev, rod
out[:, :, nope:] = ro.bfloat16(); return out
xr = x[:, :, nope:] # view, not copy
ev = xr[..., 0::2].clone() # need original ev for the mix
od = xr[..., 1::2] # view; will be overwritten below
if inverse:
xr[..., 0::2] = (ev * c + od * s).bfloat16()
xr[..., 1::2] = (-ev * s + od * c).bfloat16()
else:
xr[..., 0::2] = (ev * c - od * s).bfloat16()
xr[..., 1::2] = (ev * s + od * c).bfloat16()
return x # mutated in place
# =====================================================================
# Weight loading