- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py - Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc. - Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda) - Moved PyTorch bridges to dsv4/ops/ - Moved nn.Module layers to dsv4layers/ - Moved reference implementations to dsv4/reference/ - Moved vendored CUTLASS code to vendored/ - Archived ~190 debug tests to tests/archive/ - Kept ~15 canonical tests in tests/unit/ - Updated all import paths - Added stubs for future components (model/, cache/, loader/) - Updated pyproject.toml: dsv4-inference package name
319 lines
14 KiB
Python
319 lines
14 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
DeepSeek-V4 Blackwell Attention — Full Pipeline Test
|
|
|
|
Tests the cutedsl.blackwell_attention module with real weights:
|
|
1. Prefill: process N tokens, write KV to paged cache
|
|
2. Decode: process 1 new token, read ALL cached KV, attend
|
|
3. Verify decode output matches BF16 reference
|
|
|
|
This is the core of the fix for the vLLM Blackwell garbage output bug.
|
|
|
|
Usage (on B200):
|
|
cd /root/nvfp4-megamoe-kernel
|
|
PYTHONPATH=/root/nvfp4-megamoe-kernel tests/venv/bin/python tests/test_blackwell_attn_b200.py
|
|
"""
|
|
|
|
import sys, os, json, torch, torch.nn.functional as F, math
|
|
from safetensors import safe_open
|
|
|
|
REPO = "/root/nvfp4-megamoe-kernel"
|
|
sys.path.insert(0, REPO)
|
|
MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
|
DEV = "cuda:0"
|
|
|
|
H = 7168; NH = 128; HD = 512; NOPE = 448; ROPE = 64
|
|
QL = 1536; OL = 1024; OG = 16; HPG = NH // OG
|
|
EPS = 1e-6; WINDOW = 128; SCALE = HD ** -0.5
|
|
|
|
E2M1 = torch.tensor([0,.5,1.,1.5,2.,3.,4.,6.,-0,-.5,-1.,-1.5,-2.,-3.,-4.,-6.], dtype=torch.float32)
|
|
|
|
_cache = {}
|
|
def P(k, wm, md):
|
|
if k in _cache: return _cache[k]
|
|
with safe_open(os.path.join(md, wm[k]), framework="pt") as f:
|
|
t = f.get_tensor(k)
|
|
_cache[k] = t
|
|
return t
|
|
|
|
def dequant(w, sf, gs):
|
|
d = w.device; lut = E2M1.to(d)
|
|
lo = lut[(w & 0xF).long()]; hi = lut[((w >> 4) & 0xF).long()]
|
|
O, I2 = w.shape; I = I2*2
|
|
u = torch.empty(O, I, dtype=torch.float32, device=d)
|
|
u[:,0::2] = lo; u[:,1::2] = hi
|
|
bs = sf.float().repeat_interleave(16, dim=1)[:O,:I]
|
|
return (u * bs * gs).to(torch.bfloat16)
|
|
|
|
def rms(x, w, eps=1e-6):
|
|
v = x.float().pow(2).mean(-1, keepdim=True)
|
|
return (w.float() * (x * torch.rsqrt(v+eps)).float()).to(x.dtype)
|
|
|
|
def make_runner(w, sf, gs_t, inf, outf, fused=False, lw=None):
|
|
from dsv4.layers.linear import Nvfp4Linear
|
|
fp4 = w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()
|
|
s = sf.to(torch.float8_e4m3fn) if sf.dtype != torch.float8_e4m3fn else sf
|
|
s = s.permute(1,0).contiguous()
|
|
if fused and gs_t.numel() == 2:
|
|
g1,g2 = gs_t[0].item(), gs_t[1].item(); gs = max(g1,g2)
|
|
if g1 != g2:
|
|
s32 = s.float(); sp = lw[0] if lw else outf//2
|
|
s32[:sp] *= g1/gs; s32[sp:] *= g2/gs; s = s32.to(torch.float8_e4m3fn)
|
|
else:
|
|
gs = gs_t.max().item() if gs_t.numel() > 1 else gs_t.item()
|
|
r = Nvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device))
|
|
r.fp4 = [fp4]; r.sf = [s]; r.gs = [gs]
|
|
r.finalize_weights(); r._ensure_initialized()
|
|
return r
|
|
|
|
def build_cos_sin(max_pos=4096, rope_dim=ROPE):
|
|
half = rope_dim // 2
|
|
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half, dtype=torch.float32) / half))
|
|
freqs = torch.outer(torch.arange(max_pos, dtype=torch.float32), inv_freq)
|
|
return torch.cat([freqs.cos(), freqs.sin(), freqs.cos(), freqs.sin()], dim=-1) # extra for safety
|
|
|
|
# Only use the first rope_dim cols
|
|
def build_cos_sin(max_pos=4096, rope_dim=ROPE):
|
|
half = rope_dim // 2
|
|
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half, dtype=torch.float32) / half))
|
|
freqs = torch.outer(torch.arange(max_pos, dtype=torch.float32), inv_freq)
|
|
return torch.cat([freqs.cos(), freqs.sin()], dim=-1)
|
|
|
|
|
|
def test_blackwell_attention(layer_id, compress_ratio):
|
|
"""Test the full blackwell attention pipeline for a specific layer."""
|
|
from dsv4.reference.attention import (
|
|
apply_gptj_rope, apply_inv_gptj_rope,
|
|
blackwell_attention_forward,
|
|
kv_quantize_fp8, kv_dequantize_fp8,
|
|
paged_kv_write, paged_kv_read,
|
|
causal_prefill_attention, decode_attention,
|
|
)
|
|
|
|
torch.cuda.set_device(0)
|
|
torch.manual_seed(42)
|
|
torch.cuda.empty_cache()
|
|
|
|
with open(os.path.join(MODEL, "model.safetensors.index.json")) as f:
|
|
wm = json.load(f)["weight_map"]
|
|
G = lambda k: P(k, wm, MODEL).to(DEV)
|
|
|
|
p = f"model.layers.{layer_id}"; a = f"{p}.self_attn"
|
|
layer_type = "SWA" if compress_ratio <= 1 else f"CSA(c={compress_ratio})"
|
|
|
|
print(f"\n{'='*70}")
|
|
print(f" Layer {layer_id} — {layer_type} — Blackwell Attention Test")
|
|
print(f"{'='*70}")
|
|
|
|
# Load weights
|
|
emb = G("model.embed_tokens.weight")
|
|
anorm = G(f"{p}.input_layernorm.weight")
|
|
qn = G(f"{a}.q_a_norm.weight"); kvn = G(f"{a}.kv_norm.weight")
|
|
woa = G(f"{a}.o_a_proj.weight")
|
|
sinks = G(f"{a}.sinks")
|
|
|
|
qa_w = G(f"{a}.q_a_proj.weight"); qa_sf = G(f"{a}.q_a_proj.weight_scale"); qa_gs = G(f"{a}.q_a_proj.weight_scale_2")
|
|
qb_w = G(f"{a}.q_b_proj.weight"); qb_sf = G(f"{a}.q_b_proj.weight_scale"); qb_gs = G(f"{a}.q_b_proj.weight_scale_2")
|
|
kv_w = G(f"{a}.kv_proj.weight"); kv_sf = G(f"{a}.kv_proj.weight_scale"); kv_gs = G(f"{a}.kv_proj.weight_scale_2")
|
|
wob_w = G(f"{a}.o_b_proj.weight"); wob_sf = G(f"{a}.o_b_proj.weight_scale"); wob_gs = G(f"{a}.o_b_proj.weight_scale_2")
|
|
|
|
r_qa = make_runner(qa_w, qa_sf, qa_gs, H, qa_w.shape[0])
|
|
r_qb = make_runner(qb_w, qb_sf, qb_gs, QL, qb_w.shape[0])
|
|
r_kv = make_runner(kv_w, kv_sf, kv_gs, H, kv_w.shape[0])
|
|
r_wob = make_runner(wob_w, wob_sf, wob_gs, OG*OL, wob_w.shape[0])
|
|
|
|
cos_sin = build_cos_sin(max_pos=4096).to(DEV)
|
|
|
|
# ── Test 1: Prefill-only attention ────────────────────────────────
|
|
print(f"\n --- Test 1: Prefill attention (8 tokens) ---")
|
|
N = 8
|
|
token_ids = torch.tensor([1, 450, 8403, 315, 5413, 374, 2198, 643], dtype=torch.long, device=DEV)
|
|
positions = torch.arange(N, dtype=torch.int64, device=DEV)
|
|
|
|
with torch.no_grad():
|
|
hidden = emb[token_ids]
|
|
normed = rms(hidden, anorm, EPS)
|
|
|
|
qa = r_qa.run(normed)
|
|
kv = r_kv.run(normed)
|
|
qa_n = rms(qa, qn, EPS)
|
|
kv_n = rms(kv, kvn, EPS)
|
|
q = r_qb.run(qa_n).view(N, NH, HD)
|
|
|
|
q_rope = apply_gptj_rope(q, positions, cos_sin, NOPE, ROPE)
|
|
kv_rope = apply_gptj_rope(kv_n.unsqueeze(1), positions, cos_sin, NOPE, ROPE).squeeze(1)
|
|
|
|
# Causal attention
|
|
o_prefill = causal_prefill_attention(q_rope, kv_rope, SCALE)
|
|
print(f" Prefill attention output: amax={o_prefill.amax():.4f} NaN={torch.isnan(o_prefill).any()}")
|
|
|
|
# BF16 reference (same computation, different path)
|
|
q_t = q_rope.permute(1, 0, 2)
|
|
kv_exp = kv_rope.unsqueeze(0).expand(NH, -1, -1)
|
|
o_ref = F.scaled_dot_product_attention(q_t, kv_exp, kv_exp, is_causal=True, scale=SCALE).permute(1, 0, 2)
|
|
c = F.cosine_similarity(o_prefill.flatten().unsqueeze(0).float(), o_ref.flatten().unsqueeze(0).float()).item()
|
|
print(f" Prefill vs SDPA reference cosine: {c:.6f} {'✅' if c>=0.999 else '❌'}")
|
|
|
|
# ── Test 2: Decode attention with KV cache ────────────────────────
|
|
print(f"\n --- Test 2: Decode attention (1 token, 8 cached) ---")
|
|
|
|
block_size = 256
|
|
num_blocks = 64
|
|
kv_cache_fp8 = torch.zeros(num_blocks, block_size, HD, dtype=torch.float8_e4m3fn, device=DEV)
|
|
inv_scale_cache = torch.zeros(num_blocks * block_size, 1, dtype=torch.bfloat16, device=DEV)
|
|
|
|
with torch.no_grad():
|
|
# Write prefill KV to cache
|
|
kv_fp8, inv_s = kv_quantize_fp8(kv_rope)
|
|
prefill_slots = positions
|
|
paged_kv_write(kv_fp8, prefill_slots, kv_cache_fp8, block_size)
|
|
for t in range(N):
|
|
inv_scale_cache[prefill_slots[t]] = inv_s[t]
|
|
|
|
# Decode: token at position 8
|
|
decode_id = torch.tensor([991], dtype=torch.long, device=DEV)
|
|
decode_pos = torch.tensor([N], dtype=torch.int64, device=DEV)
|
|
|
|
hidden_d = emb[decode_id]
|
|
normed_d = rms(hidden_d, anorm, EPS)
|
|
qa_d = r_qa.run(normed_d)
|
|
kv_d = r_kv.run(normed_d)
|
|
qa_n_d = rms(qa_d, qn, EPS)
|
|
kv_n_d = rms(kv_d, kvn, EPS)
|
|
q_d = r_qb.run(qa_n_d).view(1, NH, HD)
|
|
q_rope_d = apply_gptj_rope(q_d, decode_pos, cos_sin, NOPE, ROPE)
|
|
kv_rope_d = apply_gptj_rope(kv_n_d.unsqueeze(1), decode_pos, cos_sin, NOPE, ROPE).squeeze(1)
|
|
|
|
# Write decode KV to cache
|
|
kv_fp8_d, inv_s_d = kv_quantize_fp8(kv_rope_d)
|
|
paged_kv_write(kv_fp8_d, decode_pos, kv_cache_fp8, block_size)
|
|
inv_scale_cache[decode_pos[0]] = inv_s_d[0]
|
|
|
|
# Read ALL 9 tokens from cache
|
|
all_slots = torch.arange(N + 1, dtype=torch.int64, device=DEV)
|
|
kv_cached_fp8 = paged_kv_read(all_slots, kv_cache_fp8, block_size, N + 1, HD)
|
|
kv_cached = kv_dequantize_fp8(kv_cached_fp8, inv_scale_cache[all_slots])
|
|
|
|
# Decode attention: 1 query vs 9 cached KVs
|
|
o_decode = decode_attention(q_rope_d, kv_cached, SCALE)
|
|
print(f" Decode attention output: amax={o_decode.amax():.4f} NaN={torch.isnan(o_decode).any()}")
|
|
|
|
# BF16 reference: process all 9 tokens at once
|
|
all_ids = torch.cat([token_ids, decode_id])
|
|
all_pos = torch.arange(N + 1, dtype=torch.int64, device=DEV)
|
|
hidden_all = emb[all_ids]
|
|
normed_all = rms(hidden_all, anorm, EPS)
|
|
qa_all = r_qa.run(normed_all)
|
|
kv_all = r_kv.run(normed_all)
|
|
qa_n_all = rms(qa_all, qn, EPS)
|
|
kv_n_all = rms(kv_all, kvn, EPS)
|
|
q_all = r_qb.run(qa_n_all).view(N + 1, NH, HD)
|
|
q_rope_all = apply_gptj_rope(q_all, all_pos, cos_sin, NOPE, ROPE)
|
|
kv_rope_all = apply_gptj_rope(kv_n_all.unsqueeze(1), all_pos, cos_sin, NOPE, ROPE).squeeze(1)
|
|
|
|
o_ref_all = causal_prefill_attention(q_rope_all, kv_rope_all, SCALE)
|
|
o_ref_decode = o_ref_all[N:] # Only the decode token
|
|
|
|
c = F.cosine_similarity(o_decode.flatten().unsqueeze(0).float(), o_ref_decode.flatten().unsqueeze(0).float()).item()
|
|
print(f" Decode vs BF16 reference cosine: {c:.6f} {'✅' if c>=0.98 else '❌'}")
|
|
|
|
# ── Test 3: Full output pipeline (inverse RoPE + o_a + o_b) ──────
|
|
print(f"\n --- Test 3: Full output pipeline ---")
|
|
with torch.no_grad():
|
|
# Using decode attention output
|
|
o_inv = apply_inv_gptj_rope(o_decode, decode_pos, cos_sin, NOPE, ROPE)
|
|
o_grouped = o_inv.view(1, OG, HPG * HD).permute(1, 0, 2)
|
|
woa_3d = woa.view(OG, OL, HPG * HD)
|
|
z_cached = torch.bmm(o_grouped, woa_3d.transpose(1, 2)).permute(1, 0, 2).reshape(1, OG * OL)
|
|
attn_out_cached = r_wob.run(z_cached)
|
|
|
|
# Using BF16 reference
|
|
o_inv_ref = apply_inv_gptj_rope(o_ref_decode, decode_pos, cos_sin, NOPE, ROPE)
|
|
o_grouped_ref = o_inv_ref.view(1, OG, HPG * HD).permute(1, 0, 2)
|
|
z_ref = torch.bmm(o_grouped_ref, woa_3d.transpose(1, 2)).permute(1, 0, 2).reshape(1, OG * OL)
|
|
attn_out_ref = r_wob.run(z_ref)
|
|
|
|
c_full = F.cosine_similarity(attn_out_cached.flatten().unsqueeze(0).float(), attn_out_ref.flatten().unsqueeze(0).float()).item()
|
|
print(f" Full pipeline cosine: {c_full:.6f} {'✅' if c_full>=0.98 else '❌'}")
|
|
print(f" Output amax: cached={attn_out_cached.amax():.4f} ref={attn_out_ref.amax():.4f}")
|
|
|
|
# ── Test 4: Multi-step decode (3 decode steps) ───────────────────
|
|
print(f"\n --- Test 4: Multi-step decode (3 steps) ---")
|
|
decode_ids = torch.tensor([991, 1502, 4200], dtype=torch.long, device=DEV)
|
|
|
|
with torch.no_grad():
|
|
cosines = []
|
|
for step in range(3):
|
|
pos = N + step
|
|
dpos = torch.tensor([pos], dtype=torch.int64, device=DEV)
|
|
d_id = decode_ids[step:step+1]
|
|
|
|
hidden_s = emb[d_id]
|
|
normed_s = rms(hidden_s, anorm, EPS)
|
|
qa_s = r_qa.run(normed_s)
|
|
kv_s = r_kv.run(normed_s)
|
|
qa_n_s = rms(qa_s, qn, EPS)
|
|
kv_n_s = rms(kv_s, kvn, EPS)
|
|
q_s = r_qb.run(qa_n_s).view(1, NH, HD)
|
|
q_rope_s = apply_gptj_rope(q_s, dpos, cos_sin, NOPE, ROPE)
|
|
kv_rope_s = apply_gptj_rope(kv_n_s.unsqueeze(1), dpos, cos_sin, NOPE, ROPE).squeeze(1)
|
|
|
|
# Write to cache
|
|
kv_fp8_s, inv_s_s = kv_quantize_fp8(kv_rope_s)
|
|
paged_kv_write(kv_fp8_s, dpos, kv_cache_fp8, block_size)
|
|
inv_scale_cache[dpos[0]] = inv_s_s[0]
|
|
|
|
# Read all cached KV
|
|
all_s = torch.arange(pos + 1, dtype=torch.int64, device=DEV)
|
|
kv_all_fp8 = paged_kv_read(all_s, kv_cache_fp8, block_size, pos + 1, HD)
|
|
kv_all_dequant = kv_dequantize_fp8(kv_all_fp8, inv_scale_cache[all_s])
|
|
|
|
# Decode attention
|
|
o_s = decode_attention(q_rope_s, kv_all_dequant, SCALE)
|
|
|
|
# BF16 reference
|
|
all_ids_ref = torch.cat([token_ids, decode_ids[:step+1]])
|
|
all_pos_ref = torch.arange(pos + 1, dtype=torch.int64, device=DEV)
|
|
hidden_ref = emb[all_ids_ref]
|
|
normed_ref = rms(hidden_ref, anorm, EPS)
|
|
qa_ref = r_qa.run(normed_ref)
|
|
kv_ref = r_kv.run(normed_ref)
|
|
qa_n_ref = rms(qa_ref, qn, EPS)
|
|
kv_n_ref = rms(kv_ref, kvn, EPS)
|
|
q_ref = r_qb.run(qa_n_ref).view(pos + 1, NH, HD)
|
|
q_rope_ref = apply_gptj_rope(q_ref, all_pos_ref, cos_sin, NOPE, ROPE)
|
|
kv_rope_ref = apply_gptj_rope(kv_n_ref.unsqueeze(1), all_pos_ref, cos_sin, NOPE, ROPE).squeeze(1)
|
|
o_ref_full = causal_prefill_attention(q_rope_ref, kv_rope_ref, SCALE)
|
|
o_ref_last = o_ref_full[-1:]
|
|
|
|
c = F.cosine_similarity(o_s.flatten().unsqueeze(0).float(), o_ref_last.flatten().unsqueeze(0).float()).item()
|
|
cosines.append(c)
|
|
print(f" Step {step} (pos={pos}, {pos+1} cached): cosine = {c:.6f} {'✅' if c>=0.98 else '❌'}")
|
|
|
|
# Cleanup
|
|
del r_qa, r_qb, r_kv, r_wob
|
|
torch.cuda.empty_cache()
|
|
|
|
return c_full, cosines
|
|
|
|
|
|
def main():
|
|
print("=" * 70)
|
|
print(" DeepSeek-V4 Blackwell Attention Pipeline Test")
|
|
print(" Tests cutedsl.blackwell_attention with real weights")
|
|
print("=" * 70)
|
|
|
|
# Test SWA layer (layer 60, compress_ratio=0)
|
|
c_swa, cosines_swa = test_blackwell_attention(60, 0)
|
|
|
|
print(f"\n{'='*70}")
|
|
print(f" SUMMARY")
|
|
print(f" Layer 60 (SWA):")
|
|
print(f" Full pipeline cosine: {c_swa:.6f}")
|
|
print(f" Multi-step decode: {', '.join(f'{c:.6f}' for c in cosines_swa)}")
|
|
print(f"{'='*70}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|