Files
nvfp4-megamoe-kernel/tests/archive/test_blackwell_attn_b200.py
biondizzle 9cbdc92744 Restructure: cutedsl/ -> dsv4/ with proper layering
- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py
- Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc.
- Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda)
- Moved PyTorch bridges to dsv4/ops/
- Moved nn.Module layers to dsv4layers/
- Moved reference implementations to dsv4/reference/
- Moved vendored CUTLASS code to vendored/
- Archived ~190 debug tests to tests/archive/
- Kept ~15 canonical tests in tests/unit/
- Updated all import paths
- Added stubs for future components (model/, cache/, loader/)
- Updated pyproject.toml: dsv4-inference package name
2026-05-21 17:30:44 +00:00

319 lines
14 KiB
Python

#!/usr/bin/env python3
"""
DeepSeek-V4 Blackwell Attention — Full Pipeline Test
Tests the cutedsl.blackwell_attention module with real weights:
1. Prefill: process N tokens, write KV to paged cache
2. Decode: process 1 new token, read ALL cached KV, attend
3. Verify decode output matches BF16 reference
This is the core of the fix for the vLLM Blackwell garbage output bug.
Usage (on B200):
cd /root/nvfp4-megamoe-kernel
PYTHONPATH=/root/nvfp4-megamoe-kernel tests/venv/bin/python tests/test_blackwell_attn_b200.py
"""
import sys, os, json, torch, torch.nn.functional as F, math
from safetensors import safe_open
REPO = "/root/nvfp4-megamoe-kernel"
sys.path.insert(0, REPO)
MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
DEV = "cuda:0"
H = 7168; NH = 128; HD = 512; NOPE = 448; ROPE = 64
QL = 1536; OL = 1024; OG = 16; HPG = NH // OG
EPS = 1e-6; WINDOW = 128; SCALE = HD ** -0.5
E2M1 = torch.tensor([0,.5,1.,1.5,2.,3.,4.,6.,-0,-.5,-1.,-1.5,-2.,-3.,-4.,-6.], dtype=torch.float32)
_cache = {}
def P(k, wm, md):
if k in _cache: return _cache[k]
with safe_open(os.path.join(md, wm[k]), framework="pt") as f:
t = f.get_tensor(k)
_cache[k] = t
return t
def dequant(w, sf, gs):
d = w.device; lut = E2M1.to(d)
lo = lut[(w & 0xF).long()]; hi = lut[((w >> 4) & 0xF).long()]
O, I2 = w.shape; I = I2*2
u = torch.empty(O, I, dtype=torch.float32, device=d)
u[:,0::2] = lo; u[:,1::2] = hi
bs = sf.float().repeat_interleave(16, dim=1)[:O,:I]
return (u * bs * gs).to(torch.bfloat16)
def rms(x, w, eps=1e-6):
v = x.float().pow(2).mean(-1, keepdim=True)
return (w.float() * (x * torch.rsqrt(v+eps)).float()).to(x.dtype)
def make_runner(w, sf, gs_t, inf, outf, fused=False, lw=None):
from dsv4.layers.linear import Nvfp4Linear
fp4 = w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()
s = sf.to(torch.float8_e4m3fn) if sf.dtype != torch.float8_e4m3fn else sf
s = s.permute(1,0).contiguous()
if fused and gs_t.numel() == 2:
g1,g2 = gs_t[0].item(), gs_t[1].item(); gs = max(g1,g2)
if g1 != g2:
s32 = s.float(); sp = lw[0] if lw else outf//2
s32[:sp] *= g1/gs; s32[sp:] *= g2/gs; s = s32.to(torch.float8_e4m3fn)
else:
gs = gs_t.max().item() if gs_t.numel() > 1 else gs_t.item()
r = Nvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device))
r.fp4 = [fp4]; r.sf = [s]; r.gs = [gs]
r.finalize_weights(); r._ensure_initialized()
return r
def build_cos_sin(max_pos=4096, rope_dim=ROPE):
half = rope_dim // 2
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half, dtype=torch.float32) / half))
freqs = torch.outer(torch.arange(max_pos, dtype=torch.float32), inv_freq)
return torch.cat([freqs.cos(), freqs.sin(), freqs.cos(), freqs.sin()], dim=-1) # extra for safety
# Only use the first rope_dim cols
def build_cos_sin(max_pos=4096, rope_dim=ROPE):
half = rope_dim // 2
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half, dtype=torch.float32) / half))
freqs = torch.outer(torch.arange(max_pos, dtype=torch.float32), inv_freq)
return torch.cat([freqs.cos(), freqs.sin()], dim=-1)
def test_blackwell_attention(layer_id, compress_ratio):
"""Test the full blackwell attention pipeline for a specific layer."""
from dsv4.reference.attention import (
apply_gptj_rope, apply_inv_gptj_rope,
blackwell_attention_forward,
kv_quantize_fp8, kv_dequantize_fp8,
paged_kv_write, paged_kv_read,
causal_prefill_attention, decode_attention,
)
torch.cuda.set_device(0)
torch.manual_seed(42)
torch.cuda.empty_cache()
with open(os.path.join(MODEL, "model.safetensors.index.json")) as f:
wm = json.load(f)["weight_map"]
G = lambda k: P(k, wm, MODEL).to(DEV)
p = f"model.layers.{layer_id}"; a = f"{p}.self_attn"
layer_type = "SWA" if compress_ratio <= 1 else f"CSA(c={compress_ratio})"
print(f"\n{'='*70}")
print(f" Layer {layer_id}{layer_type} — Blackwell Attention Test")
print(f"{'='*70}")
# Load weights
emb = G("model.embed_tokens.weight")
anorm = G(f"{p}.input_layernorm.weight")
qn = G(f"{a}.q_a_norm.weight"); kvn = G(f"{a}.kv_norm.weight")
woa = G(f"{a}.o_a_proj.weight")
sinks = G(f"{a}.sinks")
qa_w = G(f"{a}.q_a_proj.weight"); qa_sf = G(f"{a}.q_a_proj.weight_scale"); qa_gs = G(f"{a}.q_a_proj.weight_scale_2")
qb_w = G(f"{a}.q_b_proj.weight"); qb_sf = G(f"{a}.q_b_proj.weight_scale"); qb_gs = G(f"{a}.q_b_proj.weight_scale_2")
kv_w = G(f"{a}.kv_proj.weight"); kv_sf = G(f"{a}.kv_proj.weight_scale"); kv_gs = G(f"{a}.kv_proj.weight_scale_2")
wob_w = G(f"{a}.o_b_proj.weight"); wob_sf = G(f"{a}.o_b_proj.weight_scale"); wob_gs = G(f"{a}.o_b_proj.weight_scale_2")
r_qa = make_runner(qa_w, qa_sf, qa_gs, H, qa_w.shape[0])
r_qb = make_runner(qb_w, qb_sf, qb_gs, QL, qb_w.shape[0])
r_kv = make_runner(kv_w, kv_sf, kv_gs, H, kv_w.shape[0])
r_wob = make_runner(wob_w, wob_sf, wob_gs, OG*OL, wob_w.shape[0])
cos_sin = build_cos_sin(max_pos=4096).to(DEV)
# ── Test 1: Prefill-only attention ────────────────────────────────
print(f"\n --- Test 1: Prefill attention (8 tokens) ---")
N = 8
token_ids = torch.tensor([1, 450, 8403, 315, 5413, 374, 2198, 643], dtype=torch.long, device=DEV)
positions = torch.arange(N, dtype=torch.int64, device=DEV)
with torch.no_grad():
hidden = emb[token_ids]
normed = rms(hidden, anorm, EPS)
qa = r_qa.run(normed)
kv = r_kv.run(normed)
qa_n = rms(qa, qn, EPS)
kv_n = rms(kv, kvn, EPS)
q = r_qb.run(qa_n).view(N, NH, HD)
q_rope = apply_gptj_rope(q, positions, cos_sin, NOPE, ROPE)
kv_rope = apply_gptj_rope(kv_n.unsqueeze(1), positions, cos_sin, NOPE, ROPE).squeeze(1)
# Causal attention
o_prefill = causal_prefill_attention(q_rope, kv_rope, SCALE)
print(f" Prefill attention output: amax={o_prefill.amax():.4f} NaN={torch.isnan(o_prefill).any()}")
# BF16 reference (same computation, different path)
q_t = q_rope.permute(1, 0, 2)
kv_exp = kv_rope.unsqueeze(0).expand(NH, -1, -1)
o_ref = F.scaled_dot_product_attention(q_t, kv_exp, kv_exp, is_causal=True, scale=SCALE).permute(1, 0, 2)
c = F.cosine_similarity(o_prefill.flatten().unsqueeze(0).float(), o_ref.flatten().unsqueeze(0).float()).item()
print(f" Prefill vs SDPA reference cosine: {c:.6f} {'' if c>=0.999 else ''}")
# ── Test 2: Decode attention with KV cache ────────────────────────
print(f"\n --- Test 2: Decode attention (1 token, 8 cached) ---")
block_size = 256
num_blocks = 64
kv_cache_fp8 = torch.zeros(num_blocks, block_size, HD, dtype=torch.float8_e4m3fn, device=DEV)
inv_scale_cache = torch.zeros(num_blocks * block_size, 1, dtype=torch.bfloat16, device=DEV)
with torch.no_grad():
# Write prefill KV to cache
kv_fp8, inv_s = kv_quantize_fp8(kv_rope)
prefill_slots = positions
paged_kv_write(kv_fp8, prefill_slots, kv_cache_fp8, block_size)
for t in range(N):
inv_scale_cache[prefill_slots[t]] = inv_s[t]
# Decode: token at position 8
decode_id = torch.tensor([991], dtype=torch.long, device=DEV)
decode_pos = torch.tensor([N], dtype=torch.int64, device=DEV)
hidden_d = emb[decode_id]
normed_d = rms(hidden_d, anorm, EPS)
qa_d = r_qa.run(normed_d)
kv_d = r_kv.run(normed_d)
qa_n_d = rms(qa_d, qn, EPS)
kv_n_d = rms(kv_d, kvn, EPS)
q_d = r_qb.run(qa_n_d).view(1, NH, HD)
q_rope_d = apply_gptj_rope(q_d, decode_pos, cos_sin, NOPE, ROPE)
kv_rope_d = apply_gptj_rope(kv_n_d.unsqueeze(1), decode_pos, cos_sin, NOPE, ROPE).squeeze(1)
# Write decode KV to cache
kv_fp8_d, inv_s_d = kv_quantize_fp8(kv_rope_d)
paged_kv_write(kv_fp8_d, decode_pos, kv_cache_fp8, block_size)
inv_scale_cache[decode_pos[0]] = inv_s_d[0]
# Read ALL 9 tokens from cache
all_slots = torch.arange(N + 1, dtype=torch.int64, device=DEV)
kv_cached_fp8 = paged_kv_read(all_slots, kv_cache_fp8, block_size, N + 1, HD)
kv_cached = kv_dequantize_fp8(kv_cached_fp8, inv_scale_cache[all_slots])
# Decode attention: 1 query vs 9 cached KVs
o_decode = decode_attention(q_rope_d, kv_cached, SCALE)
print(f" Decode attention output: amax={o_decode.amax():.4f} NaN={torch.isnan(o_decode).any()}")
# BF16 reference: process all 9 tokens at once
all_ids = torch.cat([token_ids, decode_id])
all_pos = torch.arange(N + 1, dtype=torch.int64, device=DEV)
hidden_all = emb[all_ids]
normed_all = rms(hidden_all, anorm, EPS)
qa_all = r_qa.run(normed_all)
kv_all = r_kv.run(normed_all)
qa_n_all = rms(qa_all, qn, EPS)
kv_n_all = rms(kv_all, kvn, EPS)
q_all = r_qb.run(qa_n_all).view(N + 1, NH, HD)
q_rope_all = apply_gptj_rope(q_all, all_pos, cos_sin, NOPE, ROPE)
kv_rope_all = apply_gptj_rope(kv_n_all.unsqueeze(1), all_pos, cos_sin, NOPE, ROPE).squeeze(1)
o_ref_all = causal_prefill_attention(q_rope_all, kv_rope_all, SCALE)
o_ref_decode = o_ref_all[N:] # Only the decode token
c = F.cosine_similarity(o_decode.flatten().unsqueeze(0).float(), o_ref_decode.flatten().unsqueeze(0).float()).item()
print(f" Decode vs BF16 reference cosine: {c:.6f} {'' if c>=0.98 else ''}")
# ── Test 3: Full output pipeline (inverse RoPE + o_a + o_b) ──────
print(f"\n --- Test 3: Full output pipeline ---")
with torch.no_grad():
# Using decode attention output
o_inv = apply_inv_gptj_rope(o_decode, decode_pos, cos_sin, NOPE, ROPE)
o_grouped = o_inv.view(1, OG, HPG * HD).permute(1, 0, 2)
woa_3d = woa.view(OG, OL, HPG * HD)
z_cached = torch.bmm(o_grouped, woa_3d.transpose(1, 2)).permute(1, 0, 2).reshape(1, OG * OL)
attn_out_cached = r_wob.run(z_cached)
# Using BF16 reference
o_inv_ref = apply_inv_gptj_rope(o_ref_decode, decode_pos, cos_sin, NOPE, ROPE)
o_grouped_ref = o_inv_ref.view(1, OG, HPG * HD).permute(1, 0, 2)
z_ref = torch.bmm(o_grouped_ref, woa_3d.transpose(1, 2)).permute(1, 0, 2).reshape(1, OG * OL)
attn_out_ref = r_wob.run(z_ref)
c_full = F.cosine_similarity(attn_out_cached.flatten().unsqueeze(0).float(), attn_out_ref.flatten().unsqueeze(0).float()).item()
print(f" Full pipeline cosine: {c_full:.6f} {'' if c_full>=0.98 else ''}")
print(f" Output amax: cached={attn_out_cached.amax():.4f} ref={attn_out_ref.amax():.4f}")
# ── Test 4: Multi-step decode (3 decode steps) ───────────────────
print(f"\n --- Test 4: Multi-step decode (3 steps) ---")
decode_ids = torch.tensor([991, 1502, 4200], dtype=torch.long, device=DEV)
with torch.no_grad():
cosines = []
for step in range(3):
pos = N + step
dpos = torch.tensor([pos], dtype=torch.int64, device=DEV)
d_id = decode_ids[step:step+1]
hidden_s = emb[d_id]
normed_s = rms(hidden_s, anorm, EPS)
qa_s = r_qa.run(normed_s)
kv_s = r_kv.run(normed_s)
qa_n_s = rms(qa_s, qn, EPS)
kv_n_s = rms(kv_s, kvn, EPS)
q_s = r_qb.run(qa_n_s).view(1, NH, HD)
q_rope_s = apply_gptj_rope(q_s, dpos, cos_sin, NOPE, ROPE)
kv_rope_s = apply_gptj_rope(kv_n_s.unsqueeze(1), dpos, cos_sin, NOPE, ROPE).squeeze(1)
# Write to cache
kv_fp8_s, inv_s_s = kv_quantize_fp8(kv_rope_s)
paged_kv_write(kv_fp8_s, dpos, kv_cache_fp8, block_size)
inv_scale_cache[dpos[0]] = inv_s_s[0]
# Read all cached KV
all_s = torch.arange(pos + 1, dtype=torch.int64, device=DEV)
kv_all_fp8 = paged_kv_read(all_s, kv_cache_fp8, block_size, pos + 1, HD)
kv_all_dequant = kv_dequantize_fp8(kv_all_fp8, inv_scale_cache[all_s])
# Decode attention
o_s = decode_attention(q_rope_s, kv_all_dequant, SCALE)
# BF16 reference
all_ids_ref = torch.cat([token_ids, decode_ids[:step+1]])
all_pos_ref = torch.arange(pos + 1, dtype=torch.int64, device=DEV)
hidden_ref = emb[all_ids_ref]
normed_ref = rms(hidden_ref, anorm, EPS)
qa_ref = r_qa.run(normed_ref)
kv_ref = r_kv.run(normed_ref)
qa_n_ref = rms(qa_ref, qn, EPS)
kv_n_ref = rms(kv_ref, kvn, EPS)
q_ref = r_qb.run(qa_n_ref).view(pos + 1, NH, HD)
q_rope_ref = apply_gptj_rope(q_ref, all_pos_ref, cos_sin, NOPE, ROPE)
kv_rope_ref = apply_gptj_rope(kv_n_ref.unsqueeze(1), all_pos_ref, cos_sin, NOPE, ROPE).squeeze(1)
o_ref_full = causal_prefill_attention(q_rope_ref, kv_rope_ref, SCALE)
o_ref_last = o_ref_full[-1:]
c = F.cosine_similarity(o_s.flatten().unsqueeze(0).float(), o_ref_last.flatten().unsqueeze(0).float()).item()
cosines.append(c)
print(f" Step {step} (pos={pos}, {pos+1} cached): cosine = {c:.6f} {'' if c>=0.98 else ''}")
# Cleanup
del r_qa, r_qb, r_kv, r_wob
torch.cuda.empty_cache()
return c_full, cosines
def main():
print("=" * 70)
print(" DeepSeek-V4 Blackwell Attention Pipeline Test")
print(" Tests cutedsl.blackwell_attention with real weights")
print("=" * 70)
# Test SWA layer (layer 60, compress_ratio=0)
c_swa, cosines_swa = test_blackwell_attention(60, 0)
print(f"\n{'='*70}")
print(f" SUMMARY")
print(f" Layer 60 (SWA):")
print(f" Full pipeline cosine: {c_swa:.6f}")
print(f" Multi-step decode: {', '.join(f'{c:.6f}' for c in cosines_swa)}")
print(f"{'='*70}")
if __name__ == "__main__":
main()