Files
nvfp4-megamoe-kernel/tests/test_sparse_decode.py
biondizzle bbba289bd8 feat: GPU-native SWA + sparse decode attention kernels (CuTeDSL)
- native_swa_decode.py: BlackwellSWADecodeKernel
  - CTA mapping: 1 CTA per (decode_token, q_head_group)
  - Online softmax with KV tile streaming (16 tokens/tile)
  - Pre-dequantized bf16 KV (fp8 dequant on host - MLIR cvt_fpext
    requires 32-bit aligned vector, no scalar fp8->bf16 support)
  - Cosine 0.9999+ vs PyTorch batched SDPA reference
  - Fallback _fallback_batched_sdp when CuTeDSL unavailable

- native_sparse_decode.py: BlackwellSparseDecodeKernel
  - Combined SWA + compressed KV in single attention pass
  - Supports CSA (cr=4) and HCA (cr=128) layers
  - Sink weight merge on host side
  - Cosine 0.9999+ vs combined SDPA reference

- fp8_bf16.py: Documents MLIR limitation (cvt_fpext requires
  vector<4xf8>, no scalar support). Pre-dequant is the workaround.

- vLLM wiring (attention.py):
  - SWA-only layers: native_swa_decode_attention
  - CSA/HCA layers: native_sparse_decode_attention with topk + attn_sink
  - csa_attention.py updated to use native kernels

- Tests: test_decode_pipeline.py, test_sparse_decode.py both passing
2026-05-20 05:46:15 +00:00

72 lines
3.4 KiB
Python

import sys, torch, torch.nn.functional as F
sys.path.insert(0, "/root/dsv4-nvfp4-workspace/kernel")
from cutedsl.native_sparse_decode import native_sparse_decode_attention
torch.manual_seed(42)
torch.cuda.set_device(0)
NH, HD, BS, WIN, TOPK = 128, 512, 256, 128, 16
for nt, swa_l, topk_l in [(2,32,8), (2,64,16), (4,32,16), (4,64,8)]:
q = torch.randn(nt, NH, HD, dtype=torch.bfloat16, device="cuda:0") * 0.1
nb = 4
# SWA cache
kv_bf = torch.randn(nb*BS, HD, dtype=torch.bfloat16, device="cuda:0") * 0.5
am = kv_bf.float().abs().amax(-1, keepdim=True).clamp(min=1e-12)
f8m = torch.tensor(448.0, dtype=torch.float32, device="cuda:0")
swa_cache = (kv_bf.float() * f8m / am).to(torch.float8_e4m3fn)[:nb*BS].reshape(nb,BS,HD).view(torch.uint8)
inv_sc = (am / f8m).to(torch.bfloat16)
# Compressed cache
comp_bf = torch.randn(nb*BS, HD, dtype=torch.bfloat16, device="cuda:0") * 0.3
am2 = comp_bf.float().abs().amax(-1, keepdim=True).clamp(min=1e-12)
comp_cache = (comp_bf.float() * f8m / am2).to(torch.float8_e4m3fn)[:nb*BS].reshape(nb,BS,HD).view(torch.uint8)
inv_sc2 = (am2 / f8m).to(torch.bfloat16)
si = torch.zeros(nt, WIN, dtype=torch.int64, device="cuda:0")
sl = torch.zeros(nt, dtype=torch.int64, device="cuda:0")
ti = torch.zeros(nt, TOPK, dtype=torch.int64, device="cuda:0")
tl = torch.zeros(nt, dtype=torch.int64, device="cuda:0")
for t in range(nt):
sl[t] = swa_l
for i in range(swa_l): si[t,i] = i
for i in range(swa_l, WIN): si[t,i] = -1
tl[t] = topk_l
for i in range(topk_l): ti[t,i] = 1000+i
for i in range(topk_l, TOPK): ti[t,i] = -1
sink = torch.full((NH,), float("-inf"), dtype=torch.float32, device="cuda:0")
ascale = HD ** -0.5
# Reference: combined SDPA
safe_swa = si.clamp(min=0)
swa_raw = swa_cache[safe_swa//BS, safe_swa%BS].view(torch.float8_e4m3fn)
swa_kv = (swa_raw.to(torch.bfloat16)*inv_sc[safe_swa]).to(torch.bfloat16)
comp_bs = comp_cache.shape[1]
safe_topk = ti.clamp(min=0)
comp_raw = comp_cache[safe_topk//comp_bs, safe_topk%comp_bs].view(torch.float8_e4m3fn)
comp_kv = (comp_raw.to(torch.bfloat16)*inv_sc2[safe_topk]).to(torch.bfloat16)
kv_comb = torch.cat([swa_kv, comp_kv], dim=1)
total = WIN + TOPK
cl = sl + tl
# Build mask
pos = torch.arange(total, device="cuda:0").unsqueeze(0)
lm = pos >= cl.unsqueeze(1)
inv_s = si < 0
inv_t = ti < 0
inv = torch.cat([inv_s, inv_t], dim=1)
mask = lm | inv
fm = torch.zeros(mask.shape, dtype=torch.bfloat16, device="cuda:0")
fm[mask] = float("-inf")
qt = q.permute(1,0,2).reshape(NH*nt,1,HD)
kve = kv_comb.unsqueeze(0).expand(NH,nt,total,HD).reshape(NH*nt,total,HD)
mb = fm.unsqueeze(0).unsqueeze(2).expand(NH,nt,1,total).reshape(NH*nt,1,total)
ref = F.scaled_dot_product_attention(qt, kve, kve, attn_mask=mb, is_causal=False, scale=ascale).reshape(NH,nt,HD).permute(1,0,2)
try:
nat = native_sparse_decode_attention(q, swa_cache, inv_sc, si, sl, comp_cache, inv_sc2, ti, tl, sink, BS, ascale, WIN, compress_ratio=4)
c = F.cosine_similarity(ref.flatten().unsqueeze(0).float(), nat.flatten().unsqueeze(0).float()).item()
print(f"tokens={nt} swa={swa_l} topk={topk_l} cosine={c:.6f} {'OK' if c>=0.99 else 'LOW'}")
except Exception as e:
print(f"tokens={nt} swa={swa_l} topk={topk_l} FAILED: {e}")