- native_swa_decode.py: BlackwellSWADecodeKernel
- CTA mapping: 1 CTA per (decode_token, q_head_group)
- Online softmax with KV tile streaming (16 tokens/tile)
- Pre-dequantized bf16 KV (fp8 dequant on host - MLIR cvt_fpext
requires 32-bit aligned vector, no scalar fp8->bf16 support)
- Cosine 0.9999+ vs PyTorch batched SDPA reference
- Fallback _fallback_batched_sdp when CuTeDSL unavailable
- native_sparse_decode.py: BlackwellSparseDecodeKernel
- Combined SWA + compressed KV in single attention pass
- Supports CSA (cr=4) and HCA (cr=128) layers
- Sink weight merge on host side
- Cosine 0.9999+ vs combined SDPA reference
- fp8_bf16.py: Documents MLIR limitation (cvt_fpext requires
vector<4xf8>, no scalar support). Pre-dequant is the workaround.
- vLLM wiring (attention.py):
- SWA-only layers: native_swa_decode_attention
- CSA/HCA layers: native_sparse_decode_attention with topk + attn_sink
- csa_attention.py updated to use native kernels
- Tests: test_decode_pipeline.py, test_sparse_decode.py both passing
72 lines
3.4 KiB
Python
72 lines
3.4 KiB
Python
import sys, torch, torch.nn.functional as F
|
|
sys.path.insert(0, "/root/dsv4-nvfp4-workspace/kernel")
|
|
from cutedsl.native_sparse_decode import native_sparse_decode_attention
|
|
|
|
torch.manual_seed(42)
|
|
torch.cuda.set_device(0)
|
|
NH, HD, BS, WIN, TOPK = 128, 512, 256, 128, 16
|
|
|
|
for nt, swa_l, topk_l in [(2,32,8), (2,64,16), (4,32,16), (4,64,8)]:
|
|
q = torch.randn(nt, NH, HD, dtype=torch.bfloat16, device="cuda:0") * 0.1
|
|
nb = 4
|
|
# SWA cache
|
|
kv_bf = torch.randn(nb*BS, HD, dtype=torch.bfloat16, device="cuda:0") * 0.5
|
|
am = kv_bf.float().abs().amax(-1, keepdim=True).clamp(min=1e-12)
|
|
f8m = torch.tensor(448.0, dtype=torch.float32, device="cuda:0")
|
|
swa_cache = (kv_bf.float() * f8m / am).to(torch.float8_e4m3fn)[:nb*BS].reshape(nb,BS,HD).view(torch.uint8)
|
|
inv_sc = (am / f8m).to(torch.bfloat16)
|
|
# Compressed cache
|
|
comp_bf = torch.randn(nb*BS, HD, dtype=torch.bfloat16, device="cuda:0") * 0.3
|
|
am2 = comp_bf.float().abs().amax(-1, keepdim=True).clamp(min=1e-12)
|
|
comp_cache = (comp_bf.float() * f8m / am2).to(torch.float8_e4m3fn)[:nb*BS].reshape(nb,BS,HD).view(torch.uint8)
|
|
inv_sc2 = (am2 / f8m).to(torch.bfloat16)
|
|
|
|
si = torch.zeros(nt, WIN, dtype=torch.int64, device="cuda:0")
|
|
sl = torch.zeros(nt, dtype=torch.int64, device="cuda:0")
|
|
ti = torch.zeros(nt, TOPK, dtype=torch.int64, device="cuda:0")
|
|
tl = torch.zeros(nt, dtype=torch.int64, device="cuda:0")
|
|
for t in range(nt):
|
|
sl[t] = swa_l
|
|
for i in range(swa_l): si[t,i] = i
|
|
for i in range(swa_l, WIN): si[t,i] = -1
|
|
tl[t] = topk_l
|
|
for i in range(topk_l): ti[t,i] = 1000+i
|
|
for i in range(topk_l, TOPK): ti[t,i] = -1
|
|
|
|
sink = torch.full((NH,), float("-inf"), dtype=torch.float32, device="cuda:0")
|
|
ascale = HD ** -0.5
|
|
|
|
# Reference: combined SDPA
|
|
safe_swa = si.clamp(min=0)
|
|
swa_raw = swa_cache[safe_swa//BS, safe_swa%BS].view(torch.float8_e4m3fn)
|
|
swa_kv = (swa_raw.to(torch.bfloat16)*inv_sc[safe_swa]).to(torch.bfloat16)
|
|
comp_bs = comp_cache.shape[1]
|
|
safe_topk = ti.clamp(min=0)
|
|
comp_raw = comp_cache[safe_topk//comp_bs, safe_topk%comp_bs].view(torch.float8_e4m3fn)
|
|
comp_kv = (comp_raw.to(torch.bfloat16)*inv_sc2[safe_topk]).to(torch.bfloat16)
|
|
kv_comb = torch.cat([swa_kv, comp_kv], dim=1)
|
|
total = WIN + TOPK
|
|
cl = sl + tl
|
|
|
|
# Build mask
|
|
pos = torch.arange(total, device="cuda:0").unsqueeze(0)
|
|
lm = pos >= cl.unsqueeze(1)
|
|
inv_s = si < 0
|
|
inv_t = ti < 0
|
|
inv = torch.cat([inv_s, inv_t], dim=1)
|
|
mask = lm | inv
|
|
fm = torch.zeros(mask.shape, dtype=torch.bfloat16, device="cuda:0")
|
|
fm[mask] = float("-inf")
|
|
|
|
qt = q.permute(1,0,2).reshape(NH*nt,1,HD)
|
|
kve = kv_comb.unsqueeze(0).expand(NH,nt,total,HD).reshape(NH*nt,total,HD)
|
|
mb = fm.unsqueeze(0).unsqueeze(2).expand(NH,nt,1,total).reshape(NH*nt,1,total)
|
|
ref = F.scaled_dot_product_attention(qt, kve, kve, attn_mask=mb, is_causal=False, scale=ascale).reshape(NH,nt,HD).permute(1,0,2)
|
|
|
|
try:
|
|
nat = native_sparse_decode_attention(q, swa_cache, inv_sc, si, sl, comp_cache, inv_sc2, ti, tl, sink, BS, ascale, WIN, compress_ratio=4)
|
|
c = F.cosine_similarity(ref.flatten().unsqueeze(0).float(), nat.flatten().unsqueeze(0).float()).item()
|
|
print(f"tokens={nt} swa={swa_l} topk={topk_l} cosine={c:.6f} {'OK' if c>=0.99 else 'LOW'}")
|
|
except Exception as e:
|
|
print(f"tokens={nt} swa={swa_l} topk={topk_l} FAILED: {e}")
|