Add RoPE KV test
This commit is contained in:
152
tests/test_rope_kv_b200.py
Normal file
152
tests/test_rope_kv_b200.py
Normal file
@@ -0,0 +1,152 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Quick test: verify that applying RoPE to KV fixes the NaN issue.
|
||||
Test the full attention pipeline with RoPE on both Q and KV.
|
||||
"""
|
||||
import sys, os, json, torch, torch.nn.functional as F, math
|
||||
from safetensors import safe_open
|
||||
|
||||
REPO = "/root/nvfp4-megamoe-kernel"
|
||||
sys.path.insert(0, REPO)
|
||||
MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
||||
DEV = "cuda:0"
|
||||
|
||||
H = 7168; NH = 128; HD = 512; NOPE = 448; ROPE = 64
|
||||
QL = 1536; OL = 1024; OG = 16; HPG = NH // OG
|
||||
EPS = 1e-6; SCALE = HD ** -0.5
|
||||
|
||||
E2M1 = torch.tensor([0,.5,1.,1.5,2.,3.,4.,6.,-0,-.5,-1.,-1.5,-2.,-3.,-4.,-6.], dtype=torch.float32)
|
||||
|
||||
_cache = {}
|
||||
def P(k, wm, md):
|
||||
if k in _cache: return _cache[k]
|
||||
with safe_open(os.path.join(md, wm[k]), framework="pt") as f:
|
||||
t = f.get_tensor(k)
|
||||
_cache[k] = t
|
||||
return t
|
||||
|
||||
def dequant(w, sf, gs):
|
||||
d = w.device; lut = E2M1.to(d)
|
||||
lo = lut[(w & 0xF).long()]; hi = lut[((w >> 4) & 0xF).long()]
|
||||
O, I2 = w.shape; I = I2*2
|
||||
u = torch.empty(O, I, dtype=torch.float32, device=d)
|
||||
u[:,0::2] = lo; u[:,1::2] = hi
|
||||
bs = sf.float().repeat_interleave(16, dim=1)[:O,:I]
|
||||
return (u * bs * gs).to(torch.bfloat16)
|
||||
|
||||
def rms(x, w, eps=1e-6):
|
||||
v = x.float().pow(2).mean(-1, keepdim=True)
|
||||
return (w.float() * (x * torch.rsqrt(v+eps)).float()).to(x.dtype)
|
||||
|
||||
def make_runner(w, sf, gs_t, inf, outf):
|
||||
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
|
||||
fp4 = w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()
|
||||
s = sf.to(torch.float8_e4m3fn) if sf.dtype != torch.float8_e4m3fn else sf
|
||||
s = s.permute(1,0).contiguous()
|
||||
gs = gs_t.max().item() if gs_t.numel() > 1 else gs_t.item()
|
||||
r = CuTeDSLNvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device))
|
||||
r.fp4 = [fp4]; r.sf = [s]; r.gs = [gs]
|
||||
r.finalize_weights(); r._ensure_initialized()
|
||||
return r
|
||||
|
||||
def build_cos_sin(max_pos=8192, rope_dim=ROPE):
|
||||
half = rope_dim // 2
|
||||
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half, dtype=torch.float32) / half))
|
||||
freqs = torch.outer(torch.arange(max_pos, dtype=torch.float32), inv_freq)
|
||||
return torch.cat([freqs.cos(), freqs.sin()], dim=-1)
|
||||
|
||||
def apply_gptj_rope(x, positions, cos_sin, nope, rope):
|
||||
if rope == 0 or x.numel() == 0: return x
|
||||
half = rope // 2
|
||||
cos = cos_sin[positions, :half].to(x.dtype)
|
||||
sin = cos_sin[positions, half:2*half].to(x.dtype)
|
||||
if x.dim() == 3: cos = cos.unsqueeze(1); sin = sin.unsqueeze(1)
|
||||
x_rope = x[..., nope:].clone()
|
||||
even = x_rope[..., 0::2]; odd = x_rope[..., 1::2]
|
||||
out = x.clone()
|
||||
out[..., nope:][..., 0::2] = even * cos - odd * sin
|
||||
out[..., nope:][..., 1::2] = even * sin + odd * cos
|
||||
return out
|
||||
|
||||
def full_sdpa_attention(q, kv, scale):
|
||||
T, NH, HD = q.shape
|
||||
q_2d = q.reshape(T * NH, HD)
|
||||
kv_exp = kv.unsqueeze(1).expand(-1, NH, -1).contiguous()
|
||||
k_2d = kv_exp.permute(1, 0, 2).unsqueeze(1).expand(NH, T, T, -1).contiguous().reshape(T * NH, T, HD)
|
||||
v_2d = k_2d.clone()
|
||||
scores = torch.matmul(q_2d.unsqueeze(1), k_2d.transpose(-1, -2)) * scale
|
||||
qpos = torch.arange(T, device=q.device).unsqueeze(1).repeat(1, NH).reshape(T * NH)
|
||||
kpos = torch.arange(T, device=q.device).unsqueeze(0)
|
||||
causal = kpos <= qpos.unsqueeze(1)
|
||||
scores = scores.squeeze(1).masked_fill(~causal, float('-inf'))
|
||||
weights = F.softmax(scores.float(), dim=-1).to(q.dtype)
|
||||
out = torch.matmul(weights.unsqueeze(1), v_2d).squeeze(1)
|
||||
return out.reshape(T, NH, HD)
|
||||
|
||||
torch.cuda.set_device(0)
|
||||
torch.manual_seed(42)
|
||||
|
||||
with open(os.path.join(MODEL, "model.safetensors.index.json")) as f:
|
||||
wm = json.load(f)["weight_map"]
|
||||
G = lambda k: P(k, wm, MODEL).to(DEV)
|
||||
|
||||
p = "model.layers.0"; a = f"{p}.self_attn"
|
||||
emb = G("model.embed_tokens.weight")
|
||||
anorm = G(f"{p}.input_layernorm.weight")
|
||||
qn = G(f"{a}.q_a_norm.weight"); kvn = G(f"{a}.kv_norm.weight")
|
||||
woa = G(f"{a}.o_a_proj.weight")
|
||||
qa_w = G(f"{a}.q_a_proj.weight"); qa_sf = G(f"{a}.q_a_proj.weight_scale"); qa_gs = G(f"{a}.q_a_proj.weight_scale_2")
|
||||
qb_w = G(f"{a}.q_b_proj.weight"); qb_sf = G(f"{a}.q_b_proj.weight_scale"); qb_gs = G"{a}.q_b_proj.weight_scale_2")
|
||||
kv_w = G(f"{a}.kv_proj.weight"); kv_sf = G(f"{a}.kv_proj.weight_scale"); kv_gs = G(f"{a}.kv_proj.weight_scale_2")
|
||||
wob_w = G(f"{a}.o_b_proj.weight"); wob_sf = G(f"{a}.o_b_proj.weight_scale"); wob_gs = G(f"{a}.o_b_proj.weight_scale_2")
|
||||
|
||||
r_qa = make_runner(qa_w, qa_sf, qa_gs, H, qa_w.shape[0])
|
||||
r_qb = make_runner(qb_w, qb_sf, qb_gs, QL, qb_w.shape[0])
|
||||
r_kv = make_runner(kv_w, kv_sf, kv_gs, H, kv_w.shape[0])
|
||||
r_wob = make_runner(wob_w, wob_sf, wob_gs, OG*OL, wob_w.shape[0])
|
||||
|
||||
cos_sin = build_cos_sin().to(DEV)
|
||||
NT = 6
|
||||
token_ids = torch.tensor([1, 450, 8403, 315, 5413, 374], dtype=torch.long, device=DEV)
|
||||
positions = torch.arange(NT, dtype=torch.int64, device=DEV)
|
||||
|
||||
with torch.no_grad():
|
||||
hidden = emb[token_ids]
|
||||
normed = rms(hidden, anorm, EPS)
|
||||
|
||||
# Projections
|
||||
qa = r_qa.run(normed); kv = r_kv.run(normed)
|
||||
qa_n = rms(qa, qn, EPS); kv_n = rms(kv, kvn, EPS)
|
||||
q = r_qb.run(qa_n).view(NT, NH, HD)
|
||||
q_rope = apply_gptj_rope(q, positions, cos_sin, NOPE, ROPE)
|
||||
|
||||
# Test 1: NO RoPE on KV (the bug)
|
||||
print("--- Test 1: No RoPE on KV (BUG) ---")
|
||||
o_no_rope = full_sdpa_attention(q_rope, kv_n, SCALE)
|
||||
print(f" Output: amax={o_no_rope.amax():.4f} NaN={torch.isnan(o_no_rope).any()}")
|
||||
|
||||
# Test 2: RoPE on KV (the fix)
|
||||
print("--- Test 2: RoPE on KV (FIX) ---")
|
||||
kv_rope = apply_gptj_rope(kv_n.unsqueeze(1), positions, cos_sin, NOPE, ROPE).squeeze(1)
|
||||
o_with_rope = full_sdpa_attention(q_rope, kv_rope, SCALE)
|
||||
print(f" Output: amax={o_with_rope.amax():.4f} NaN={torch.isnan(o_with_rope).any()}")
|
||||
|
||||
# Test 3: Full pipeline
|
||||
from cutedsl.csa_attention import apply_inv_gptj_rope
|
||||
o_inv = apply_inv_gptj_rope(o_with_rope, positions, cos_sin, NOPE, ROPE)
|
||||
o_grouped = o_inv.view(NT, OG, HPG * HD).permute(1, 0, 2)
|
||||
woa_3d = woa.view(OG, OL, HPG * HD)
|
||||
z = torch.bmm(o_grouped, woa_3d.transpose(1, 2)).permute(1, 0, 2).reshape(NT, OG * OL)
|
||||
attn_out = r_wob.run(z)
|
||||
|
||||
# LM head
|
||||
fnorm_w = G("model.norm.weight"); lm_head = G("lm_head.weight")
|
||||
x = hidden + attn_out
|
||||
x_n = rms(x, fnorm_w, EPS)
|
||||
logits = x_n @ lm_head.T
|
||||
log_std = logits[-1].float().std().item()
|
||||
top5 = torch.topk(logits[-1], 5)
|
||||
print(f"\n--- Logits ---")
|
||||
print(f" std={log_std:.4f} {'✅' if 0.5 < log_std < 50 else '❌'}")
|
||||
print(f" top5 tokens: {top5.indices.tolist()}")
|
||||
print(f" NaN in logits: {torch.isnan(logits).any()}")
|
||||
Reference in New Issue
Block a user