diff --git a/tests/test_vllm_codepaths_b200.py b/tests/test_vllm_codepaths_b200.py new file mode 100644 index 00000000..48cf0234 --- /dev/null +++ b/tests/test_vllm_codepaths_b200.py @@ -0,0 +1,251 @@ +#!/usr/bin/env python3 +""" +Test the EXACT code paths used in vLLM's Blackwell attention. + +Imports the actual functions from csa_attention.py and blackwell_attention.py +and verifies they produce correct output with real weights. + +This is the closest possible test to what runs in the container. +""" +import sys, os, json, torch, torch.nn.functional as F +from safetensors import safe_open + +REPO = "/root/nvfp4-megamoe-kernel" +sys.path.insert(0, REPO) +MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" +DEV = "cuda:0" + +H = 7168; NH = 128; HD = 512; NOPE = 448; ROPE = 64 +QL = 1536; OL = 1024; OG = 16; HPG = NH // OG +EPS = 1e-6; WINDOW = 128; SCALE = HD ** -0.5 + +E2M1 = torch.tensor([0,.5,1.,1.5,2.,3.,4.,6.,-0,-.5,-1.,-1.5,-2.,-3.,-4.,-6.], dtype=torch.float32) + +_cache = {} +def P(k, wm, md): + if k in _cache: return _cache[k] + with safe_open(os.path.join(md, wm[k]), framework="pt") as f: + t = f.get_tensor(k) + _cache[k] = t + return t + +def rms(x, w, eps=1e-6): + v = x.float().pow(2).mean(-1, keepdim=True) + return (w.float() * (x * torch.rsqrt(v+eps)).float()).to(x.dtype) + +def make_runner(w, sf, gs_t, inf, outf): + from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear + fp4 = w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous() + s = sf.to(torch.float8_e4m3fn) if sf.dtype != torch.float8_e4m3fn else sf + s = s.permute(1,0).contiguous() + gs = gs_t.max().item() if gs_t.numel() > 1 else gs_t.item() + r = CuTeDSLNvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device)) + r.fp4 = [fp4]; r.sf = [s]; r.gs = [gs] + r.finalize_weights(); r._ensure_initialized() + return r + +def build_cos_sin(max_pos=4096, rope_dim=ROPE): + half = rope_dim // 2 + inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half, dtype=torch.float32) / half)) + freqs = torch.outer(torch.arange(max_pos, dtype=torch.float32), inv_freq) + return torch.cat([freqs.cos(), freqs.sin()], dim=-1) + +def apply_gptj_rope(x, positions, cos_sin, nope_dim, rope_dim): + if rope_dim == 0 or x.numel() == 0: return x + half = rope_dim // 2 + cos = cos_sin[positions, :half].to(x.dtype) + sin = cos_sin[positions, half:2*half].to(x.dtype) + if x.dim() == 3: cos = cos.unsqueeze(1); sin = sin.unsqueeze(1) + x_rope = x[..., nope_dim:].clone() + even = x_rope[..., 0::2]; odd = x_rope[..., 1::2] + out = x.clone() + out[..., nope_dim:][..., 0::2] = even * cos - odd * sin + out[..., nope_dim:][..., 1::2] = even * sin + odd * cos + return out + +def apply_inv_gptj_rope(x, positions, cos_sin, nope_dim, rope_dim): + if rope_dim == 0 or x.numel() == 0: return x + half = rope_dim // 2 + cos = cos_sin[positions, :half].to(x.dtype) + sin = cos_sin[positions, half:2*half].to(x.dtype) + if x.dim() == 3: cos = cos.unsqueeze(1); sin = sin.unsqueeze(1) + x_rope = x[..., nope_dim:].clone() + even = x_rope[..., 0::2]; odd = x_rope[..., 1::2] + out = x.clone() + out[..., nope_dim:][..., 0::2] = even * cos + odd * sin + out[..., nope_dim:][..., 1::2] = -even * sin + odd * cos + return out + +def kv_quantize_fp8(kv_bf16): + amax = kv_bf16.float().abs().amax(dim=-1, keepdim=True).clamp(min=1e-12) + fp8_max = torch.tensor(448.0, dtype=torch.float32, device=kv_bf16.device) + scale = fp8_max / amax + kv_fp8 = (kv_bf16.float() * scale).to(torch.float8_e4m3fn) + inv_scale = (amax / fp8_max).to(torch.bfloat16) + return kv_fp8, inv_scale + +def kv_dequantize_fp8(kv_fp8, inv_scale): + return (kv_fp8.to(torch.bfloat16) * inv_scale).to(torch.bfloat16) + +def causal_prefill_attention(q, kv, scale): + T, NH, HD = q.shape + q_t = q.permute(1, 0, 2) + kv_exp = kv.unsqueeze(0).expand(NH, -1, -1) + out = F.scaled_dot_product_attention(q_t, kv_exp, kv_exp, is_causal=True, scale=scale) + return out.permute(1, 0, 2) + + +def main(): + """Test the exact csa_attention.py code paths used in the container.""" + from cutedsl.blackwell_attention import ( + blackwell_attention_kv_write, + blackwell_attention_decode, + blackwell_attention_forward, + ) + # Also import the vLLM patch version + sys.path.insert(0, os.path.join(REPO, "vllm", "patches", "layers")) + from csa_attention import ( + fused_qnorm_rope_kv_insert_py, + blackwell_attention_kv_write as vllm_kv_write, + blackwell_attention_decode as vllm_decode, + kv_quantize_fp8 as vllm_kv_quantize, + kv_dequantize_fp8 as vllm_kv_dequantize, + ) + + torch.cuda.set_device(0) + + with open(os.path.join(MODEL, "model.safetensors.index.json")) as f: + wm = json.load(f)["weight_map"] + G = lambda k: P(k, wm, MODEL).to(DEV) + + # Test layer 60 (SWA) + layer_id = 60 + p = f"model.layers.{layer_id}"; a = f"{p}.self_attn" + + emb = G("model.embed_tokens.weight") + anorm = G(f"{p}.input_layernorm.weight") + qn = G(f"{a}.q_a_norm.weight"); kvn = G(f"{a}.kv_norm.weight") + woa = G(f"{a}.o_a_proj.weight") + qa_w = G(f"{a}.q_a_proj.weight"); qa_sf = G(f"{a}.q_a_proj.weight_scale"); qa_gs = G(f"{a}.q_a_proj.weight_scale_2") + qb_w = G(f"{a}.q_b_proj.weight"); qb_sf = G(f"{a}.q_b_proj.weight_scale"); qb_gs = G(f"{a}.q_b_proj.weight_scale_2") + kv_w = G(f"{a}.kv_proj.weight"); kv_sf = G(f"{a}.kv_proj.weight_scale"); kv_gs = G(f"{a}.kv_proj.weight_scale_2") + wob_w = G(f"{a}.o_b_proj.weight"); wob_sf = G(f"{a}.o_b_proj.weight_scale"); wob_gs = G(f"{a}.o_b_proj.weight_scale_2") + + r_qa = make_runner(qa_w, qa_sf, qa_gs, H, qa_w.shape[0]) + r_qb = make_runner(qb_w, qb_sf, qb_gs, QL, qb_w.shape[0]) + r_kv = make_runner(kv_w, kv_sf, kv_gs, H, kv_w.shape[0]) + r_wob = make_runner(wob_w, wob_sf, wob_gs, OG*OL, wob_w.shape[0]) + cos_sin = build_cos_sin(max_pos=4096).to(DEV) + woa_3d = woa.view(OG, OL, HPG * HD) + + N = 8 + token_ids = torch.tensor([1, 450, 8403, 315, 5413, 374, 2198, 643], dtype=torch.long, device=DEV) + + with torch.no_grad(): + # ── Test 1: Verify fused_qnorm_rope_kv_insert_py ────────── + print("=== Test 1: fused_qnorm_rope_kv_insert_py ===") + positions_p = torch.arange(N, dtype=torch.int64, device=DEV) + hidden_p = emb[token_ids] + normed_p = rms(hidden_p, anorm, EPS) + qa_p = r_qa.run(normed_p) + kv_p = r_kv.run(normed_p) + + # Manual Q norm + RoPE (reference) + qa_n_ref = rms(qa_p, qn, EPS) + q_ref = r_qb.run(qa_n_ref).view(N, NH, HD) + q_rope_ref = apply_gptj_rope(q_ref, positions_p, cos_sin, NOPE, ROPE) + + # Using fused_qnorm_rope_kv_insert_py + q_test = r_qb.run(qa_n_ref).view(N, NH, HD) + fused_qnorm_rope_kv_insert_py( + q_test, kv_p, None, None, positions_p, + cos_sin, EPS, 64, # block_size + nope_dim=NOPE, rope_dim=ROPE, + ) + + c = F.cosine_similarity(q_rope_ref.flatten().unsqueeze(0).float(), q_test.flatten().unsqueeze(0).float()).item() + print(f" fused_qnorm_rope vs manual: cosine = {c:.6f} {'PASS' if c>=0.999 else 'FAIL'}") + + # ── Test 2: Verify blackwell_attention_kv_write ─────────── + print("\n=== Test 2: blackwell_attention_kv_write ===") + block_size = 64; max_tokens = 256 + num_blocks = (max_tokens + block_size - 1) // block_size + + # uint8 cache (like vLLM uses) + swa_cache = torch.zeros(num_blocks, block_size, HD, dtype=torch.uint8, device=DEV) + inv_scale_cache = torch.zeros(max_tokens, 1, dtype=torch.bfloat16, device=DEV) + slot_mapping = positions_p # Simple: slot = position + + # Manual KV RoPE + fp8 quant + kv_n = rms(kv_p, kvn, EPS) + kv_rope_manual = apply_gptj_rope(kv_n.unsqueeze(1), positions_p, cos_sin, NOPE, ROPE).squeeze(1) + kv_fp8_manual, inv_s_manual = kv_quantize_fp8(kv_rope_manual) + + # Write using vLLM's function + vllm_kv_write( + kv_n, positions_p, swa_cache, inv_scale_cache, + slot_mapping, block_size, cos_sin, + nope_dim=NOPE, rope_dim=ROPE, + ) + + # Read back and compare + bi = slot_mapping // block_size; oi = slot_mapping % block_size + kv_read = swa_cache[bi, oi].view(torch.float8_e4m3fn) + inv_read = inv_scale_cache[slot_mapping] + kv_dequant = kv_dequantize_fp8(kv_read, inv_read) + + c = F.cosine_similarity(kv_rope_manual.flatten().unsqueeze(0).float(), kv_dequant.flatten().unsqueeze(0).float()).item() + print(f" vllm_kv_write roundtrip: cosine = {c:.6f} {'PASS' if c>=0.99 else 'FAIL'}") + + # ── Test 3: Decode attention using swa_indices ──────────── + print("\n=== Test 3: Decode attention with swa_indices ===") + decode_id = torch.tensor([991], dtype=torch.long, device=DEV) + pos_d = torch.tensor([N], dtype=torch.int64, device=DEV) + + # Write decode KV to cache + hidden_d = emb[decode_id] + normed_d = rms(hidden_d, anorm, EPS) + qa_d = r_qa.run(normed_d); kv_d = r_kv.run(normed_d) + qa_n_d = rms(qa_d, qn, EPS); kv_n_d = rms(kv_d, kvn, EPS) + q_d = r_qb.run(qa_n_d).view(1, NH, HD) + q_rope_d = apply_gptj_rope(q_d, pos_d, cos_sin, NOPE, ROPE) + + vllm_kv_write(kv_n_d, pos_d, swa_cache, inv_scale_cache, + pos_d, block_size, cos_sin, nope_dim=NOPE, rope_dim=ROPE) + + # swa_indices: simulate vLLM's pre-computed indices + # These are flat slot indices for each decode token's window + all_slots = torch.arange(N + 1, dtype=torch.int64, device=DEV) + swa_indices = all_slots.unsqueeze(0) # (1, N+1) — all tokens in window + swa_lens = torch.tensor([N + 1], dtype=torch.int64, device=DEV) + + o_decode = vllm_decode( + q_rope_d, pos_d, swa_cache, inv_scale_cache, + pos_d, block_size, SCALE, WINDOW, + swa_indices=swa_indices, + swa_lens=swa_lens, + decode_token_idx=0, + ) + print(f" Decode output: amax={o_decode.amax():.4f} NaN={torch.isnan(o_decode).any()}") + + # ── Reference: full prefill attention ──────────────────── + all_ids = torch.cat([token_ids, decode_id]) + all_pos = torch.arange(N + 1, dtype=torch.int64, device=DEV) + hidden_ref = emb[all_ids] + normed_ref = rms(hidden_ref, anorm, EPS) + qa_ref = r_qa.run(normed_ref); kv_ref = r_kv.run(normed_ref) + qa_n_ref = rms(qa_ref, qn, EPS); kv_n_ref = rms(kv_ref, kvn, EPS) + q_ref = r_qb.run(qa_n_ref).view(N + 1, NH, HD) + q_rope_ref = apply_gptj_rope(q_ref, all_pos, cos_sin, NOPE, ROPE) + kv_rope_ref = apply_gptj_rope(kv_n_ref.unsqueeze(1), all_pos, cos_sin, NOPE, ROPE).squeeze(1) + o_ref = causal_prefill_attention(q_rope_ref, kv_rope_ref, SCALE) + o_ref_decode = o_ref[-1:] + + c = F.cosine_similarity(o_decode.flatten().unsqueeze(0).float(), o_ref_decode.flatten().unsqueeze(0).float()).item() + print(f" Decode vs reference cosine: {c:.6f} {'PASS' if c>=0.98 else 'FAIL}") + + print("\n=== DONE ===") + + +if __name__ == "__main__": + main()