import sys, torch, torch.nn.functional as F sys.path.insert(0, "/root/dsv4-nvfp4-workspace/kernel") from cutedsl.native_sparse_decode import native_sparse_decode_attention torch.manual_seed(42) torch.cuda.set_device(0) NH, HD, BS, WIN, TOPK = 128, 512, 256, 128, 16 for nt, swa_l, topk_l in [(2,32,8), (2,64,16), (4,32,16), (4,64,8)]: q = torch.randn(nt, NH, HD, dtype=torch.bfloat16, device="cuda:0") * 0.1 nb = 4 # SWA cache kv_bf = torch.randn(nb*BS, HD, dtype=torch.bfloat16, device="cuda:0") * 0.5 am = kv_bf.float().abs().amax(-1, keepdim=True).clamp(min=1e-12) f8m = torch.tensor(448.0, dtype=torch.float32, device="cuda:0") swa_cache = (kv_bf.float() * f8m / am).to(torch.float8_e4m3fn)[:nb*BS].reshape(nb,BS,HD).view(torch.uint8) inv_sc = (am / f8m).to(torch.bfloat16) # Compressed cache comp_bf = torch.randn(nb*BS, HD, dtype=torch.bfloat16, device="cuda:0") * 0.3 am2 = comp_bf.float().abs().amax(-1, keepdim=True).clamp(min=1e-12) comp_cache = (comp_bf.float() * f8m / am2).to(torch.float8_e4m3fn)[:nb*BS].reshape(nb,BS,HD).view(torch.uint8) inv_sc2 = (am2 / f8m).to(torch.bfloat16) si = torch.zeros(nt, WIN, dtype=torch.int64, device="cuda:0") sl = torch.zeros(nt, dtype=torch.int64, device="cuda:0") ti = torch.zeros(nt, TOPK, dtype=torch.int64, device="cuda:0") tl = torch.zeros(nt, dtype=torch.int64, device="cuda:0") for t in range(nt): sl[t] = swa_l for i in range(swa_l): si[t,i] = i for i in range(swa_l, WIN): si[t,i] = -1 tl[t] = topk_l for i in range(topk_l): ti[t,i] = 1000+i for i in range(topk_l, TOPK): ti[t,i] = -1 sink = torch.full((NH,), float("-inf"), dtype=torch.float32, device="cuda:0") ascale = HD ** -0.5 # Reference: combined SDPA safe_swa = si.clamp(min=0) swa_raw = swa_cache[safe_swa//BS, safe_swa%BS].view(torch.float8_e4m3fn) swa_kv = (swa_raw.to(torch.bfloat16)*inv_sc[safe_swa]).to(torch.bfloat16) comp_bs = comp_cache.shape[1] safe_topk = ti.clamp(min=0) comp_raw = comp_cache[safe_topk//comp_bs, safe_topk%comp_bs].view(torch.float8_e4m3fn) comp_kv = (comp_raw.to(torch.bfloat16)*inv_sc2[safe_topk]).to(torch.bfloat16) kv_comb = torch.cat([swa_kv, comp_kv], dim=1) total = WIN + TOPK cl = sl + tl # Build mask pos = torch.arange(total, device="cuda:0").unsqueeze(0) lm = pos >= cl.unsqueeze(1) inv_s = si < 0 inv_t = ti < 0 inv = torch.cat([inv_s, inv_t], dim=1) mask = lm | inv fm = torch.zeros(mask.shape, dtype=torch.bfloat16, device="cuda:0") fm[mask] = float("-inf") qt = q.permute(1,0,2).reshape(NH*nt,1,HD) kve = kv_comb.unsqueeze(0).expand(NH,nt,total,HD).reshape(NH*nt,total,HD) mb = fm.unsqueeze(0).unsqueeze(2).expand(NH,nt,1,total).reshape(NH*nt,1,total) ref = F.scaled_dot_product_attention(qt, kve, kve, attn_mask=mb, is_causal=False, scale=ascale).reshape(NH,nt,HD).permute(1,0,2) try: nat = native_sparse_decode_attention(q, swa_cache, inv_sc, si, sl, comp_cache, inv_sc2, ti, tl, sink, BS, ascale, WIN, compress_ratio=4) c = F.cosine_similarity(ref.flatten().unsqueeze(0).float(), nat.flatten().unsqueeze(0).float()).item() print(f"tokens={nt} swa={swa_l} topk={topk_l} cosine={c:.6f} {'OK' if c>=0.99 else 'LOW'}") except Exception as e: print(f"tokens={nt} swa={swa_l} topk={topk_l} FAILED: {e}")