Files
nvfp4-megamoe-kernel/tests/archive/test_sparse_decode.py
biondizzle 3fb3c925af Restructure: cutedsl/ -> dsv4/ with proper layering
- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py
- Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc.
- Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda)
- Moved PyTorch bridges to dsv4/ops/
- Moved nn.Module layers to dsv4layers/
- Moved reference implementations to dsv4/reference/
- Moved vendored CUTLASS code to vendored/
- Archived ~190 debug tests to tests/archive/
- Kept ~15 canonical tests in tests/unit/
- Updated all import paths
- Added stubs for future components (model/, cache/, loader/)
- Updated pyproject.toml: dsv4-inference package name
2026-05-21 17:30:44 +00:00

72 lines
3.3 KiB
Python

import sys, torch, torch.nn.functional as F
sys.path.insert(0, "/root/dsv4-nvfp4-workspace/kernel")
from dsv4.ops.decode_sparse import native_sparse_decode_attention
torch.manual_seed(42)
torch.cuda.set_device(0)
NH, HD, BS, WIN, TOPK = 128, 512, 256, 128, 16
for nt, swa_l, topk_l in [(2,32,8), (2,64,16), (4,32,16), (4,64,8)]:
q = torch.randn(nt, NH, HD, dtype=torch.bfloat16, device="cuda:0") * 0.1
nb = 4
# SWA cache
kv_bf = torch.randn(nb*BS, HD, dtype=torch.bfloat16, device="cuda:0") * 0.5
am = kv_bf.float().abs().amax(-1, keepdim=True).clamp(min=1e-12)
f8m = torch.tensor(448.0, dtype=torch.float32, device="cuda:0")
swa_cache = (kv_bf.float() * f8m / am).to(torch.float8_e4m3fn)[:nb*BS].reshape(nb,BS,HD).view(torch.uint8)
inv_sc = (am / f8m).to(torch.bfloat16)
# Compressed cache
comp_bf = torch.randn(nb*BS, HD, dtype=torch.bfloat16, device="cuda:0") * 0.3
am2 = comp_bf.float().abs().amax(-1, keepdim=True).clamp(min=1e-12)
comp_cache = (comp_bf.float() * f8m / am2).to(torch.float8_e4m3fn)[:nb*BS].reshape(nb,BS,HD).view(torch.uint8)
inv_sc2 = (am2 / f8m).to(torch.bfloat16)
si = torch.zeros(nt, WIN, dtype=torch.int64, device="cuda:0")
sl = torch.zeros(nt, dtype=torch.int64, device="cuda:0")
ti = torch.zeros(nt, TOPK, dtype=torch.int64, device="cuda:0")
tl = torch.zeros(nt, dtype=torch.int64, device="cuda:0")
for t in range(nt):
sl[t] = swa_l
for i in range(swa_l): si[t,i] = i
for i in range(swa_l, WIN): si[t,i] = -1
tl[t] = topk_l
for i in range(topk_l): ti[t,i] = 1000+i
for i in range(topk_l, TOPK): ti[t,i] = -1
sink = torch.full((NH,), float("-inf"), dtype=torch.float32, device="cuda:0")
ascale = HD ** -0.5
# Reference: combined SDPA
safe_swa = si.clamp(min=0)
swa_raw = swa_cache[safe_swa//BS, safe_swa%BS].view(torch.float8_e4m3fn)
swa_kv = (swa_raw.to(torch.bfloat16)*inv_sc[safe_swa]).to(torch.bfloat16)
comp_bs = comp_cache.shape[1]
safe_topk = ti.clamp(min=0)
comp_raw = comp_cache[safe_topk//comp_bs, safe_topk%comp_bs].view(torch.float8_e4m3fn)
comp_kv = (comp_raw.to(torch.bfloat16)*inv_sc2[safe_topk]).to(torch.bfloat16)
kv_comb = torch.cat([swa_kv, comp_kv], dim=1)
total = WIN + TOPK
cl = sl + tl
# Build mask
pos = torch.arange(total, device="cuda:0").unsqueeze(0)
lm = pos >= cl.unsqueeze(1)
inv_s = si < 0
inv_t = ti < 0
inv = torch.cat([inv_s, inv_t], dim=1)
mask = lm | inv
fm = torch.zeros(mask.shape, dtype=torch.bfloat16, device="cuda:0")
fm[mask] = float("-inf")
qt = q.permute(1,0,2).reshape(NH*nt,1,HD)
kve = kv_comb.unsqueeze(0).expand(NH,nt,total,HD).reshape(NH*nt,total,HD)
mb = fm.unsqueeze(0).unsqueeze(2).expand(NH,nt,1,total).reshape(NH*nt,1,total)
ref = F.scaled_dot_product_attention(qt, kve, kve, attn_mask=mb, is_causal=False, scale=ascale).reshape(NH,nt,HD).permute(1,0,2)
try:
nat = native_sparse_decode_attention(q, swa_cache, inv_sc, si, sl, comp_cache, inv_sc2, ti, tl, sink, BS, ascale, WIN, compress_ratio=4)
c = F.cosine_similarity(ref.flatten().unsqueeze(0).float(), nat.flatten().unsqueeze(0).float()).item()
print(f"tokens={nt} swa={swa_l} topk={topk_l} cosine={c:.6f} {'OK' if c>=0.99 else 'LOW'}")
except Exception as e:
print(f"tokens={nt} swa={swa_l} topk={topk_l} FAILED: {e}")