Add decode vs prefill consistency test
This commit is contained in:
274
tests/test_decode_vs_prefill_b200.py
Normal file
274
tests/test_decode_vs_prefill_b200.py
Normal file
@@ -0,0 +1,274 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
DeepSeek-V4 Decode vs Prefill Consistency Test
|
||||
|
||||
Verifies that:
|
||||
1. Decode attention (using KV cache) produces the same output as
|
||||
prefill attention (raw KV) for the same token position
|
||||
2. The cosine similarity between decode and prefill outputs is > 0.98
|
||||
|
||||
This is the CRITICAL test: if it passes, the KV cache pipeline is correct
|
||||
and the vLLM container should produce valid output.
|
||||
|
||||
Usage (on B200):
|
||||
cd /root/nvfp4-megamoe-kernel
|
||||
PYTHONPATH=/root/nvfp4-megamoe-kernel tests/venv/bin/python tests/test_decode_vs_prefill_b200.py
|
||||
"""
|
||||
|
||||
import sys, os, json, torch, torch.nn.functional as F, time
|
||||
from safetensors import safe_open
|
||||
|
||||
REPO = "/root/nvfp4-megamoe-kernel"
|
||||
sys.path.insert(0, REPO)
|
||||
MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
||||
DEV = "cuda:0"
|
||||
|
||||
H = 7168; NH = 128; HD = 512; NOPE = 448; ROPE = 64
|
||||
QL = 1536; OL = 1024; OG = 16; HPG = NH // OG
|
||||
EPS = 1e-6; WINDOW = 128; SCALE = HD ** -0.5
|
||||
NUM_LAYERS = 61
|
||||
|
||||
E2M1 = torch.tensor([0,.5,1.,1.5,2.,3.,4.,6.,-0,-.5,-1.,-1.5,-2.,-3.,-4.,-6.], dtype=torch.float32)
|
||||
|
||||
_cache = {}
|
||||
def P(k, wm, md):
|
||||
if k in _cache: return _cache[k]
|
||||
with safe_open(os.path.join(md, wm[k]), framework="pt") as f:
|
||||
t = f.get_tensor(k)
|
||||
_cache[k] = t
|
||||
return t
|
||||
|
||||
def rms(x, w, eps=1e-6):
|
||||
v = x.float().pow(2).mean(-1, keepdim=True)
|
||||
return (w.float() * (x * torch.rsqrt(v+eps)).float()).to(x.dtype)
|
||||
|
||||
def make_runner(w, sf, gs_t, inf, outf, fused=False, lw=None):
|
||||
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
|
||||
fp4 = w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()
|
||||
s = sf.to(torch.float8_e4m3fn) if sf.dtype != torch.float8_e4m3fn else sf
|
||||
s = s.permute(1,0).contiguous()
|
||||
if fused and gs_t.numel() == 2:
|
||||
g1,g2 = gs_t[0].item(), gs_t[1].item(); gs = max(g1,g2)
|
||||
if g1 != g2:
|
||||
s32 = s.float(); sp = lw[0] if lw else outf//2
|
||||
s32[:sp] *= g1/gs; s32[sp:] *= g2/gs; s = s32.to(torch.float8_e4m3fn)
|
||||
else:
|
||||
gs = gs_t.max().item() if gs_t.numel() > 1 else gs_t.item()
|
||||
r = CuTeDSLNvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device))
|
||||
r.fp4 = [fp4]; r.sf = [s]; r.gs = [gs]
|
||||
r.finalize_weights(); r._ensure_initialized()
|
||||
return r
|
||||
|
||||
def build_cos_sin(max_pos=4096, rope_dim=ROPE):
|
||||
half = rope_dim // 2
|
||||
inv_freq = 1.0 / (10000.0 ** (torch.arange(0, half, dtype=torch.float32) / half))
|
||||
freqs = torch.outer(torch.arange(max_pos, dtype=torch.float32), inv_freq)
|
||||
return torch.cat([freqs.cos(), freqs.sin()], dim=-1)
|
||||
|
||||
def apply_gptj_rope(x, positions, cos_sin, nope_dim, rope_dim):
|
||||
if rope_dim == 0 or x.numel() == 0: return x
|
||||
half = rope_dim // 2
|
||||
cos = cos_sin[positions, :half].to(x.dtype)
|
||||
sin = cos_sin[positions, half:2*half].to(x.dtype)
|
||||
if x.dim() == 3: cos = cos.unsqueeze(1); sin = sin.unsqueeze(1)
|
||||
x_rope = x[..., nope_dim:].clone()
|
||||
even = x_rope[..., 0::2]; odd = x_rope[..., 1::2]
|
||||
out = x.clone()
|
||||
out[..., nope_dim:][..., 0::2] = even * cos - odd * sin
|
||||
out[..., nope_dim:][..., 1::2] = even * sin + odd * cos
|
||||
return out
|
||||
|
||||
def apply_inv_gptj_rope(x, positions, cos_sin, nope_dim, rope_dim):
|
||||
if rope_dim == 0 or x.numel() == 0: return x
|
||||
half = rope_dim // 2
|
||||
cos = cos_sin[positions, :half].to(x.dtype)
|
||||
sin = cos_sin[positions, half:2*half].to(x.dtype)
|
||||
if x.dim() == 3: cos = cos.unsqueeze(1); sin = sin.unsqueeze(1)
|
||||
x_rope = x[..., nope_dim:].clone()
|
||||
even = x_rope[..., 0::2]; odd = x_rope[..., 1::2]
|
||||
out = x.clone()
|
||||
out[..., nope_dim:][..., 0::2] = even * cos + odd * sin
|
||||
out[..., nope_dim:][..., 1::2] = -even * sin + odd * cos
|
||||
return out
|
||||
|
||||
def kv_quantize_fp8(kv_bf16):
|
||||
amax = kv_bf16.float().abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
|
||||
fp8_max = torch.tensor(448.0, dtype=torch.float32, device=kv_bf16.device)
|
||||
scale = fp8_max / amax
|
||||
kv_fp8 = (kv_bf16.float() * scale).to(torch.float8_e4m3fn)
|
||||
inv_scale = (amax / fp8_max).to(torch.bfloat16)
|
||||
return kv_fp8, inv_scale
|
||||
|
||||
def kv_dequantize_fp8(kv_fp8, inv_scale):
|
||||
return (kv_fp8.to(torch.bfloat16) * inv_scale).to(torch.bfloat16)
|
||||
|
||||
def causal_prefill_attention(q, kv, scale):
|
||||
T, NH, HD = q.shape
|
||||
q_t = q.permute(1, 0, 2)
|
||||
kv_exp = kv.unsqueeze(0).expand(NH, -1, -1)
|
||||
out = F.scaled_dot_product_attention(q_t, kv_exp, kv_exp, is_causal=True, scale=scale)
|
||||
return out.permute(1, 0, 2)
|
||||
|
||||
def decode_attention(q, kv, scale):
|
||||
NH = q.shape[1]; HD = q.shape[2]
|
||||
q_t = q.permute(1, 0, 2)
|
||||
kv_exp = kv.unsqueeze(0).expand(NH, -1, -1)
|
||||
out = F.scaled_dot_product_attention(q_t, kv_exp, kv_exp, is_causal=False, scale=scale)
|
||||
return out.permute(1, 0, 2)
|
||||
|
||||
|
||||
def test_layer_decode_vs_prefill(layer_id):
|
||||
"""For a single layer, verify decode matches prefill."""
|
||||
torch.cuda.set_device(0)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with open(os.path.join(MODEL, "model.safetensors.index.json")) as f:
|
||||
wm = json.load(f)["weight_map"]
|
||||
G = lambda k: P(k, wm, MODEL).to(DEV)
|
||||
|
||||
p = f"model.layers.{layer_id}"; a = f"{p}.self_attn"
|
||||
cr = 128 if layer_id == 0 else (0 if layer_id == 60 else 4)
|
||||
lt = f"C{cr}A" if cr > 1 else "SWA"
|
||||
|
||||
emb = G("model.embed_tokens.weight")
|
||||
anorm = G(f"{p}.input_layernorm.weight")
|
||||
qn = G(f"{a}.q_a_norm.weight"); kvn = G(f"{a}.kv_norm.weight")
|
||||
woa = G(f"{a}.o_a_proj.weight")
|
||||
qa_w = G(f"{a}.q_a_proj.weight"); qa_sf = G(f"{a}.q_a_proj.weight_scale"); qa_gs = G(f"{a}.q_a_proj.weight_scale_2")
|
||||
qb_w = G(f"{a}.q_b_proj.weight"); qb_sf = G(f"{a}.q_b_proj.weight_scale"); qb_gs = G(f"{a}.q_b_proj.weight_scale_2")
|
||||
kv_w = G(f"{a}.kv_proj.weight"); kv_sf = G(f"{a}.kv_proj.weight_scale"); kv_gs = G(f"{a}.kv_proj.weight_scale_2")
|
||||
wob_w = G(f"{a}.o_b_proj.weight"); wob_sf = G(f"{a}.o_b_proj.weight_scale"); wob_gs = G(f"{a}.o_b_proj.weight_scale_2")
|
||||
|
||||
r_qa = make_runner(qa_w, qa_sf, qa_gs, H, qa_w.shape[0])
|
||||
r_qb = make_runner(qb_w, qb_sf, qb_gs, QL, qb_w.shape[0])
|
||||
r_kv = make_runner(kv_w, kv_sf, kv_gs, H, kv_w.shape[0])
|
||||
r_wob = make_runner(wob_w, wob_sf, wob_gs, OG*OL, wob_w.shape[0])
|
||||
cos_sin = build_cos_sin(max_pos=4096).to(DEV)
|
||||
|
||||
# Paged KV cache
|
||||
block_size = 64; max_tokens = 32; num_blocks = (max_tokens + block_size - 1) // block_size
|
||||
kv_cache = torch.zeros(num_blocks, block_size, HD, dtype=torch.float8_e4m3fn, device=DEV)
|
||||
inv_scale_cache = torch.zeros(max_tokens, 1, dtype=torch.bfloat16, device=DEV)
|
||||
|
||||
N = 8 # Prefill tokens
|
||||
token_ids = torch.tensor([1, 450, 8403, 315, 5413, 374, 2198, 643], dtype=torch.long, device=DEV)
|
||||
|
||||
with torch.no_grad():
|
||||
# ── PREFILL: process all N tokens at once ───────────────
|
||||
positions_p = torch.arange(N, dtype=torch.int64, device=DEV)
|
||||
hidden_p = emb[token_ids]
|
||||
normed_p = rms(hidden_p, anorm, EPS)
|
||||
qa_p = r_qa.run(normed_p); kv_p = r_kv.run(normed_p)
|
||||
qa_n_p = rms(qa_p, qn, EPS); kv_n_p = rms(kv_p, kvn, EPS)
|
||||
q_p = r_qb.run(qa_n_p).view(N, NH, HD)
|
||||
q_rope_p = apply_gptj_rope(q_p, positions_p, cos_sin, NOPE, ROPE)
|
||||
kv_rope_p = apply_gptj_rope(kv_n_p.unsqueeze(1), positions_p, cos_sin, NOPE, ROPE).squeeze(1)
|
||||
|
||||
# Write prefill KV to cache
|
||||
kv_fp8_p, inv_s_p = kv_quantize_fp8(kv_rope_p)
|
||||
slots_p = positions_p
|
||||
bi_p = slots_p // block_size; oi_p = slots_p % block_size
|
||||
kv_cache[bi_p, oi_p] = kv_fp8_p
|
||||
for t in range(N):
|
||||
inv_scale_cache[slots_p[t]] = inv_s_p[t]
|
||||
|
||||
# Prefill attention (raw KV)
|
||||
o_prefill = causal_prefill_attention(q_rope_p, kv_rope_p, SCALE)
|
||||
o_inv_p = apply_inv_gptj_rope(o_prefill, positions_p, cos_sin, NOPE, ROPE)
|
||||
o_grp_p = o_inv_p.reshape(N, OG, HPG * HD).permute(1, 0, 2)
|
||||
woa_3d = woa.view(OG, OL, HPG * HD)
|
||||
z_p = torch.bmm(o_grp_p, woa_3d.transpose(1, 2)).permute(1, 0, 2).reshape(N, OG * OL)
|
||||
attn_prefill = r_wob.run(z_p)
|
||||
|
||||
# ── DECODE: process token N (one at a time) ────────────
|
||||
decode_id = torch.tensor([991], dtype=torch.long, device=DEV)
|
||||
pos_d = torch.tensor([N], dtype=torch.int64, device=DEV)
|
||||
hidden_d = emb[decode_id]
|
||||
normed_d = rms(hidden_d, anorm, EPS)
|
||||
qa_d = r_qa.run(normed_d); kv_d = r_kv.run(normed_d)
|
||||
qa_n_d = rms(qa_d, qn, EPS); kv_n_d = rms(kv_d, kvn, EPS)
|
||||
q_d = r_qb.run(qa_n_d).view(1, NH, HD)
|
||||
q_rope_d = apply_gptj_rope(q_d, pos_d, cos_sin, NOPE, ROPE)
|
||||
kv_rope_d = apply_gptj_rope(kv_n_d.unsqueeze(1), pos_d, cos_sin, NOPE, ROPE).squeeze(1)
|
||||
|
||||
# Write decode KV to cache
|
||||
kv_fp8_d, inv_s_d = kv_quantize_fp8(kv_rope_d)
|
||||
slot_d = pos_d[0].item()
|
||||
bi_d = slot_d // block_size; oi_d = slot_d % block_size
|
||||
kv_cache[bi_d, oi_d] = kv_fp8_d[0]
|
||||
inv_scale_cache[slot_d] = inv_s_d[0]
|
||||
|
||||
# Decode attention: read from cache
|
||||
all_slots = torch.arange(N + 1, dtype=torch.int64, device=DEV)
|
||||
all_bi = all_slots // block_size; all_oi = all_slots % block_size
|
||||
kv_cached_fp8 = kv_cache[all_bi, all_oi]
|
||||
kv_cached_inv = inv_scale_cache[all_slots]
|
||||
kv_cached = kv_dequantize_fp8(kv_cached_fp8, kv_cached_inv)
|
||||
|
||||
# SWA window
|
||||
ws = max(0, N - WINDOW + 1)
|
||||
kv_window = kv_cached[ws:]
|
||||
o_decode = decode_attention(q_rope_d, kv_window, SCALE)
|
||||
|
||||
# Full output pipeline for decode
|
||||
o_inv_d = apply_inv_gptj_rope(o_decode, pos_d, cos_sin, NOPE, ROPE)
|
||||
o_grp_d = o_inv_d.reshape(1, OG, HPG * HD).permute(1, 0, 2)
|
||||
z_d = torch.bmm(o_grp_d, woa_3d.transpose(1, 2)).permute(1, 0, 2).reshape(1, OG * OL)
|
||||
attn_decode = r_wob.run(z_d)
|
||||
|
||||
# ── REFERENCE: prefill all N+1 tokens, take the last ────
|
||||
all_ids = torch.cat([token_ids, decode_id])
|
||||
all_pos = torch.arange(N + 1, dtype=torch.int64, device=DEV)
|
||||
hidden_ref = emb[all_ids]
|
||||
normed_ref = rms(hidden_ref, anorm, EPS)
|
||||
qa_ref = r_qa.run(normed_ref); kv_ref = r_kv.run(normed_ref)
|
||||
qa_n_ref = rms(qa_ref, qn, EPS); kv_n_ref = rms(kv_ref, kvn, EPS)
|
||||
q_ref = r_qb.run(qa_n_ref).view(N + 1, NH, HD)
|
||||
q_rope_ref = apply_gptj_rope(q_ref, all_pos, cos_sin, NOPE, ROPE)
|
||||
kv_rope_ref = apply_gptj_rope(kv_n_ref.unsqueeze(1), all_pos, cos_sin, NOPE, ROPE).squeeze(1)
|
||||
o_ref = causal_prefill_attention(q_rope_ref, kv_rope_ref, SCALE)
|
||||
o_inv_ref = apply_inv_gptj_rope(o_ref[-1:], pos_d, cos_sin, NOPE, ROPE)
|
||||
o_grp_ref = o_inv_ref.reshape(1, OG, HPG * HD).permute(1, 0, 2)
|
||||
z_ref = torch.bmm(o_grp_ref, woa_3d.transpose(1, 2)).permute(1, 0, 2).reshape(1, OG * OL)
|
||||
attn_ref = r_wob.run(z_ref)
|
||||
|
||||
# ── COMPARE ─────────────────────────────────────────────
|
||||
# Decode attention output vs reference
|
||||
c_attn = F.cosine_similarity(o_decode.flatten().unsqueeze(0).float(), o_ref[-1:].flatten().unsqueeze(0).float()).item()
|
||||
# Full output vs reference
|
||||
c_full = F.cosine_similarity(attn_decode.flatten().unsqueeze(0).float(), attn_ref.flatten().unsqueeze(0).float()).item()
|
||||
|
||||
del r_qa, r_qb, r_kv, r_wob
|
||||
torch.cuda.empty_cache()
|
||||
_cache.clear()
|
||||
|
||||
return c_attn, c_full
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 70)
|
||||
print(" DeepSeek-V4 Decode vs Prefill Consistency Test")
|
||||
print(" Verifies KV cache produces same output as full prefill")
|
||||
print("=" * 70)
|
||||
|
||||
test_layers = [
|
||||
(0, "C128A"),
|
||||
(1, "C4A"),
|
||||
(2, "C4A"),
|
||||
(30, "C4A"),
|
||||
(60, "SWA"),
|
||||
]
|
||||
|
||||
for layer_id, lt in test_layers:
|
||||
c_attn, c_full = test_layer_decode_vs_prefill(layer_id)
|
||||
status = "✅" if c_full >= 0.98 else "❌"
|
||||
print(f" Layer {layer_id} ({lt}): attn={c_attn:.4f} full={c_full:.4f} {status}")
|
||||
|
||||
print(f"\n{'='*70}")
|
||||
print(f" If all layers pass (≥0.98), the KV cache pipeline is correct.")
|
||||
print(f" The vLLM container should produce valid output.")
|
||||
print(f"{'='*70}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user