461 lines
22 KiB
Python
461 lines
22 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
DeepSeek-V4 Decode Attention Pipeline Test
|
||
|
||
REPRODUCES THE BUG: The vLLM Blackwell path uses raw KV for attention,
|
||
which means decode (generating token N+1 when tokens 0..N are in the KV cache)
|
||
produces garbage because the cache is never written to.
|
||
|
||
This test simulates the actual decode scenario:
|
||
1. Prefill: compute KV for N tokens, write to paged cache
|
||
2. Decode: compute KV for 1 new token, write to cache, then attend to ALL cached KV
|
||
|
||
The key insight: during decode, you can't use raw KV — you need the KV cache
|
||
because previous tokens' KV was computed in a prior forward pass.
|
||
|
||
Architecture:
|
||
- KV latent is (T, 512) — single head, shared across all 128 Q heads
|
||
- After kv_norm + RoPE, KV is quantized to fp8 and stored in paged cache
|
||
- Attention: Q (128 heads) × K^T → softmax → × V
|
||
- For CSA/HCA: attention attends to compressed positions (every 4th or 128th)
|
||
- For SWA: attention attends to last WINDOW tokens
|
||
|
||
Usage (on B200):
|
||
cd /root/nvfp4-megamoe-kernel
|
||
PYTHONPATH=/root/nvfp4-megamoe-kernel tests/venv/bin/python tests/test_decode_attention_b200.py
|
||
"""
|
||
|
||
import sys, os, json, torch, torch.nn.functional as F, math
|
||
from safetensors import safe_open
|
||
|
||
REPO = "/root/nvfp4-megamoe-kernel"
|
||
sys.path.insert(0, REPO)
|
||
MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
||
DEV = "cuda:0"
|
||
|
||
H = 7168; NH = 128; HD = 512; NOPE = 448; ROPE = 64
|
||
QL = 1536; OL = 1024; OG = 16; HPG = NH // OG
|
||
EPS = 1e-6; WINDOW = 128; SCALE = HD ** -0.5
|
||
|
||
E2M1 = torch.tensor([0,.5,1.,1.5,2.,3.,4.,6.,-0,-.5,-1.,-1.5,-2.,-3.,-4.,-6.], dtype=torch.float32)
|
||
|
||
_cache = {}
|
||
def P(k, wm, md):
|
||
if k in _cache: return _cache[k]
|
||
with safe_open(os.path.join(md, wm[k]), framework="pt") as f:
|
||
t = f.get_tensor(k)
|
||
_cache[k] = t
|
||
return t
|
||
|
||
def dequant(w, sf, gs):
|
||
d = w.device; lut = E2M1.to(d)
|
||
lo = lut[(w & 0xF).long()]; hi = lut[((w >> 4) & 0xF).long()]
|
||
O, I2 = w.shape; I = I2*2
|
||
u = torch.empty(O, I, dtype=torch.float32, device=d)
|
||
u[:,0::2] = lo; u[:,1::2] = hi
|
||
bs = sf.float().repeat_interleave(16, dim=1)[:O,:I]
|
||
return (u * bs * gs).to(torch.bfloat16)
|
||
|
||
def rms(x, w, eps=1e-6):
|
||
v = x.float().pow(2).mean(-1, keepdim=True)
|
||
return (w.float() * (x * torch.rsqrt(v+eps)).float()).to(x.dtype)
|
||
|
||
def make_runner(w, sf, gs_t, inf, outf, fused=False, lw=None):
|
||
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
|
||
fp4 = w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()
|
||
s = sf.to(torch.float8_e4m3fn) if sf.dtype != torch.float8_e4m3fn else sf
|
||
s = s.permute(1,0).contiguous()
|
||
if fused and gs_t.numel() == 2:
|
||
g1,g2 = gs_t[0].item(), gs_t[1].item(); gs = max(g1,g2)
|
||
if g1 != g2:
|
||
s32 = s.float(); sp = lw[0] if lw else outf//2
|
||
s32[:sp] *= g1/gs; s32[sp:] *= g2/gs; s = s32.to(torch.float8_e4m3fn)
|
||
else:
|
||
gs = gs_t.max().item() if gs_t.numel() > 1 else gs_t.item()
|
||
r = CuTeDSLNvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device))
|
||
r.fp4 = [fp4]; r.sf = [s]; r.gs = [gs]
|
||
r.finalize_weights(); r._ensure_initialized()
|
||
return r
|
||
|
||
def build_cos_sin(max_pos=4096, rope_dim=ROPE):
|
||
half = rope_dim // 2
|
||
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half, dtype=torch.float32) / half))
|
||
freqs = torch.outer(torch.arange(max_pos, dtype=torch.float32), inv_freq)
|
||
return torch.cat([freqs.cos(), freqs.sin()], dim=-1)
|
||
|
||
def apply_gptj_rope(x, positions, cos_sin, nope, rope):
|
||
if rope == 0 or x.numel() == 0: return x
|
||
half = rope // 2
|
||
cos = cos_sin[positions, :half].to(x.dtype)
|
||
sin = cos_sin[positions, half:2*half].to(x.dtype)
|
||
if x.dim() == 3: cos = cos.unsqueeze(1); sin = sin.unsqueeze(1)
|
||
x_rope = x[..., nope:].clone()
|
||
even = x_rope[..., 0::2]; odd = x_rope[..., 1::2]
|
||
out = x.clone()
|
||
out[..., nope:][..., 0::2] = even * cos - odd * sin
|
||
out[..., nope:][..., 1::2] = even * sin + odd * cos
|
||
return out
|
||
|
||
def apply_inv_gptj_rope(x, positions, cos_sin, nope, rope):
|
||
if rope == 0 or x.numel() == 0: return x
|
||
half = rope // 2
|
||
cos = cos_sin[positions, :half].to(x.dtype)
|
||
sin = cos_sin[positions, half:2*half].to(x.dtype)
|
||
if x.dim() == 3: cos = cos.unsqueeze(1); sin = sin.unsqueeze(1)
|
||
x_rope = x[..., nope:].clone()
|
||
even = x_rope[..., 0::2]; odd = x_rope[..., 1::2]
|
||
out = x.clone()
|
||
out[..., nope:][..., 0::2] = even * cos + odd * sin
|
||
out[..., nope:][..., 1::2] = -even * sin + odd * cos
|
||
return out
|
||
|
||
|
||
# ── KV Cache Kernels ────────────────────────────────────────────────
|
||
|
||
def kv_quantize_fp8(kv_bf16):
|
||
"""BF16 KV → fp8_e4m3 with per-token scale."""
|
||
amax = kv_bf16.float().abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
|
||
fp8_max = torch.tensor(448.0, dtype=torch.float32, device=kv_bf16.device)
|
||
scale = fp8_max / amax
|
||
kv_fp8 = (kv_bf16.float() * scale).to(torch.float8_e4m3fn)
|
||
inv_scale = (amax / fp8_max).to(torch.bfloat16)
|
||
return kv_fp8, inv_scale
|
||
|
||
def kv_dequantize_fp8(kv_fp8, inv_scale):
|
||
"""fp8 KV → BF16."""
|
||
return (kv_fp8.to(torch.bfloat16) * inv_scale).to(torch.bfloat16)
|
||
|
||
def paged_kv_write(kv_data, slot_mapping, cache, block_size):
|
||
"""Write data into paged cache. Works for fp8 or bf16.
|
||
|
||
kv_data: (T, D) tensor to write
|
||
slot_mapping: (T,) slot indices
|
||
cache: (num_blocks, block_size, D) cache tensor
|
||
"""
|
||
for t in range(kv_data.shape[0]):
|
||
slot = slot_mapping[t].item()
|
||
block_idx = slot // block_size
|
||
offset = slot % block_size
|
||
if block_idx < cache.shape[0] and offset < cache.shape[1]:
|
||
cache[block_idx, offset] = kv_data[t]
|
||
|
||
def paged_kv_read(slot_mapping, cache, block_size, num_tokens, head_dim):
|
||
"""Read KV from paged cache."""
|
||
device = cache.device
|
||
kv = torch.zeros(num_tokens, head_dim, dtype=cache.dtype, device=device)
|
||
for t in range(num_tokens):
|
||
slot = slot_mapping[t].item()
|
||
block_idx = slot // block_size
|
||
offset = slot % block_size
|
||
if block_idx < cache.shape[0] and offset < cache.shape[1]:
|
||
kv[t] = cache[block_idx, offset]
|
||
return kv
|
||
|
||
|
||
# ── Attention ────────────────────────────────────────────────────────
|
||
|
||
def full_causal_attention(q, kv, scale):
|
||
"""Full causal self-attention. q: (T_q, NH, HD), kv: (T_kv, HD).
|
||
|
||
Works for prefill (T_q == T_kv) and decode (T_q == 1, T_kv > 1).
|
||
Uses SDPA for efficiency.
|
||
"""
|
||
T_q, NH, HD = q.shape
|
||
T_kv = kv.shape[0]
|
||
|
||
# q: (NH, T_q, HD), k/v: (NH, T_kv, HD) — shared KV across heads
|
||
q_t = q.permute(1, 0, 2) # (NH, T_q, HD)
|
||
kv_exp = kv.unsqueeze(0).expand(NH, -1, -1) # (NH, T_kv, HD)
|
||
v_exp = kv_exp.clone()
|
||
|
||
# Causal mask: query at position i can attend to positions <= i
|
||
# For decode (T_q=1), all T_kv positions are valid (position T_kv-1 attends to 0..T_kv-1)
|
||
if T_q == T_kv:
|
||
# Prefill: standard causal
|
||
attn_mask = torch.tril(torch.ones(T_q, T_kv, device=q.device, dtype=torch.bool)).unsqueeze(0).expand(NH, -1, -1)
|
||
out = F.scaled_dot_product_attention(q_t, kv_exp, v_exp, attn_mask=attn_mask, scale=scale)
|
||
else:
|
||
# Decode or mixed: no masking needed (all positions are in the past)
|
||
out = F.scaled_dot_product_attention(q_t, kv_exp, v_exp, is_causal=False, scale=scale)
|
||
|
||
return out.permute(1, 0, 2) # (T_q, NH, HD)
|
||
|
||
|
||
def swa_decode_attention(q_new, kv_cache_bf16, positions_new, scale, window_size=WINDOW):
|
||
"""Decode-time sliding window attention.
|
||
|
||
q_new: (1, NH, HD) — single new query token with RoPE
|
||
kv_cache_bf16: (total_len, HD) — ALL cached KV (already with RoPE)
|
||
positions_new: (1,) — position of the new token
|
||
"""
|
||
total_len = kv_cache_bf16.shape[0]
|
||
pos = positions_new[0].item()
|
||
window_start = max(0, pos - window_size + 1)
|
||
window_len = pos - window_start + 1
|
||
|
||
# Get the KV window
|
||
kv_window = kv_cache_bf16[window_start:pos+1] # (window_len, HD)
|
||
NH = q_new.shape[1]
|
||
HD = q_new.shape[2]
|
||
|
||
# Multi-head attention
|
||
q_2d = q_new.reshape(NH, HD) # (NH, HD)
|
||
k_exp = kv_window.unsqueeze(0).expand(NH, -1, -1) # (NH, window_len, HD)
|
||
v_exp = k_exp.clone()
|
||
|
||
# scores: (NH, 1, window_len)
|
||
scores = torch.matmul(q_2d.unsqueeze(1), k_exp.transpose(-1, -2)) * scale
|
||
weights = F.softmax(scores.float(), dim=-1).to(q_new.dtype)
|
||
out = torch.matmul(weights, v_exp).squeeze(1) # (NH, HD)
|
||
return out.unsqueeze(0) # (1, NH, HD)
|
||
|
||
|
||
def test_prefill_decode(layer_id, compress_ratio):
|
||
"""Test the full prefill + decode attention pipeline.
|
||
|
||
Simulates what vLLM actually does:
|
||
1. PREFILL: Process N tokens, write their KV to the paged cache
|
||
2. DECODE: Process 1 new token, write its KV to the cache, attend to all cached KV
|
||
|
||
Compares decode output against a full BF16 reference (which processes all tokens at once).
|
||
"""
|
||
torch.cuda.set_device(0)
|
||
torch.manual_seed(42)
|
||
torch.cuda.empty_cache()
|
||
|
||
with open(os.path.join(MODEL, "model.safetensors.index.json")) as f:
|
||
wm = json.load(f)["weight_map"]
|
||
G = lambda k: P(k, wm, MODEL).to(DEV)
|
||
|
||
p = f"model.layers.{layer_id}"; a = f"{p}.self_attn"
|
||
layer_type = "SWA" if compress_ratio <= 1 else f"CSA(c={compress_ratio})"
|
||
|
||
print(f"\n{'='*70}")
|
||
print(f" Layer {layer_id} — {layer_type} — Prefill+Decode Test")
|
||
print(f"{'='*70}")
|
||
|
||
# Load weights
|
||
emb = G("model.embed_tokens.weight")
|
||
anorm = G(f"{p}.input_layernorm.weight")
|
||
qn = G(f"{a}.q_a_norm.weight"); kvn = G(f"{a}.kv_norm.weight")
|
||
woa = G(f"{a}.o_a_proj.weight")
|
||
|
||
qa_w = G(f"{a}.q_a_proj.weight"); qa_sf = G(f"{a}.q_a_proj.weight_scale"); qa_gs = G(f"{a}.q_a_proj.weight_scale_2")
|
||
qb_w = G(f"{a}.q_b_proj.weight"); qb_sf = G(f"{a}.q_b_proj.weight_scale"); qb_gs = G(f"{a}.q_b_proj.weight_scale_2")
|
||
kv_w = G(f"{a}.kv_proj.weight"); kv_sf = G(f"{a}.kv_proj.weight_scale"); kv_gs = G(f"{a}.kv_proj.weight_scale_2")
|
||
wob_w = G(f"{a}.o_b_proj.weight"); wob_sf = G(f"{a}.o_b_proj.weight_scale"); wob_gs = G(f"{a}.o_b_proj.weight_scale_2")
|
||
|
||
# CuTeDSL runners
|
||
r_qa = make_runner(qa_w, qa_sf, qa_gs, H, qa_w.shape[0])
|
||
r_qb = make_runner(qb_w, qb_sf, qb_gs, QL, qb_w.shape[0])
|
||
r_kv = make_runner(kv_w, kv_sf, kv_gs, H, kv_w.shape[0])
|
||
r_wob = make_runner(wob_w, wob_sf, wob_gs, OG*OL, wob_w.shape[0])
|
||
|
||
# Setup
|
||
N_PREFILL = 8 # Number of prefill tokens
|
||
N_DECODE = 1 # Single decode token
|
||
N_TOTAL = N_PREFILL + N_DECODE
|
||
|
||
token_ids = torch.tensor([1, 450, 8403, 315, 5413, 374, 2198, 643, 991], dtype=torch.long, device=DEV)
|
||
assert len(token_ids) >= N_TOTAL
|
||
cos_sin = build_cos_sin(max_pos=4096).to(DEV)
|
||
|
||
# Paged KV cache
|
||
block_size = 256
|
||
num_blocks = 64
|
||
# Cache stores fp8 KV (with per-token inv_scale stored separately)
|
||
kv_cache_fp8 = torch.zeros(num_blocks, block_size, HD, dtype=torch.float8_e4m3fn, device=DEV)
|
||
# Per-token inv scales (indexed by slot)
|
||
inv_scale_cache = torch.zeros(num_blocks * block_size, 1, dtype=torch.bfloat16, device=DEV)
|
||
# RoPE'd BF16 KV cache (for reference — in production, RoPE is applied after dequant)
|
||
kv_cache_bf16 = torch.zeros(N_TOTAL, HD, dtype=torch.bfloat16, device=DEV)
|
||
|
||
with torch.no_grad():
|
||
# ════════════════════════════════════════════════════════════════
|
||
# STEP 1: PREFILL — process tokens 0..N_PREFILL-1
|
||
# ════════════════════════════════════════════════════════════════
|
||
prefill_ids = token_ids[:N_PREFILL]
|
||
prefill_pos = torch.arange(N_PREFILL, dtype=torch.int64, device=DEV)
|
||
prefill_slots = prefill_pos # slot = position (simplified)
|
||
|
||
hidden_prefill = emb[prefill_ids]
|
||
normed_prefill = rms(hidden_prefill, anorm, EPS)
|
||
|
||
# Project KV
|
||
kv_prefill = r_kv.run(normed_prefill)
|
||
kv_normed_prefill = rms(kv_prefill, kvn, EPS)
|
||
|
||
# Apply RoPE to KV BEFORE caching
|
||
kv_rope_prefill = apply_gptj_rope(kv_normed_prefill.unsqueeze(1), prefill_pos, cos_sin, NOPE, ROPE).squeeze(1)
|
||
|
||
# Quantize to fp8
|
||
kv_fp8_prefill, inv_scale_prefill = kv_quantize_fp8(kv_rope_prefill)
|
||
|
||
# Write to paged cache
|
||
paged_kv_write(kv_fp8_prefill, prefill_slots, kv_cache_fp8, block_size)
|
||
# Write inv_scale to flat cache
|
||
for t in range(N_PREFILL):
|
||
slot = prefill_slots[t].item()
|
||
inv_scale_cache[slot] = inv_scale_prefill[t]
|
||
|
||
# Also store BF16 reference (for verification)
|
||
kv_cache_bf16[:N_PREFILL] = kv_rope_prefill
|
||
|
||
print(f" Prefill: {N_PREFILL} tokens written to KV cache")
|
||
print(f" KV cache fp8 amax: {kv_fp8_prefill.float().abs().max():.4f}")
|
||
print(f" KV BF16 amax: {kv_rope_prefill.amax():.4f}")
|
||
|
||
# Verify roundtrip: read back and compare
|
||
kv_read = paged_kv_read(prefill_slots, kv_cache_fp8, block_size, N_PREFILL, HD)
|
||
inv_read = inv_scale_cache[prefill_slots]
|
||
kv_dequant = kv_dequantize_fp8(kv_read, inv_read)
|
||
c = F.cosine_similarity(kv_rope_prefill.flatten().unsqueeze(0).float(), kv_dequant.flatten().unsqueeze(0).float()).item()
|
||
print(f" KV cache roundtrip cosine: {c:.6f} {'✅' if c>=0.99 else '❌'}")
|
||
|
||
# ════════════════════════════════════════════════════════════════
|
||
# STEP 2: DECODE — process token N_PREFILL
|
||
# ════════════════════════════════════════════════════════════════
|
||
decode_id = token_ids[N_PREFILL:N_PREFILL + N_DECODE]
|
||
decode_pos = torch.tensor([N_PREFILL], dtype=torch.int64, device=DEV)
|
||
decode_slot = decode_pos
|
||
|
||
hidden_decode = emb[decode_id]
|
||
normed_decode = rms(hidden_decode, anorm, EPS)
|
||
|
||
# Project Q and KV
|
||
qa_decode = r_qa.run(normed_decode)
|
||
kv_decode = r_kv.run(normed_decode)
|
||
qa_n_decode = rms(qa_decode, qn, EPS)
|
||
kv_n_decode = rms(kv_decode, kvn, EPS)
|
||
q_decode = r_qb.run(qa_n_decode).view(N_DECODE, NH, HD)
|
||
q_rope_decode = apply_gptj_rope(q_decode, decode_pos, cos_sin, NOPE, ROPE)
|
||
|
||
# Apply RoPE to KV
|
||
kv_rope_decode = apply_gptj_rope(kv_n_decode.unsqueeze(1), decode_pos, cos_sin, NOPE, ROPE).squeeze(1)
|
||
|
||
# Write decode KV to cache
|
||
kv_fp8_decode, inv_scale_decode = kv_quantize_fp8(kv_rope_decode)
|
||
paged_kv_write(kv_fp8_decode, decode_slot, kv_cache_fp8, block_size)
|
||
for t in range(N_DECODE):
|
||
slot = decode_slot[t].item()
|
||
inv_scale_cache[slot] = inv_scale_decode[t]
|
||
kv_cache_bf16[N_PREFILL:N_PREFILL + N_DECODE] = kv_rope_decode
|
||
|
||
print(f"\n Decode: token {N_PREFILL} written to KV cache")
|
||
|
||
# ════════════════════════════════════════════════════════════════
|
||
# STEP 3: DECODE ATTENTION using KV cache
|
||
# ════════════════════════════════════════════════════════════════
|
||
|
||
# Read ALL KV from cache (tokens 0..N_PREFILL)
|
||
all_slots = torch.arange(N_TOTAL, dtype=torch.int64, device=DEV)
|
||
kv_all_fp8 = paged_kv_read(all_slots, kv_cache_fp8, block_size, N_TOTAL, HD)
|
||
inv_scale_all = inv_scale_cache[all_slots]
|
||
kv_all_dequant = kv_dequantize_fp8(kv_all_fp8, inv_scale_all)
|
||
|
||
# SWA: attend to last WINDOW tokens (or all if total < WINDOW)
|
||
if N_TOTAL <= WINDOW:
|
||
# Full attention within window
|
||
o_from_cache = full_causal_attention(
|
||
q_rope_decode, # (1, NH, HD) — only the decode token
|
||
kv_all_dequant, # (N_TOTAL, HD) — all cached KV
|
||
SCALE,
|
||
)
|
||
else:
|
||
o_from_cache = swa_decode_attention(
|
||
q_rope_decode, kv_all_dequant, decode_pos, SCALE, WINDOW,
|
||
)
|
||
|
||
# ════════════════════════════════════════════════════════════════
|
||
# STEP 4: BF16 REFERENCE — process ALL tokens at once
|
||
# ════════════════════════════════════════════════════════════════
|
||
all_ids = token_ids[:N_TOTAL]
|
||
all_pos = torch.arange(N_TOTAL, dtype=torch.int64, device=DEV)
|
||
|
||
hidden_all = emb[all_ids]
|
||
normed_all = rms(hidden_all, anorm, EPS)
|
||
|
||
qa_all = r_qa.run(normed_all)
|
||
kv_all = r_kv.run(normed_all)
|
||
qa_n_all = rms(qa_all, qn, EPS)
|
||
kv_n_all = rms(kv_all, kvn, EPS)
|
||
q_all = r_qb.run(qa_n_all).view(N_TOTAL, NH, HD)
|
||
q_rope_all = apply_gptj_rope(q_all, all_pos, cos_sin, NOPE, ROPE)
|
||
kv_rope_all = apply_gptj_rope(kv_n_all.unsqueeze(1), all_pos, cos_sin, NOPE, ROPE).squeeze(1)
|
||
|
||
# Full BF16 attention on all tokens
|
||
o_ref_all = full_causal_attention(q_rope_all, kv_rope_all, SCALE)
|
||
o_ref_decode = o_ref_all[N_PREFILL:] # Only the decode token's output
|
||
|
||
# ════════════════════════════════════════════════════════════════
|
||
# COMPARE: cached KV decode vs BF16 reference decode
|
||
# ════════════════════════════════════════════════════════════════
|
||
c = F.cosine_similarity(o_from_cache.flatten().unsqueeze(0).float(), o_ref_decode.flatten().unsqueeze(0).float()).item()
|
||
print(f"\n Decode attention (cached KV) vs BF16 reference cosine: {c:.6f} {'✅' if c>=0.98 else '❌'}")
|
||
print(f" Cached output amax: {o_from_cache.amax():.4f} BF16 ref amax: {o_ref_decode.amax():.4f}")
|
||
print(f" Cached output NaN: {torch.isnan(o_from_cache).any()} BF16 NaN: {torch.isnan(o_ref_decode).any()}")
|
||
|
||
# ════════════════════════════════════════════════════════════════
|
||
# STEP 5: Full output pipeline — inverse RoPE + o_a BMM + o_b
|
||
# ════════════════════════════════════════════════════════════════
|
||
# Using cached attention output
|
||
o_inv = apply_inv_gptj_rope(o_from_cache, decode_pos, cos_sin, NOPE, ROPE)
|
||
o_grouped = o_inv.view(N_DECODE, OG, HPG * HD).permute(1, 0, 2)
|
||
woa_3d = woa.view(OG, OL, HPG * HD)
|
||
z_cached = torch.bmm(o_grouped, woa_3d.transpose(1, 2)).permute(1, 0, 2).reshape(N_DECODE, OG * OL)
|
||
attn_out_cached = r_wob.run(z_cached)
|
||
|
||
# Using BF16 reference
|
||
o_inv_ref = apply_inv_gptj_rope(o_ref_decode, decode_pos, cos_sin, NOPE, ROPE)
|
||
o_grouped_ref = o_inv_ref.view(N_DECODE, OG, HPG * HD).permute(1, 0, 2)
|
||
z_ref = torch.bmm(o_grouped_ref, woa_3d.transpose(1, 2)).permute(1, 0, 2).reshape(N_DECODE, OG * OL)
|
||
attn_out_ref = r_wob.run(z_ref)
|
||
|
||
c_full = F.cosine_similarity(attn_out_cached.flatten().unsqueeze(0).float(), attn_out_ref.flatten().unsqueeze(0).float()).item()
|
||
print(f" Full output (cached) vs BF16 reference cosine: {c_full:.6f} {'✅' if c_full>=0.98 else '❌'}")
|
||
|
||
# ════════════════════════════════════════════════════════════════
|
||
# BUG REPRODUCTION: What vLLM currently does (uses raw kv, not cache)
|
||
# ════════════════════════════════════════════════════════════════
|
||
print(f"\n --- BUG REPRODUCTION: vLLM Blackwell path ---")
|
||
# vLLM's _attention_impl_blackwell calls full_sdpa_attention(q, kv, scale)
|
||
# where kv is the RAW projection output (not from cache)
|
||
# For decode, this only has 1 token of KV — missing all the prior tokens!
|
||
o_buggy = full_causal_attention(q_rope_decode, kv_n_decode, SCALE)
|
||
c_bug = F.cosine_similarity(o_buggy.flatten().unsqueeze(0).float(), o_ref_decode.flatten().unsqueeze(0).float()).item()
|
||
print(f" Buggy (raw kv, no cache) cosine: {c_bug:.6f} ❌ (should be low — missing context)")
|
||
print(f" This is why vLLM produces garbage: decode only has 1 KV vector,")
|
||
print(f" but needs to attend to ALL prior tokens' KV from the cache.")
|
||
|
||
# Cleanup
|
||
del r_qa, r_qb, r_kv, r_wob
|
||
torch.cuda.empty_cache()
|
||
return c, c_full
|
||
|
||
|
||
def main():
|
||
print("=" * 70)
|
||
print(" DeepSeek-V4 Decode Attention Pipeline Test")
|
||
print(" Reproduces the vLLM Blackwell bug: KV cache not used for decode")
|
||
print("=" * 70)
|
||
|
||
# Test SWA layer (layer 60, compress_ratio=0)
|
||
c_swa, c_swa_full = test_prefill_decode(60, 0)
|
||
|
||
# Test C128A layer (layer 0, compress_ratio=128) — for this test,
|
||
# we just do full attention (not compressed) since compression
|
||
# requires the compressor/indexer which is a separate concern
|
||
# c_c128, c_c128_full = test_prefill_decode(0, 128)
|
||
|
||
print(f"\n{'='*70}")
|
||
print(f" SUMMARY")
|
||
print(f" Layer 60 (SWA): decode attention cosine = {c_swa:.6f}, full output = {c_swa_full:.6f}")
|
||
print(f"{'='*70}")
|
||
print(f"\n KEY TAKEAWAY: The KV cache write/read + attention pipeline")
|
||
print(f" must work for decode. Once verified, we can build the vLLM")
|
||
print(f" attention backend that uses this pipeline.")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|