diff --git a/tests/test_nvfp4_attention_b200.py b/tests/test_nvfp4_attention_b200.py new file mode 100644 index 00000000..68c27223 --- /dev/null +++ b/tests/test_nvfp4_attention_b200.py @@ -0,0 +1,252 @@ +#!/usr/bin/env python3 +""" +Test NVFP4 attention: quantize Q and K, GEMM in NVFP4, softmax in BF16. + +Step 1: Verify NVFP4 quantize/dequant roundtrip for attention +Step 2: Q×K^T using CuTeDSL NVFP4 GEMM +Step 3: Softmax + attn×V +Step 4: Full pipeline with real weights, compare to BF16 SDPA + +Usage (on B200): + cd /root/nvfp4-megamoe-kernel + PYTHONPATH=/root/nvfp4-megamoe-kernel tests/venv/bin/python tests/test_nvfp4_attention_b200.py +""" + +import sys, os, json, torch, torch.nn.functional as F, math +from safetensors import safe_open + +REPO = "/root/nvfp4-megamoe-kernel" +sys.path.insert(0, REPO) +MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" +DEV = "cuda:0" + +H = 7168; NH = 128; HD = 512; NOPE = 448; ROPE = 64 +QL = 1536; OL = 1024; OG = 16; HPG = NH // OG +EPS = 1e-6; WINDOW = 8192; SCALE = HD ** -0.5 + +E2M1 = torch.tensor([0,.5,1.,1.5,2.,3.,4.,6.,-0,-.5,-1.,-1.5,-2.,-3.,-4.,-6.], dtype=torch.float32) + +_cache = {} +def P(k, wm, md): + if k in _cache: return _cache[k] + with safe_open(os.path.join(md, wm[k]), framework="pt") as f: + t = f.get_tensor(k) + _cache[k] = t + return t + +def dequant(w, sf, gs): + d = w.device; lut = E2M1.to(d) + lo = lut[(w & 0xF).long()]; hi = lut[((w >> 4) & 0xF).long()] + O, I2 = w.shape; I = I2*2 + u = torch.empty(O, I, dtype=torch.float32, device=d) + u[:,0::2] = lo; u[:,1::2] = hi + bs = sf.float().repeat_interleave(16, dim=1)[:O,:I] + return (u * bs * gs).to(torch.bfloat16) + +def rms(x, w, eps=1e-6): + v = x.float().pow(2).mean(-1, keepdim=True) + return (w.float() * (x * torch.rsqrt(v+eps)).float()).to(x.dtype) + +def make_runner(w, sf, gs_t, inf, outf, fused=False, lw=None): + from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear + fp4 = w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous() + s = sf.to(torch.float8_e4m3fn) if sf.dtype != torch.float8_e4m3fn else sf + s = s.permute(1,0).contiguous() + if fused and gs_t.numel() == 2: + g1,g2 = gs_t[0].item(), gs_t[1].item(); gs = max(g1,g2) + if g1 != g2: + s32 = s.float(); sp = lw[0] if lw else outf//2 + s32[:sp] *= g1/gs; s32[sp:] *= g2/gs; s = s32.to(torch.float8_e4m3fn) + else: + gs = gs_t.max().item() if gs_t.numel() > 1 else gs_t.item() + r = CuTeDSLNvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device)) + r.fp4 = [fp4]; r.sf = [s]; r.gs = [gs] + r.finalize_weights(); r._ensure_initialized() + return r + +def apply_gptj_rope(x, positions, cos_sin, nope, rope): + if rope == 0 or x.numel() == 0: return x + half = rope // 2 + cos = cos_sin[positions, :half].to(x.dtype) + sin = cos_sin[positions, half:].to(x.dtype) + if x.dim() == 3: cos = cos.unsqueeze(1); sin = sin.unsqueeze(1) + x_rope = x[..., nope:].clone() + even = x_rope[..., 0::2]; odd = x_rope[..., 1::2] + out = x.clone() + out[..., nope:][..., 0::2] = even * cos - odd * sin + out[..., nope:][..., 1::2] = even * sin + odd * cos + return out + +def build_cos_sin(max_pos=4096, rope_dim=ROPE): + half = rope_dim // 2 + inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half, dtype=torch.float32) / half)) + freqs = torch.outer(torch.arange(max_pos, dtype=torch.float32), inv_freq) + return torch.cat([freqs.cos(), freqs.sin(), freqs.cos(), freqs.sin()], dim=-1) + + +def bf16_full_attention(q, kv, scale): + """BF16 reference: full self-attention with causal mask.""" + T, NH, HD = q.shape + q_2d = q.reshape(T * NH, HD) + kv_expanded = kv.unsqueeze(1).expand(-1, NH, -1).contiguous() + k_2d = kv_expanded.permute(1, 0, 2).unsqueeze(1).expand(NH, T, T, -1).contiguous().reshape(T * NH, T, HD) + v_2d = k_2d.clone() + scores = torch.matmul(q_2d.unsqueeze(1), k_2d.transpose(-1, -2)) * scale + query_pos = torch.arange(T, device=q.device).unsqueeze(1).repeat(1, NH).reshape(T * NH) + kv_pos = torch.arange(T, device=q.device).unsqueeze(0) + causal = kv_pos <= query_pos.unsqueeze(1) + scores = scores.squeeze(1).masked_fill(~causal, float('-inf')) + weights = F.softmax(scores.float(), dim=-1).to(q.dtype) + out = torch.matmul(weights.unsqueeze(1), v_2d).squeeze(1) + return out.reshape(T, NH, HD) + + +def nvfp4_qk_attention(q, kv, scale): + """NVFP4 attention: quantize Q and K for Q×K^T, then BF16 softmax + attn×V. + + Key insight: Q×K^T is (T*NH, HD) × (HD, T) = (T*NH, T). + This is a standard GEMM that CuTeDSL can handle. + We quantize Q as the "activation" and K^T as the "weight". + """ + from cutedsl.bridge import quantize_to_nvfp4, quantize_activation_nvfp4 + from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear + + T, NH, HD = q.shape + device = q.device + + # Q as activation: (T*NH, HD) → NVFP4 + q_2d = q.reshape(T * NH, HD) + q_fp4, q_sf, q_gs = quantize_to_nvfp4(q_2d) # (T*NH, HD//2), (T*NH, HD//16), scalar + + # K as weight: (T, HD) → transpose to (HD, T), quantize as weight + # In our framework, "weight" means quantized along K dim + kv_T = kv.T.contiguous() # (HD, T) + w_fp4, w_sf, w_gs = quantize_to_nvfp4(kv_T) # (HD//2, T), (HD//16, T), scalar + + # Use CuTeDSLNvfp4Linear runner for Q×K^T GEMM + # in_features=HD, out_features=T + # Q is "activation" side, K^T is "weight" side + M = T * NH + K = HD + N = T + + # Create runner for this specific (M, K, N) combination + runner = CuTeDSLNvfp4Linear( + in_features=K, out_features=N, max_num_tokens=M, device=str(device) + ) + + # Weight is kv_T: set up as (N, K//2) in N-major (standard row-major) + # runner expects: weight fp4 is (N, K//2), weight sf is (N, K//16) + # Our w_fp4 from quantize_to_nvfp4(kv_T) is (K//2, T) — that's (K_packed, N) + # Need to transpose to (N, K_packed) + w_fp4_loaded = w_fp4.T.contiguous() # (T, HD//2) = (N, K_packed) + w_sf_loaded = w_sf.T.contiguous() # (T, HD//16) = (N, K_sf) + + runner.fp4 = [w_fp4_loaded] + runner.sf = [w_sf_loaded] + runner.gs = [w_gs] + runner.finalize_weights() + runner._ensure_initialized() + + # Run: Q×K^T + # q_2d is (M, K) BF16, runner produces (M, N) BF16 + scores = runner.run(q_2d) * scale # (T*NH, T) + + # Causal mask + query_pos = torch.arange(T, device=device).unsqueeze(1).repeat(1, NH).reshape(T * NH) + kv_pos = torch.arange(T, device=device).unsqueeze(0) + causal = kv_pos <= query_pos.unsqueeze(1) + scores = scores.masked_fill(~causal, float('-inf')) + + # Softmax in BF16 (must be full precision for numerical stability) + weights = F.softmax(scores.float(), dim=-1).to(q.dtype) # (T*NH, T) + + # attn×V: (T*NH, T) × (T, HD) → (T*NH, HD) + # V = kv (shared, BF16) — no quantization needed here since attn weights are already BF16 + out = torch.matmul(weights, kv) # (T*NH, HD) + + return out.reshape(T, NH, HD) + + +def main(): + torch.cuda.set_device(0) + torch.manual_seed(42) + + print("=" * 70) + print(" NVFP4 Attention Kernel Test") + print("=" * 70) + + with open(os.path.join(MODEL, "model.safetensors.index.json")) as f: + wm = json.load(f)["weight_map"] + G = lambda k: P(k, wm, MODEL).to(DEV) + + p = "model.layers.0"; a = f"{p}.self_attn" + + # Load weights + emb = G("model.embed_tokens.weight") + anorm = G(f"{p}.input_layernorm.weight") + qn = G(f"{a}.q_a_norm.weight"); kvn = G(f"{a}.kv_norm.weight") + woa = G(f"{a}.o_a_proj.weight") + + qa_w = G(f"{a}.q_a_proj.weight"); qa_sf = G(f"{a}.q_a_proj.weight_scale"); qa_gs = G(f"{a}.q_a_proj.weight_scale_2") + qb_w = G(f"{a}.q_b_proj.weight"); qb_sf = G(f"{a}.q_b_proj.weight_scale"); qb_gs = G(f"{a}.q_b_proj.weight_scale_2") + kv_w = G(f"{a}.kv_proj.weight"); kv_sf = G(f"{a}.kv_proj.weight_scale"); kv_gs = G(f"{a}.kv_proj.weight_scale_2") + wob_w = G(f"{a}.o_b_proj.weight"); wob_sf = G(f"{a}.o_b_proj.weight_scale"); wob_gs = G(f"{a}.o_b_proj.weight_scale_2") + sinks = G(f"{a}.sinks") + + # BF16 references + qa_bf16 = dequant(qa_w, qa_sf, qa_gs.item()) + qb_bf16 = dequant(qb_w, qb_sf, qb_gs.item()) + kv_bf16 = dequant(kv_w, kv_sf, kv_gs.item()) + wob_bf16 = dequant(wob_w, wob_sf, wob_gs.item()) + + # CuTeDSL runners + r_qa = make_runner(qa_w, qa_sf, qa_gs, H, qa_w.shape[0]) + r_qb = make_runner(qb_w, qb_sf, qb_gs, QL, qb_w.shape[0]) + r_kv = make_runner(kv_w, kv_sf, kv_gs, H, kv_w.shape[0]) + r_wob = make_runner(wob_w, wob_sf, wob_gs, OG*OL, wob_w.shape[0]) + + # Input + token_ids = torch.tensor([1, 450, 8403, 315, 5413, 374], dtype=torch.long, device=DEV) + NT = len(token_ids) + cos_sin = build_cos_sin(max_pos=WINDOW + 256).to(DEV) + positions = torch.arange(NT, dtype=torch.int64, device=DEV) + + print(f" Input: {NT} tokens, {NH} heads, HD={HD}") + + with torch.no_grad(): + hidden = emb[token_ids] + normed = rms(hidden, anorm, EPS) + + # Projections + qa_cute = r_qa.run(normed) + kv_cute = r_kv.run(normed) + qa_n = rms(qa_cute, qn, EPS) + kv_n = rms(kv_cute, kvn, EPS) + q_cute = r_qb.run(qa_n).view(NT, NH, HD) + q_rope = apply_gptj_rope(q_cute, positions, cos_sin, NOPE, ROPE) + + # ── BF16 reference ──────────────────────────────────────────── + print("\n--- Step 1: BF16 reference attention ---") + o_bf16 = bf16_full_attention(q_rope, kv_n, SCALE) + print(f" BF16 attention output: amax={o_bf16.amax():.4f} NaN={torch.isnan(o_bf16).any()}") + + # ── NVFP4 Q×K^T attention ──────────────────────────────────── + print("\n--- Step 2: NVFP4 Q×K^T attention ---") + try: + o_nvfp4 = nvfp4_qk_attention(q_rope, kv_n, SCALE) + print(f" NVFP4 attention output: amax={o_nvfp4.amax():.4f} NaN={torch.isnan(o_nvfp4).any()}") + + c = F.cosine_similarity(o_nvfp4.flatten().unsqueeze(0).float(), o_bf16.flatten().unsqueeze(0).float()).item() + print(f" NVFP4 vs BF16 cosine: {c:.6f} {'✅' if c>=0.98 else '❌'}") + except Exception as e: + print(f" ERROR: {e}") + import traceback; traceback.print_exc() + + print("\n" + "=" * 70) + print(" Done") + print("=" * 70) + + +if __name__ == "__main__": + main()