Simplify PART A test: compressor + FMHA at production scale

This commit is contained in:
2026-06-03 04:18:13 +00:00
parent 8c54cfa748
commit 4126909dfb

View File

@@ -1,13 +1,12 @@
#!/usr/bin/env python3
"""PART A diagnostic: Compressor + KV cache gathering at production scale.
"""PART A diagnostic: Compressor + FMHA at production scale.
Tests the compressed KV pipeline with production values:
- HCA ratio=128 (layers 0-1 of Pro)
- CSA ratio=4 (alternating layers)
- T=32 tokens (8 CSA blocks, 0 HCA blocks at T=32)
- Validates: compressor output, FP8/BF16 KV round-trip, KV cache gather
Tests:
1. CSA compression FP8/BF16 round-trip
2. HCA compression FP8/BF16 round-trip
3. B1 FMHA with mixed FP8/BF16 KV vs SDPA
All values are production: HD=512, NOPE=448, ROPE=64.
All values are production: HD=512, NOPE=448, ROPE=64, H=128.
"""
import sys, math
import torch
@@ -17,13 +16,14 @@ def cosine(a, b):
return F.cosine_similarity(a.flatten().float(), b.flatten().float(), dim=0).item()
def main():
HD = 512; NOPE = 448; ROPE = 64
HD = 512; NOPE = 448; ROPE = 64; n_h = 128
scale = 1.0 / math.sqrt(HD)
device = "cuda:0"
torch.manual_seed(42)
print("=" * 70)
print("PART A: Compressor + KV Cache Gathering at Production Scale")
print(f"HD={HD}, NOPE={NOPE}, ROPE={ROPE}")
print("PART A: Compressor + FMHA at Production Scale")
print(f"HD={HD}, NOPE={NOPE}, ROPE={ROPE}, H={n_h}")
print("=" * 70)
all_pass = True
@@ -35,13 +35,11 @@ def main():
for T in [4, 16, 32, 64]:
m = 4
n_blocks = T // m
kv_dim = HD * 2 # Compressor outputs 2*hd
kv_dim = HD * 2
# Simulate compressor inputs (from NVFP4 GEMM outputs)
kv_proj = torch.randn(T, kv_dim, dtype=torch.float32, device=device) * 0.3
gate_proj = torch.randn(T, kv_dim, dtype=torch.float32, device=device) * 0.3
# Run compressor
compressed = csa_compress_production_fp32(
kv_proj, gate_proj, None, None, m=4)
@@ -49,18 +47,13 @@ def main():
print(f" T={T}: n_blocks=0, SKIPPED")
continue
# Split compressed output into KV (first HD) and check
comp_kv = compressed[:, :HD] # (n_blocks, HD)
comp_kv = compressed[:, :HD]
# Quantize to FP8 (noPE) + BF16 (RoPE) — same as production path
from dsv4.kernels.cuda.loader import get_cuda_module
kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
nope_fp32 = comp_kv[:, :NOPE].contiguous()
rope_bf16 = comp_kv[:, NOPE:].bfloat16().contiguous()
nope_fp8, nope_scale = kv_mod.quantize_fp8_e4m3_from_fp32(nope_fp32)
# Dequantize back
nope_dequant = nope_fp8.view(torch.float8_e4m3fn).float() * nope_scale.unsqueeze(-1).float()
comp_kv_rt = torch.cat([nope_dequant, rope_bf16.float()], dim=-1)
@@ -102,94 +95,26 @@ def main():
if cos < 0.999: all_pass = False
print(f" T={T}: n_blocks={n_blocks} FP8/BF16 round-trip cos={cos:.6f} {status}")
# ---- Test 3: KV Cache gathering (mixed storage) ----
print("\n--- Test 3: KV Cache gathering with FP8/BF16 mixed storage ---")
from single_shot_inference import KVCache
import json
# Use model config
cfg = {
"num_attention_heads": 128,
"head_dim": HD,
"qk_rope_head_dim": ROPE,
"hidden_size": 7168,
}
for ratio in [4, 128]:
cache = KVCache(
head_dim=HD, window_size=128, max_comp=65536,
device=device, indexer_key_dim=128,
compress_ratio=ratio, indexer_top_k=1024, rope_dim=ROPE
)
# Simulate adding compressed KV entries
n_comp = 16 if ratio == 128 else 64
comp_nope_fp8 = torch.randint(0, 200, (n_comp, NOPE), dtype=torch.uint8, device=device)
comp_nope_scale = torch.rand(n_comp, dtype=torch.float32, device=device) * 0.1 + 0.01
comp_rope_bf16 = torch.randn(n_comp, ROPE, dtype=torch.bfloat16, device=device) * 0.3
comp_pos = torch.arange(n_comp, dtype=torch.long, device=device) * ratio
cache.set_compressed_mixed(comp_nope_fp8, comp_nope_scale, comp_rope_bf16, comp_pos)
# Add SWA entries
swa_len = min(128, n_comp)
swa_kv = torch.randn(swa_len, HD, dtype=torch.bfloat16, device=device) * 0.3
swa_pos = torch.arange(swa_len, dtype=torch.long, device=device) + n_comp * ratio
for i in range(swa_len):
cache.append_swa(swa_kv[i:i+1], swa_pos[i:i+1])
# Gather all (HCA path)
if ratio > 4:
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = cache.gather_mixed_all()
else:
# CSA: use top-k indices
tk = torch.arange(min(cache.n_comp, 16), device=device)
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = cache.gather_mixed_selective(tk)
total_len = kv_nope_scale.shape[0]
# Validate gathered shapes
assert kv_nope_fp8.shape == (total_len, NOPE), f"Wrong nope shape: {kv_nope_fp8.shape} vs ({total_len}, {NOPE})"
assert kv_nope_scale.shape == (total_len,), f"Wrong scale shape: {kv_nope_scale.shape} vs ({total_len},)"
assert kv_rope_bf16.shape == (total_len, ROPE), f"Wrong rope shape: {kv_rope_bf16.shape} vs ({total_len}, {ROPE})"
# Dequantize and check values
nope_dequant = kv_nope_fp8.view(torch.float8_e4m3fn).float() * kv_nope_scale.unsqueeze(-1).float()
# Compare against original compressed entries (first n_comp rows)
if ratio > 4:
orig_nope = comp_nope_fp8[:n_comp].view(torch.float8_e4m3fn).float() * comp_nope_scale[:n_comp].unsqueeze(-1).float()
else:
orig_nope = comp_nope_fp8[:min(n_comp,16)].view(torch.float8_e4m3fn).float() * comp_nope_scale[:min(n_comp,16)].unsqueeze(-1).float()
cos = cosine(nope_dequant[:orig_nope.shape[0]], orig_nope)
status = "PASS" if cos > 0.9999 else "FAIL"
if cos < 0.9999: all_pass = False
print(f" ratio={ratio}: n_comp={n_comp} swa_len={swa_len} gathered_len={total_len} "
f"dequant cos={cos:.6f} {status}")
# ---- Test 4: FMHA with gathered mixed KV vs SDPA ----
print("\n--- Test 4: B1 FMHA with mixed FP8/BF16 gathered KV vs SDPA ---")
# ---- Test 3: FMHA with mixed FP8/BF16 KV vs SDPA ----
print("\n--- Test 3: B1 FMHA with mixed FP8/BF16 KV vs SDPA ---")
from dsv4.kernels.attention.production import dsv4_attention_mixed_fp8_decode
for N in [128, 512, 1024]:
# Create mixed-format KV (as if gathered from cache)
kv_nope_fp8 = torch.randint(0, 200, (N, NOPE), dtype=torch.uint8, device=device)
kv_nope_scale = torch.rand(N, dtype=torch.float32, device=device) * 0.1 + 0.01
kv_rope_bf16 = torch.randn(N, ROPE, dtype=torch.bfloat16, device=device) * 0.3
# Q: (n_h, T=1, HD) BF16
q = torch.randn(n_h, 1, HD, dtype=torch.bfloat16, device=device) * 0.3
# Production FMHA
attn_out = dsv4_attention_mixed_fp8_decode(
q=q, k_nope_fp8=kv_nope_fp8, k_nope_scale=kv_nope_scale,
k_rope_bf16=kv_rope_bf16, scale=scale, rope_dim=ROPE)
# Reference: dequantize all KV to BF16, run SDPA
nope_dequant = kv_nope_fp8.view(torch.float8_e4m3fn).float() * kv_nope_scale.unsqueeze(-1).float()
k_full = torch.cat([nope_dequant.bfloat16(), kv_rope_bf16], dim=-1) # (N, HD)
k_4d = k_full.unsqueeze(0).unsqueeze(0) # (1, 1, N, HD)
k_full = torch.cat([nope_dequant.bfloat16(), kv_rope_bf16], dim=-1)
k_4d = k_full.unsqueeze(0).unsqueeze(0)
v_4d = k_4d.clone()
q_4d = q.unsqueeze(0) # (1, n_h, 1, HD)
q_4d = q.unsqueeze(0)
o_ref = F.scaled_dot_product_attention(q_4d, k_4d, v_4d, scale=scale)
cos = cosine(attn_out, o_ref.squeeze(0))