Files
nvfp4-megamoe-kernel/tests/test_decode_attention_b200.py

461 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
DeepSeek-V4 Decode Attention Pipeline Test
REPRODUCES THE BUG: The vLLM Blackwell path uses raw KV for attention,
which means decode (generating token N+1 when tokens 0..N are in the KV cache)
produces garbage because the cache is never written to.
This test simulates the actual decode scenario:
1. Prefill: compute KV for N tokens, write to paged cache
2. Decode: compute KV for 1 new token, write to cache, then attend to ALL cached KV
The key insight: during decode, you can't use raw KV — you need the KV cache
because previous tokens' KV was computed in a prior forward pass.
Architecture:
- KV latent is (T, 512) — single head, shared across all 128 Q heads
- After kv_norm + RoPE, KV is quantized to fp8 and stored in paged cache
- Attention: Q (128 heads) × K^T → softmax → × V
- For CSA/HCA: attention attends to compressed positions (every 4th or 128th)
- For SWA: attention attends to last WINDOW tokens
Usage (on B200):
cd /root/nvfp4-megamoe-kernel
PYTHONPATH=/root/nvfp4-megamoe-kernel tests/venv/bin/python tests/test_decode_attention_b200.py
"""
import sys, os, json, torch, torch.nn.functional as F, math
from safetensors import safe_open
REPO = "/root/nvfp4-megamoe-kernel"
sys.path.insert(0, REPO)
MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
DEV = "cuda:0"
H = 7168; NH = 128; HD = 512; NOPE = 448; ROPE = 64
QL = 1536; OL = 1024; OG = 16; HPG = NH // OG
EPS = 1e-6; WINDOW = 128; SCALE = HD ** -0.5
E2M1 = torch.tensor([0,.5,1.,1.5,2.,3.,4.,6.,-0,-.5,-1.,-1.5,-2.,-3.,-4.,-6.], dtype=torch.float32)
_cache = {}
def P(k, wm, md):
if k in _cache: return _cache[k]
with safe_open(os.path.join(md, wm[k]), framework="pt") as f:
t = f.get_tensor(k)
_cache[k] = t
return t
def dequant(w, sf, gs):
d = w.device; lut = E2M1.to(d)
lo = lut[(w & 0xF).long()]; hi = lut[((w >> 4) & 0xF).long()]
O, I2 = w.shape; I = I2*2
u = torch.empty(O, I, dtype=torch.float32, device=d)
u[:,0::2] = lo; u[:,1::2] = hi
bs = sf.float().repeat_interleave(16, dim=1)[:O,:I]
return (u * bs * gs).to(torch.bfloat16)
def rms(x, w, eps=1e-6):
v = x.float().pow(2).mean(-1, keepdim=True)
return (w.float() * (x * torch.rsqrt(v+eps)).float()).to(x.dtype)
def make_runner(w, sf, gs_t, inf, outf, fused=False, lw=None):
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
fp4 = w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()
s = sf.to(torch.float8_e4m3fn) if sf.dtype != torch.float8_e4m3fn else sf
s = s.permute(1,0).contiguous()
if fused and gs_t.numel() == 2:
g1,g2 = gs_t[0].item(), gs_t[1].item(); gs = max(g1,g2)
if g1 != g2:
s32 = s.float(); sp = lw[0] if lw else outf//2
s32[:sp] *= g1/gs; s32[sp:] *= g2/gs; s = s32.to(torch.float8_e4m3fn)
else:
gs = gs_t.max().item() if gs_t.numel() > 1 else gs_t.item()
r = CuTeDSLNvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device))
r.fp4 = [fp4]; r.sf = [s]; r.gs = [gs]
r.finalize_weights(); r._ensure_initialized()
return r
def build_cos_sin(max_pos=4096, rope_dim=ROPE):
half = rope_dim // 2
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half, dtype=torch.float32) / half))
freqs = torch.outer(torch.arange(max_pos, dtype=torch.float32), inv_freq)
return torch.cat([freqs.cos(), freqs.sin()], dim=-1)
def apply_gptj_rope(x, positions, cos_sin, nope, rope):
if rope == 0 or x.numel() == 0: return x
half = rope // 2
cos = cos_sin[positions, :half].to(x.dtype)
sin = cos_sin[positions, half:2*half].to(x.dtype)
if x.dim() == 3: cos = cos.unsqueeze(1); sin = sin.unsqueeze(1)
x_rope = x[..., nope:].clone()
even = x_rope[..., 0::2]; odd = x_rope[..., 1::2]
out = x.clone()
out[..., nope:][..., 0::2] = even * cos - odd * sin
out[..., nope:][..., 1::2] = even * sin + odd * cos
return out
def apply_inv_gptj_rope(x, positions, cos_sin, nope, rope):
if rope == 0 or x.numel() == 0: return x
half = rope // 2
cos = cos_sin[positions, :half].to(x.dtype)
sin = cos_sin[positions, half:2*half].to(x.dtype)
if x.dim() == 3: cos = cos.unsqueeze(1); sin = sin.unsqueeze(1)
x_rope = x[..., nope:].clone()
even = x_rope[..., 0::2]; odd = x_rope[..., 1::2]
out = x.clone()
out[..., nope:][..., 0::2] = even * cos + odd * sin
out[..., nope:][..., 1::2] = -even * sin + odd * cos
return out
# ── KV Cache Kernels ────────────────────────────────────────────────
def kv_quantize_fp8(kv_bf16):
"""BF16 KV → fp8_e4m3 with per-token scale."""
amax = kv_bf16.float().abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
fp8_max = torch.tensor(448.0, dtype=torch.float32, device=kv_bf16.device)
scale = fp8_max / amax
kv_fp8 = (kv_bf16.float() * scale).to(torch.float8_e4m3fn)
inv_scale = (amax / fp8_max).to(torch.bfloat16)
return kv_fp8, inv_scale
def kv_dequantize_fp8(kv_fp8, inv_scale):
"""fp8 KV → BF16."""
return (kv_fp8.to(torch.bfloat16) * inv_scale).to(torch.bfloat16)
def paged_kv_write(kv_data, slot_mapping, cache, block_size):
"""Write data into paged cache. Works for fp8 or bf16.
kv_data: (T, D) tensor to write
slot_mapping: (T,) slot indices
cache: (num_blocks, block_size, D) cache tensor
"""
for t in range(kv_data.shape[0]):
slot = slot_mapping[t].item()
block_idx = slot // block_size
offset = slot % block_size
if block_idx < cache.shape[0] and offset < cache.shape[1]:
cache[block_idx, offset] = kv_data[t]
def paged_kv_read(slot_mapping, cache, block_size, num_tokens, head_dim):
"""Read KV from paged cache."""
device = cache.device
kv = torch.zeros(num_tokens, head_dim, dtype=cache.dtype, device=device)
for t in range(num_tokens):
slot = slot_mapping[t].item()
block_idx = slot // block_size
offset = slot % block_size
if block_idx < cache.shape[0] and offset < cache.shape[1]:
kv[t] = cache[block_idx, offset]
return kv
# ── Attention ────────────────────────────────────────────────────────
def full_causal_attention(q, kv, scale):
"""Full causal self-attention. q: (T_q, NH, HD), kv: (T_kv, HD).
Works for prefill (T_q == T_kv) and decode (T_q == 1, T_kv > 1).
Uses SDPA for efficiency.
"""
T_q, NH, HD = q.shape
T_kv = kv.shape[0]
# q: (NH, T_q, HD), k/v: (NH, T_kv, HD) — shared KV across heads
q_t = q.permute(1, 0, 2) # (NH, T_q, HD)
kv_exp = kv.unsqueeze(0).expand(NH, -1, -1) # (NH, T_kv, HD)
v_exp = kv_exp.clone()
# Causal mask: query at position i can attend to positions <= i
# For decode (T_q=1), all T_kv positions are valid (position T_kv-1 attends to 0..T_kv-1)
if T_q == T_kv:
# Prefill: standard causal
attn_mask = torch.tril(torch.ones(T_q, T_kv, device=q.device, dtype=torch.bool)).unsqueeze(0).expand(NH, -1, -1)
out = F.scaled_dot_product_attention(q_t, kv_exp, v_exp, attn_mask=attn_mask, scale=scale)
else:
# Decode or mixed: no masking needed (all positions are in the past)
out = F.scaled_dot_product_attention(q_t, kv_exp, v_exp, is_causal=False, scale=scale)
return out.permute(1, 0, 2) # (T_q, NH, HD)
def swa_decode_attention(q_new, kv_cache_bf16, positions_new, scale, window_size=WINDOW):
"""Decode-time sliding window attention.
q_new: (1, NH, HD) — single new query token with RoPE
kv_cache_bf16: (total_len, HD) — ALL cached KV (already with RoPE)
positions_new: (1,) — position of the new token
"""
total_len = kv_cache_bf16.shape[0]
pos = positions_new[0].item()
window_start = max(0, pos - window_size + 1)
window_len = pos - window_start + 1
# Get the KV window
kv_window = kv_cache_bf16[window_start:pos+1] # (window_len, HD)
NH = q_new.shape[1]
HD = q_new.shape[2]
# Multi-head attention
q_2d = q_new.reshape(NH, HD) # (NH, HD)
k_exp = kv_window.unsqueeze(0).expand(NH, -1, -1) # (NH, window_len, HD)
v_exp = k_exp.clone()
# scores: (NH, 1, window_len)
scores = torch.matmul(q_2d.unsqueeze(1), k_exp.transpose(-1, -2)) * scale
weights = F.softmax(scores.float(), dim=-1).to(q_new.dtype)
out = torch.matmul(weights, v_exp).squeeze(1) # (NH, HD)
return out.unsqueeze(0) # (1, NH, HD)
def test_prefill_decode(layer_id, compress_ratio):
"""Test the full prefill + decode attention pipeline.
Simulates what vLLM actually does:
1. PREFILL: Process N tokens, write their KV to the paged cache
2. DECODE: Process 1 new token, write its KV to the cache, attend to all cached KV
Compares decode output against a full BF16 reference (which processes all tokens at once).
"""
torch.cuda.set_device(0)
torch.manual_seed(42)
torch.cuda.empty_cache()
with open(os.path.join(MODEL, "model.safetensors.index.json")) as f:
wm = json.load(f)["weight_map"]
G = lambda k: P(k, wm, MODEL).to(DEV)
p = f"model.layers.{layer_id}"; a = f"{p}.self_attn"
layer_type = "SWA" if compress_ratio <= 1 else f"CSA(c={compress_ratio})"
print(f"\n{'='*70}")
print(f" Layer {layer_id}{layer_type} — Prefill+Decode Test")
print(f"{'='*70}")
# Load weights
emb = G("model.embed_tokens.weight")
anorm = G(f"{p}.input_layernorm.weight")
qn = G(f"{a}.q_a_norm.weight"); kvn = G(f"{a}.kv_norm.weight")
woa = G(f"{a}.o_a_proj.weight")
qa_w = G(f"{a}.q_a_proj.weight"); qa_sf = G(f"{a}.q_a_proj.weight_scale"); qa_gs = G(f"{a}.q_a_proj.weight_scale_2")
qb_w = G(f"{a}.q_b_proj.weight"); qb_sf = G(f"{a}.q_b_proj.weight_scale"); qb_gs = G(f"{a}.q_b_proj.weight_scale_2")
kv_w = G(f"{a}.kv_proj.weight"); kv_sf = G(f"{a}.kv_proj.weight_scale"); kv_gs = G(f"{a}.kv_proj.weight_scale_2")
wob_w = G(f"{a}.o_b_proj.weight"); wob_sf = G(f"{a}.o_b_proj.weight_scale"); wob_gs = G(f"{a}.o_b_proj.weight_scale_2")
# CuTeDSL runners
r_qa = make_runner(qa_w, qa_sf, qa_gs, H, qa_w.shape[0])
r_qb = make_runner(qb_w, qb_sf, qb_gs, QL, qb_w.shape[0])
r_kv = make_runner(kv_w, kv_sf, kv_gs, H, kv_w.shape[0])
r_wob = make_runner(wob_w, wob_sf, wob_gs, OG*OL, wob_w.shape[0])
# Setup
N_PREFILL = 8 # Number of prefill tokens
N_DECODE = 1 # Single decode token
N_TOTAL = N_PREFILL + N_DECODE
token_ids = torch.tensor([1, 450, 8403, 315, 5413, 374, 2198, 643, 991], dtype=torch.long, device=DEV)
assert len(token_ids) >= N_TOTAL
cos_sin = build_cos_sin(max_pos=4096).to(DEV)
# Paged KV cache
block_size = 256
num_blocks = 64
# Cache stores fp8 KV (with per-token inv_scale stored separately)
kv_cache_fp8 = torch.zeros(num_blocks, block_size, HD, dtype=torch.float8_e4m3fn, device=DEV)
# Per-token inv scales (indexed by slot)
inv_scale_cache = torch.zeros(num_blocks * block_size, 1, dtype=torch.bfloat16, device=DEV)
# RoPE'd BF16 KV cache (for reference — in production, RoPE is applied after dequant)
kv_cache_bf16 = torch.zeros(N_TOTAL, HD, dtype=torch.bfloat16, device=DEV)
with torch.no_grad():
# ════════════════════════════════════════════════════════════════
# STEP 1: PREFILL — process tokens 0..N_PREFILL-1
# ════════════════════════════════════════════════════════════════
prefill_ids = token_ids[:N_PREFILL]
prefill_pos = torch.arange(N_PREFILL, dtype=torch.int64, device=DEV)
prefill_slots = prefill_pos # slot = position (simplified)
hidden_prefill = emb[prefill_ids]
normed_prefill = rms(hidden_prefill, anorm, EPS)
# Project KV
kv_prefill = r_kv.run(normed_prefill)
kv_normed_prefill = rms(kv_prefill, kvn, EPS)
# Apply RoPE to KV BEFORE caching
kv_rope_prefill = apply_gptj_rope(kv_normed_prefill.unsqueeze(1), prefill_pos, cos_sin, NOPE, ROPE).squeeze(1)
# Quantize to fp8
kv_fp8_prefill, inv_scale_prefill = kv_quantize_fp8(kv_rope_prefill)
# Write to paged cache
paged_kv_write(kv_fp8_prefill, prefill_slots, kv_cache_fp8, block_size)
# Write inv_scale to flat cache
for t in range(N_PREFILL):
slot = prefill_slots[t].item()
inv_scale_cache[slot] = inv_scale_prefill[t]
# Also store BF16 reference (for verification)
kv_cache_bf16[:N_PREFILL] = kv_rope_prefill
print(f" Prefill: {N_PREFILL} tokens written to KV cache")
print(f" KV cache fp8 amax: {kv_fp8_prefill.float().abs().max():.4f}")
print(f" KV BF16 amax: {kv_rope_prefill.amax():.4f}")
# Verify roundtrip: read back and compare
kv_read = paged_kv_read(prefill_slots, kv_cache_fp8, block_size, N_PREFILL, HD)
inv_read = inv_scale_cache[prefill_slots]
kv_dequant = kv_dequantize_fp8(kv_read, inv_read)
c = F.cosine_similarity(kv_rope_prefill.flatten().unsqueeze(0).float(), kv_dequant.flatten().unsqueeze(0).float()).item()
print(f" KV cache roundtrip cosine: {c:.6f} {'' if c>=0.99 else ''}")
# ════════════════════════════════════════════════════════════════
# STEP 2: DECODE — process token N_PREFILL
# ════════════════════════════════════════════════════════════════
decode_id = token_ids[N_PREFILL:N_PREFILL + N_DECODE]
decode_pos = torch.tensor([N_PREFILL], dtype=torch.int64, device=DEV)
decode_slot = decode_pos
hidden_decode = emb[decode_id]
normed_decode = rms(hidden_decode, anorm, EPS)
# Project Q and KV
qa_decode = r_qa.run(normed_decode)
kv_decode = r_kv.run(normed_decode)
qa_n_decode = rms(qa_decode, qn, EPS)
kv_n_decode = rms(kv_decode, kvn, EPS)
q_decode = r_qb.run(qa_n_decode).view(N_DECODE, NH, HD)
q_rope_decode = apply_gptj_rope(q_decode, decode_pos, cos_sin, NOPE, ROPE)
# Apply RoPE to KV
kv_rope_decode = apply_gptj_rope(kv_n_decode.unsqueeze(1), decode_pos, cos_sin, NOPE, ROPE).squeeze(1)
# Write decode KV to cache
kv_fp8_decode, inv_scale_decode = kv_quantize_fp8(kv_rope_decode)
paged_kv_write(kv_fp8_decode, decode_slot, kv_cache_fp8, block_size)
for t in range(N_DECODE):
slot = decode_slot[t].item()
inv_scale_cache[slot] = inv_scale_decode[t]
kv_cache_bf16[N_PREFILL:N_PREFILL + N_DECODE] = kv_rope_decode
print(f"\n Decode: token {N_PREFILL} written to KV cache")
# ════════════════════════════════════════════════════════════════
# STEP 3: DECODE ATTENTION using KV cache
# ════════════════════════════════════════════════════════════════
# Read ALL KV from cache (tokens 0..N_PREFILL)
all_slots = torch.arange(N_TOTAL, dtype=torch.int64, device=DEV)
kv_all_fp8 = paged_kv_read(all_slots, kv_cache_fp8, block_size, N_TOTAL, HD)
inv_scale_all = inv_scale_cache[all_slots]
kv_all_dequant = kv_dequantize_fp8(kv_all_fp8, inv_scale_all)
# SWA: attend to last WINDOW tokens (or all if total < WINDOW)
if N_TOTAL <= WINDOW:
# Full attention within window
o_from_cache = full_causal_attention(
q_rope_decode, # (1, NH, HD) — only the decode token
kv_all_dequant, # (N_TOTAL, HD) — all cached KV
SCALE,
)
else:
o_from_cache = swa_decode_attention(
q_rope_decode, kv_all_dequant, decode_pos, SCALE, WINDOW,
)
# ════════════════════════════════════════════════════════════════
# STEP 4: BF16 REFERENCE — process ALL tokens at once
# ════════════════════════════════════════════════════════════════
all_ids = token_ids[:N_TOTAL]
all_pos = torch.arange(N_TOTAL, dtype=torch.int64, device=DEV)
hidden_all = emb[all_ids]
normed_all = rms(hidden_all, anorm, EPS)
qa_all = r_qa.run(normed_all)
kv_all = r_kv.run(normed_all)
qa_n_all = rms(qa_all, qn, EPS)
kv_n_all = rms(kv_all, kvn, EPS)
q_all = r_qb.run(qa_n_all).view(N_TOTAL, NH, HD)
q_rope_all = apply_gptj_rope(q_all, all_pos, cos_sin, NOPE, ROPE)
kv_rope_all = apply_gptj_rope(kv_n_all.unsqueeze(1), all_pos, cos_sin, NOPE, ROPE).squeeze(1)
# Full BF16 attention on all tokens
o_ref_all = full_causal_attention(q_rope_all, kv_rope_all, SCALE)
o_ref_decode = o_ref_all[N_PREFILL:] # Only the decode token's output
# ════════════════════════════════════════════════════════════════
# COMPARE: cached KV decode vs BF16 reference decode
# ════════════════════════════════════════════════════════════════
c = F.cosine_similarity(o_from_cache.flatten().unsqueeze(0).float(), o_ref_decode.flatten().unsqueeze(0).float()).item()
print(f"\n Decode attention (cached KV) vs BF16 reference cosine: {c:.6f} {'' if c>=0.98 else ''}")
print(f" Cached output amax: {o_from_cache.amax():.4f} BF16 ref amax: {o_ref_decode.amax():.4f}")
print(f" Cached output NaN: {torch.isnan(o_from_cache).any()} BF16 NaN: {torch.isnan(o_ref_decode).any()}")
# ════════════════════════════════════════════════════════════════
# STEP 5: Full output pipeline — inverse RoPE + o_a BMM + o_b
# ════════════════════════════════════════════════════════════════
# Using cached attention output
o_inv = apply_inv_gptj_rope(o_from_cache, decode_pos, cos_sin, NOPE, ROPE)
o_grouped = o_inv.view(N_DECODE, OG, HPG * HD).permute(1, 0, 2)
woa_3d = woa.view(OG, OL, HPG * HD)
z_cached = torch.bmm(o_grouped, woa_3d.transpose(1, 2)).permute(1, 0, 2).reshape(N_DECODE, OG * OL)
attn_out_cached = r_wob.run(z_cached)
# Using BF16 reference
o_inv_ref = apply_inv_gptj_rope(o_ref_decode, decode_pos, cos_sin, NOPE, ROPE)
o_grouped_ref = o_inv_ref.view(N_DECODE, OG, HPG * HD).permute(1, 0, 2)
z_ref = torch.bmm(o_grouped_ref, woa_3d.transpose(1, 2)).permute(1, 0, 2).reshape(N_DECODE, OG * OL)
attn_out_ref = r_wob.run(z_ref)
c_full = F.cosine_similarity(attn_out_cached.flatten().unsqueeze(0).float(), attn_out_ref.flatten().unsqueeze(0).float()).item()
print(f" Full output (cached) vs BF16 reference cosine: {c_full:.6f} {'' if c_full>=0.98 else ''}")
# ════════════════════════════════════════════════════════════════
# BUG REPRODUCTION: What vLLM currently does (uses raw kv, not cache)
# ════════════════════════════════════════════════════════════════
print(f"\n --- BUG REPRODUCTION: vLLM Blackwell path ---")
# vLLM's _attention_impl_blackwell calls full_sdpa_attention(q, kv, scale)
# where kv is the RAW projection output (not from cache)
# For decode, this only has 1 token of KV — missing all the prior tokens!
o_buggy = full_causal_attention(q_rope_decode, kv_n_decode, SCALE)
c_bug = F.cosine_similarity(o_buggy.flatten().unsqueeze(0).float(), o_ref_decode.flatten().unsqueeze(0).float()).item()
print(f" Buggy (raw kv, no cache) cosine: {c_bug:.6f} ❌ (should be low — missing context)")
print(f" This is why vLLM produces garbage: decode only has 1 KV vector,")
print(f" but needs to attend to ALL prior tokens' KV from the cache.")
# Cleanup
del r_qa, r_qb, r_kv, r_wob
torch.cuda.empty_cache()
return c, c_full
def main():
print("=" * 70)
print(" DeepSeek-V4 Decode Attention Pipeline Test")
print(" Reproduces the vLLM Blackwell bug: KV cache not used for decode")
print("=" * 70)
# Test SWA layer (layer 60, compress_ratio=0)
c_swa, c_swa_full = test_prefill_decode(60, 0)
# Test C128A layer (layer 0, compress_ratio=128) — for this test,
# we just do full attention (not compressed) since compression
# requires the compressor/indexer which is a separate concern
# c_c128, c_c128_full = test_prefill_decode(0, 128)
print(f"\n{'='*70}")
print(f" SUMMARY")
print(f" Layer 60 (SWA): decode attention cosine = {c_swa:.6f}, full output = {c_swa_full:.6f}")
print(f"{'='*70}")
print(f"\n KEY TAKEAWAY: The KV cache write/read + attention pipeline")
print(f" must work for decode. Once verified, we can build the vLLM")
print(f" attention backend that uses this pipeline.")
if __name__ == "__main__":
main()