Files
nvfp4-megamoe-kernel/tests/test_csa_sparse_attn_b200.py

400 lines
18 KiB
Python

#!/usr/bin/env python3
"""
CSA Sparse Attention Test
Tests the csa_sparse_attention_batched function with simulated compressor output:
1. Create compressed KV cache (simulating compressor output)
2. Create topk_indices (simulating indexer output)
3. Do sparse attention on compressed KV at topk positions
4. Do SWA attention on the window
5. Merge with sink weights
6. Compare against full attention reference
Usage (on B200):
cd /root/nvfp4-megamoe-kernel
PYTHONPATH=/root/nvfp4-megamoe-kernel tests/venv/bin/python tests/test_csa_sparse_attn_b200.py
"""
import sys, os, json, torch, torch.nn.functional as F, time
from safetensors import safe_open
REPO = "/root/nvfp4-megamoe-kernel"
sys.path.insert(0, REPO)
MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
DEV = "cuda:0"
H = 7168; NH = 128; HD = 512; NOPE = 448; ROPE = 64
QL = 1536; OL = 1024; OG = 16; HPG = NH // OG
EPS = 1e-6; WINDOW = 128; SCALE = HD ** -0.5
E2M1 = torch.tensor([0,.5,1.,1.5,2.,3.,4.,6.,-0,-.5,-1.,-1.5,-2.,-3.,-4.,6.], dtype=torch.float32)
_cache = {}
def P(k, wm, md):
if k in _cache: return _cache[k]
with safe_open(os.path.join(md, wm[k]), framework="pt") as f:
t = f.get_tensor(k)
_cache[k] = t
return t
def rms(x, w, eps=1e-6):
v = x.float().pow(2).mean(-1, keepdim=True)
return (w.float() * (x * torch.rsqrt(v+eps)).float()).to(x.dtype)
def make_runner(w, sf, gs_t, inf, outf, fused=False, lw=None):
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
fp4 = w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()
s = sf.to(torch.float8_e4m3fn) if sf.dtype != torch.float8_e4m3fn else sf
s = s.permute(1,0).contiguous()
if fused and gs_t.numel() == 2:
g1,g2 = gs_t[0].item(), gs_t[1].item(); gs = max(g1,g2)
if g1 != g2:
s32 = s.float(); sp = lw[0] if lw else outf//2
s32[:sp] *= g1/gs; s32[sp:] *= g2/gs; s = s32.to(torch.float8_e4m3fn)
else:
gs = gs_t.max().item() if gs_t.numel() > 1 else gs_t.item()
r = CuTeDSLNvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device))
r.fp4 = [fp4]; r.sf = [s]; r.gs = [gs]
r.finalize_weights(); r._ensure_initialized()
return r
def build_cos_sin(max_pos=4096, rope_dim=ROPE):
half = rope_dim // 2
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half, dtype=torch.float32) / half))
freqs = torch.outer(torch.arange(max_pos, dtype=torch.float32), inv_freq)
return torch.cat([freqs.cos(), freqs.sin()], dim=-1)
def apply_gptj_rope(x, positions, cos_sin, nope_dim, rope_dim):
if rope_dim == 0 or x.numel() == 0: return x
half = rope_dim // 2
cos = cos_sin[positions, :half].to(x.dtype)
sin = cos_sin[positions, half:2*half].to(x.dtype)
if x.dim() == 3: cos = cos.unsqueeze(1); sin = sin.unsqueeze(1)
x_rope = x[..., nope_dim:].clone()
even = x_rope[..., 0::2]; odd = x_rope[..., 1::2]
out = x.clone()
out[..., nope_dim:][..., 0::2] = even * cos - odd * sin
out[..., nope_dim:][..., 1::2] = even * sin + odd * cos
return out
def apply_inv_gptj_rope(x, positions, cos_sin, nope_dim, rope_dim):
if rope_dim == 0 or x.numel() == 0: return x
half = rope_dim // 2
cos = cos_sin[positions, :half].to(x.dtype)
sin = cos_sin[positions, half:2*half].to(x.dtype)
if x.dim() == 3: cos = cos.unsqueeze(1); sin = sin.unsqueeze(1)
x_rope = x[..., nope_dim:].clone()
even = x_rope[..., 0::2]; odd = x_rope[..., 1::2]
out = x.clone()
out[..., nope_dim:][..., 0::2] = even * cos + odd * sin
out[..., nope_dim:][..., 1::2] = -even * sin + odd * cos
return out
def kv_quantize_fp8(kv_bf16):
amax = kv_bf16.float().abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
fp8_max = torch.tensor(448.0, dtype=torch.float32, device=kv_bf16.device)
scale = fp8_max / amax
kv_fp8 = (kv_bf16.float() * scale).to(torch.float8_e4m3fn)
inv_scale = (amax / fp8_max).to(torch.bfloat16)
return kv_fp8, inv_scale
def kv_dequantize_fp8(kv_fp8, inv_scale):
return (kv_fp8.to(torch.bfloat16) * inv_scale).to(torch.bfloat16)
def causal_prefill_attention(q, kv, scale):
T, NH, HD = q.shape
q_t = q.permute(1, 0, 2)
kv_exp = kv.unsqueeze(0).expand(NH, -1, -1)
out = F.scaled_dot_product_attention(q_t, kv_exp, kv_exp, is_causal=True, scale=scale)
return out.permute(1, 0, 2)
def csa_sparse_gather_attention(q, compressed_kv, topk_indices, topk_lens, scale, cos_sin, nope_dim, rope_dim):
"""CSA sparse attention: gather compressed KV at topk positions, attend.
q: (T, NH, HD) with RoPE already applied
compressed_kv: (num_compressed, HD) — all compressed KV vectors
topk_indices: (T, num_topk) — which compressed positions to attend to
topk_lens: (T,) — how many of the topk_indices are valid
"""
T, NH, HD = q.shape
device = q.device
num_topk = topk_indices.shape[-1]
# Gather compressed KV at topk positions
# Clamp to valid range
safe_idx = topk_indices.clamp(min=0, max=compressed_kv.shape[0] - 1)
# (T, num_topk, HD)
k_gathered = compressed_kv[safe_idx]
# Mask invalid positions (set to 0)
valid_mask = torch.arange(num_topk, device=device).unsqueeze(0) < topk_lens.unsqueeze(1)
k_gathered = k_gathered * valid_mask.unsqueeze(-1).to(k_gathered.dtype)
# Apply RoPE to gathered K at their original (compressed) positions
if rope_dim > 0 and cos_sin is not None:
kv_positions = safe_idx # The positions in the compressed cache
# BUT: compressed position i represents the i-th group of compress_ratio tokens
# The "position" for RoPE should be the original token position, not the compressed index
# For now, use the compressed index as a proxy (this is a simplification)
# In the real pipeline, the compressor stores KV with RoPE already applied
pass # Skip RoPE for now — the compressor already applies it
# Multi-head attention: expand K for all heads
# k_gathered: (T, num_topk, HD) → (T, NH, num_topk, HD)
k_heads = k_gathered.unsqueeze(1).expand(-1, NH, -1, -1)
v_heads = k_heads.clone()
# Q: (T, NH, HD) → (T*NH, 1, HD)
q_2d = q.reshape(T * NH, 1, HD)
k_2d = k_heads.reshape(T * NH, num_topk, HD)
v_2d = v_heads.reshape(T * NH, num_topk, HD)
# Attention mask: (T, num_topk) → (T*NH, 1, num_topk)
attn_mask = valid_mask.unsqueeze(1).expand(-1, NH, -1).reshape(T * NH, 1, num_topk)
out = F.scaled_dot_product_attention(
q_2d, k_2d, v_2d,
attn_mask=attn_mask if not attn_mask.all() else None,
scale=scale,
)
return out.squeeze(1).reshape(T, NH, HD)
def swa_cache_attention(q, swa_kv_cache, inv_scale_cache, positions, block_size, scale, window_size):
"""SWA attention reading from paged KV cache.
q: (1, NH, HD) single decode token
"""
pos = positions[0].item()
all_slots = torch.arange(pos + 1, dtype=torch.int64, device=q.device)
all_bi = all_slots // block_size
all_oi = all_slots % block_size
kv_cached = swa_kv_cache[all_bi, all_oi]
if swa_kv_cache.dtype == torch.uint8:
kv_cached = kv_cached.view(torch.float8_e4m3fn)
kv_inv = inv_scale_cache[all_slots]
kv_deq = kv_dequantize_fp8(kv_cached, kv_inv)
ws = max(0, pos - window_size + 1)
kv_window = kv_deq[ws:]
NH = q.shape[1]
q_t = q.permute(1, 0, 2)
kv_exp = kv_window.unsqueeze(0).expand(NH, -1, -1)
out = F.scaled_dot_product_attention(q_t, kv_exp, kv_exp, is_causal=False, scale=scale)
return out.permute(1, 0, 2)
def test_csa_layer(layer_id, compress_ratio):
"""Test CSA/HCA sparse attention for a specific layer.
Simulates the full pipeline:
1. Prefill: project Q and KV, compute compressed KV, run indexer
2. Decode: sparse attention on compressed KV + SWA on window
3. Merge with sink weights
4. Compare against full attention reference
"""
torch.cuda.set_device(0)
torch.cuda.empty_cache()
with open(os.path.join(MODEL, "model.safetensors.index.json")) as f:
wm = json.load(f)["weight_map"]
G = lambda k: P(k, wm, MODEL).to(DEV)
p = f"model.layers.{layer_id}"; a = f"{p}.self_attn"
cr = compress_ratio
lt = f"C{cr}A" if cr > 1 else "SWA"
emb = G("model.embed_tokens.weight")
anorm = G(f"{p}.input_layernorm.weight")
qn = G(f"{a}.q_a_norm.weight"); kvn = G(f"{a}.kv_norm.weight")
woa = G(f"{a}.o_a_proj.weight")
sinks = G(f"{a}.sinks")
qa_w = G(f"{a}.q_a_proj.weight"); qa_sf = G(f"{a}.q_a_proj.weight_scale"); qa_gs = G(f"{a}.q_a_proj.weight_scale_2")
qb_w = G(f"{a}.q_b_proj.weight"); qb_sf = G(f"{a}.q_b_proj.weight_scale"); qb_gs = G(f"{a}.q_b_proj.weight_scale_2")
kv_w = G(f"{a}.kv_proj.weight"); kv_sf = G(f"{a}.kv_proj.weight_scale"); kv_gs = G(f"{a}.kv_proj.weight_scale_2")
wob_w = G(f"{a}.o_b_proj.weight"); wob_sf = G(f"{a}.o_b_proj.weight_scale"); wob_gs = G(f"{a}.o_b_proj.weight_scale_2")
# Compressor weights
comp_kv_w = G(f"{a}.compressor.kv_proj.weight"); comp_kv_sf = G(f"{a}.compressor.kv_proj.weight_scale"); comp_kv_gs = G(f"{a}.compressor.kv_proj.weight_scale_2")
comp_gate_w = G(f"{a}.compressor.gate_proj.weight"); comp_gate_sf = G(f"{a}.compressor.gate_proj.weight_scale"); comp_gate_gs = G(f"{a}.compressor.gate_proj.weight_scale_2")
r_qa = make_runner(qa_w, qa_sf, qa_gs, H, qa_w.shape[0])
r_qb = make_runner(qb_w, qb_sf, qb_gs, QL, qb_w.shape[0])
r_kv = make_runner(kv_w, kv_sf, kv_gs, H, kv_w.shape[0])
r_wob = make_runner(wob_w, wob_sf, wob_gs, OG*OL, wob_w.shape[0])
r_comp_kv = make_runner(comp_kv_w, comp_kv_sf, comp_kv_gs, H, comp_kv_w.shape[0])
r_comp_gate = make_runner(comp_gate_w, comp_gate_sf, comp_gate_gs, H, comp_gate_w.shape[0])
cos_sin = build_cos_sin(max_pos=4096).to(DEV)
woa_3d = woa.view(OG, OL, HPG * HD)
# Paged KV caches
block_size = 64; max_tokens = 256
num_blocks = (max_tokens + block_size - 1) // block_size
swa_cache = torch.zeros(num_blocks, block_size, HD, dtype=torch.uint8, device=DEV)
swa_inv_scale = torch.zeros(max_tokens, 1, dtype=torch.bfloat16, device=DEV)
N = 128 if cr >= 128 else 16 # Prefill tokens (use a multiple of compress_ratio)
assert N % cr == 0, f"N={N} must be multiple of compress_ratio={cr}"
token_ids = torch.arange(1, N + 1, dtype=torch.long, device=DEV)
with torch.no_grad():
# ── PREFILL ─────────────────────────────────────────────
positions = torch.arange(N, dtype=torch.int64, device=DEV)
hidden = emb[token_ids]
normed = rms(hidden, anorm, EPS)
# Project Q and KV
qa = r_qa.run(normed); kv = r_kv.run(normed)
qa_n = rms(qa, qn, EPS); kv_n = rms(kv, kvn, EPS)
q = r_qb.run(qa_n).view(N, NH, HD)
q_rope = apply_gptj_rope(q, positions, cos_sin, NOPE, ROPE)
kv_rope = apply_gptj_rope(kv_n.unsqueeze(1), positions, cos_sin, NOPE, ROPE).squeeze(1)
# Write prefill KV to SWA cache
kv_fp8, inv_s = kv_quantize_fp8(kv_rope)
bi = positions // block_size; oi = positions % block_size
swa_cache[bi, oi] = kv_fp8.view(torch.uint8)
for t in range(N):
swa_inv_scale[positions[t]] = inv_s[t]
# Compute compressed KV (simulating compressor)
# The compressor takes kv_score from the parallel GEMM, but we can
# approximate by compressing the full KV: average every cr tokens
# In reality, the compressor uses a learned projection, but for testing
# the attention mechanism, averaging is a valid approximation
num_compressed = N // cr
comp_kv = r_comp_kv.run(normed)
comp_gate_out = r_comp_gate.run(normed)
# Simple average pooling for compression
compressed_kv = kv_n.reshape(num_compressed, cr, HD).mean(dim=1) # (num_compressed, HD)
compressed_kv_rope = apply_gptj_rope(
compressed_kv.unsqueeze(1),
torch.arange(num_compressed, dtype=torch.int64, device=DEV),
cos_sin, NOPE, ROPE,
).squeeze(1)
# Simulate indexer output: topk indices
# For testing, just use the compressed positions
num_topk = min(16, num_compressed) # Use up to 16 topk positions
topk_indices = torch.arange(num_compressed, dtype=torch.int64, device=DEV).unsqueeze(0).expand(N, -1)
topk_lens = torch.full((N,), num_compressed, dtype=torch.int64, device=DEV)
# ── CSA Sparse Attention (prefill) ───────────────────────
# For prefill, we do full causal attention (simpler and correct)
o_prefill = causal_prefill_attention(q_rope, kv_rope, SCALE)
# ── DECODE ──────────────────────────────────────────────
decode_id = torch.tensor([N], dtype=torch.long, device=DEV)
pos_d = torch.tensor([N], dtype=torch.int64, device=DEV)
hidden_d = emb[decode_id]
normed_d = rms(hidden_d, anorm, EPS)
qa_d = r_qa.run(normed_d); kv_d = r_kv.run(normed_d)
qa_n_d = rms(qa_d, qn, EPS); kv_n_d = rms(kv_d, kvn, EPS)
q_d = r_qb.run(qa_n_d).view(1, NH, HD)
q_rope_d = apply_gptj_rope(q_d, pos_d, cos_sin, NOPE, ROPE)
kv_rope_d = apply_gptj_rope(kv_n_d.unsqueeze(1), pos_d, cos_sin, NOPE, ROPE).squeeze(1)
# Write decode KV to SWA cache
kv_fp8_d, inv_s_d = kv_quantize_fp8(kv_rope_d)
bi_d = pos_d[0].item() // block_size
oi_d = pos_d[0].item() % block_size
swa_cache[bi_d, oi_d] = kv_fp8_d[0].view(torch.uint8)
swa_inv_scale[pos_d[0].item()] = inv_s_d[0]
# Compute compressed KV for decode
comp_kv_d = r_comp_kv.run(normed_d)
# Append to compressed cache
num_compressed_total = num_compressed + 1
compressed_kv_all = torch.cat([compressed_kv_rope, kv_n_d], dim=0)
# Decode: sparse attention on compressed KV
topk_d = torch.arange(num_compressed_total, dtype=torch.int64, device=DEV).unsqueeze(0)
topk_lens_d = torch.tensor([num_compressed_total], dtype=torch.int64, device=DEV)
sparse_out = csa_sparse_gather_attention(
q_rope_d, compressed_kv_all, topk_d, topk_lens_d,
SCALE, cos_sin, NOPE, ROPE,
)
# Decode: SWA attention
swa_out = swa_cache_attention(
q_rope_d, swa_cache, swa_inv_scale, pos_d, block_size, SCALE, WINDOW,
)
# Merge with sink weights
sink_w = torch.sigmoid(sinks).view(1, NH, 1)
merged_out = sparse_out * (1 - sink_w) + swa_out * sink_w
# ── Reference: full causal attention on all tokens ──────
all_ids = torch.cat([token_ids, decode_id])
all_pos = torch.arange(N + 1, dtype=torch.int64, device=DEV)
hidden_ref = emb[all_ids]
normed_ref = rms(hidden_ref, anorm, EPS)
qa_ref = r_qa.run(normed_ref); kv_ref = r_kv.run(normed_ref)
qa_n_ref = rms(qa_ref, qn, EPS); kv_n_ref = rms(kv_ref, kvn, EPS)
q_ref = r_qb.run(qa_n_ref).view(N + 1, NH, HD)
q_rope_ref = apply_gptj_rope(q_ref, all_pos, cos_sin, NOPE, ROPE)
kv_rope_ref = apply_gptj_rope(kv_n_ref.unsqueeze(1), all_pos, cos_sin, NOPE, ROPE).squeeze(1)
o_ref = causal_prefill_attention(q_rope_ref, kv_rope_ref, SCALE)
o_ref_decode = o_ref[-1:] # Only the decode token
# ── Full output pipeline ────────────────────────────────
# Merged
o_inv = apply_inv_gptj_rope(merged_out, pos_d, cos_sin, NOPE, ROPE)
o_grp = o_inv.reshape(1, OG, HPG * HD).permute(1, 0, 2)
z = torch.bmm(o_grp, woa_3d.transpose(1, 2)).permute(1, 0, 2).reshape(1, OG * OL)
attn_merged = r_wob.run(z)
# Reference
o_inv_ref = apply_inv_gptj_rope(o_ref_decode, pos_d, cos_sin, NOPE, ROPE)
o_grp_ref = o_inv_ref.reshape(1, OG, HPG * HD).permute(1, 0, 2)
z_ref = torch.bmm(o_grp_ref, woa_3d.transpose(1, 2)).permute(1, 0, 2).reshape(1, OG * OL)
attn_ref = r_wob.run(z_ref)
# ── COMPARE ─────────────────────────────────────────────
# Note: CSA sparse attention with avg-pooled KV won't match full attention perfectly.
# But it should be > 0.5 cosine (the structure is preserved)
c_attn = F.cosine_similarity(merged_out.flatten().unsqueeze(0).float(), o_ref_decode.flatten().unsqueeze(0).float()).item()
c_full = F.cosine_similarity(attn_merged.flatten().unsqueeze(0).float(), attn_ref.flatten().unsqueeze(0).float()).item()
# Also check SWA-only (window) attention
c_swa = F.cosine_similarity(swa_out.flatten().unsqueeze(0).float(), o_ref_decode.flatten().unsqueeze(0).float()).item()
del r_qa, r_qb, r_kv, r_wob, r_comp_kv, r_comp_gate
torch.cuda.empty_cache()
_cache.clear()
return c_attn, c_full, c_swa
def main():
print("=" * 70)
print(" CSA Sparse Attention Test")
print(" Tests compressed KV gather + sparse attention + SWA merge")
print("=" * 70)
# Test C128A layer (layer 0)
c_attn, c_full, c_swa = test_csa_layer(0, 128)
print(f" Layer 0 (C128A):")
print(f" Merged (sparse+SWA) attn cosine: {c_attn:.4f}")
print(f" Full pipeline cosine: {c_full:.4f}")
print(f" SWA-only cosine: {c_swa:.4f}")
# Test C4A layer (layer 1)
c_attn, c_full, c_swa = test_csa_layer(1, 4)
print(f" Layer 1 (C4A):")
print(f" Merged (sparse+SWA) attn cosine: {c_attn:.4f}")
print(f" Full pipeline cosine: {c_full:.4f}")
print(f" SWA-only cosine: {c_swa:.4f}")
print(f"\n{'='*70}")
print(f" SWA-only cosine should be >0.98 (proven in decode vs prefill test)")
print(f" Merged cosine may be lower (avg-pooled KV is an approximation)")
print(f" The important thing: no NaN, reasonable values")
print(f"{'='*70}")
if __name__ == "__main__":
main()