New files: - dsv4/kernels/cuda/rope_cuda.cu: GPT-J interleaved RoPE kernel (forward+inverse) - dsv4/ops/rope_cuda.py: Python bridge with ctypes loading - tests/unit/test_rope_cuda.py: correctness test (cos >= 0.999998) Savings: ~915 launches/token → 183 launches/token
127 lines
4.1 KiB
Python
127 lines
4.1 KiB
Python
#!/usr/bin/env python3
|
|
"""Test CUDA RoPE kernel correctness.
|
|
|
|
Compare CUDA kernel output vs PyTorch reference.
|
|
Must achieve cos >= 0.999998 for production.
|
|
"""
|
|
import torch
|
|
import math
|
|
import sys
|
|
|
|
def build_rope_cache(max_pos, rope_dim, device, theta=10000.):
|
|
freqs = 1. / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim))
|
|
angles = torch.outer(torch.arange(max_pos, dtype=torch.float32), freqs)
|
|
return torch.cos(angles).to(device), torch.sin(angles).to(device)
|
|
|
|
def apply_rope_ref(x, pos, cos, sin, rope_dim, inverse=False):
|
|
"""PyTorch reference — the current _apply_rope implementation."""
|
|
T, nh, hd = x.shape
|
|
nope = hd - rope_dim
|
|
c, s = cos[pos].unsqueeze(1), sin[pos].unsqueeze(1)
|
|
xr = x[:, :, nope:] # view
|
|
ev = xr[..., 0::2].clone()
|
|
od = xr[..., 1::2]
|
|
if inverse:
|
|
xr[..., 0::2] = (ev * c + od * s).bfloat16()
|
|
xr[..., 1::2] = (-ev * s + od * c).bfloat16()
|
|
else:
|
|
xr[..., 0::2] = (ev * c - od * s).bfloat16()
|
|
xr[..., 1::2] = (ev * s + od * c).bfloat16()
|
|
return x
|
|
|
|
def test_rope_cuda():
|
|
from dsv4.ops.rope_cuda import apply_rope
|
|
|
|
device = "cuda:0"
|
|
rope_dim = 64
|
|
hd = 512
|
|
n_h = 128
|
|
T = 1 # decode
|
|
max_pos = 4096
|
|
|
|
cos, sin = build_rope_cache(max_pos, rope_dim, device)
|
|
|
|
# Test forward RoPE
|
|
torch.manual_seed(42)
|
|
x_ref = torch.randn(T, n_h, hd, dtype=torch.bfloat16, device=device)
|
|
x_cuda = x_ref.clone()
|
|
positions = torch.tensor([100], dtype=torch.long, device=device)
|
|
|
|
apply_rope_ref(x_ref, positions, cos, sin, rope_dim, inverse=False)
|
|
apply_rope(x_cuda, positions, cos, sin, rope_dim, inverse=False)
|
|
|
|
cos_sim = torch.nn.functional.cosine_similarity(
|
|
x_ref.flatten().float(), x_cuda.flatten().float(), dim=0
|
|
).item()
|
|
max_diff = (x_ref.float() - x_cuda.float()).abs().max().item()
|
|
|
|
print(f"Forward RoPE (T=1, n_h=128, hd=512):")
|
|
print(f" Cosine: {cos_sim:.8f}")
|
|
print(f" Max diff: {max_diff:.8f}")
|
|
|
|
if cos_sim < 0.999998:
|
|
print(f" ❌ FAIL: cosine < 0.999998")
|
|
return False
|
|
print(f" ✅ PASS")
|
|
|
|
# Test inverse RoPE
|
|
x_ref_inv = x_ref.clone()
|
|
x_cuda_inv = x_cuda.clone()
|
|
|
|
apply_rope_ref(x_ref_inv, positions, cos, sin, rope_dim, inverse=True)
|
|
apply_rope(x_cuda_inv, positions, cos, sin, rope_dim, inverse=True)
|
|
|
|
cos_sim_inv = torch.nn.functional.cosine_similarity(
|
|
x_ref_inv.flatten().float(), x_cuda_inv.flatten().float(), dim=0
|
|
).item()
|
|
|
|
print(f"\nInverse RoPE (T=1, n_h=128, hd=512):")
|
|
print(f" Cosine: {cos_sim_inv:.8f}")
|
|
|
|
if cos_sim_inv < 0.999998:
|
|
print(f" ❌ FAIL")
|
|
return False
|
|
print(f" ✅ PASS")
|
|
|
|
# Test round-trip (forward + inverse should be identity)
|
|
x_rt = torch.randn(T, n_h, hd, dtype=torch.bfloat16, device=device)
|
|
x_orig = x_rt.clone()
|
|
apply_rope(x_rt, positions, cos, sin, rope_dim, inverse=False)
|
|
apply_rope(x_rt, positions, cos, sin, rope_dim, inverse=True)
|
|
|
|
rt_cos = torch.nn.functional.cosine_similarity(
|
|
x_orig.flatten().float(), x_rt.flatten().float(), dim=0
|
|
).item()
|
|
print(f"\nRound-trip (forward + inverse):")
|
|
print(f" Cosine: {rt_cos:.8f}")
|
|
if rt_cos < 0.9999:
|
|
print(f" ❌ FAIL: round-trip error too large")
|
|
return False
|
|
print(f" ✅ PASS")
|
|
|
|
# Test multi-token
|
|
T2 = 8
|
|
x_ref2 = torch.randn(T2, n_h, hd, dtype=torch.bfloat16, device=device)
|
|
x_cuda2 = x_ref2.clone()
|
|
pos2 = torch.arange(T2, dtype=torch.long, device=device)
|
|
|
|
apply_rope_ref(x_ref2, pos2, cos, sin, rope_dim, inverse=False)
|
|
apply_rope(x_cuda2, pos2, cos, sin, rope_dim, inverse=False)
|
|
|
|
cos_sim2 = torch.nn.functional.cosine_similarity(
|
|
x_ref2.flatten().float(), x_cuda2.flatten().float(), dim=0
|
|
).item()
|
|
print(f"\nMulti-token forward (T=8, n_h=128, hd=512):")
|
|
print(f" Cosine: {cos_sim2:.8f}")
|
|
if cos_sim2 < 0.999998:
|
|
print(f" ❌ FAIL")
|
|
return False
|
|
print(f" ✅ PASS")
|
|
|
|
return True
|
|
|
|
if __name__ == "__main__":
|
|
torch.manual_seed(42)
|
|
success = test_rope_cuda()
|
|
sys.exit(0 if success else 1)
|