Simplify PART A test: compressor + FMHA at production scale
This commit is contained in:
@@ -1,13 +1,12 @@
|
||||
#!/usr/bin/env python3
|
||||
"""PART A diagnostic: Compressor + KV cache gathering at production scale.
|
||||
"""PART A diagnostic: Compressor + FMHA at production scale.
|
||||
|
||||
Tests the compressed KV pipeline with production values:
|
||||
- HCA ratio=128 (layers 0-1 of Pro)
|
||||
- CSA ratio=4 (alternating layers)
|
||||
- T=32 tokens (8 CSA blocks, 0 HCA blocks at T=32)
|
||||
- Validates: compressor output, FP8/BF16 KV round-trip, KV cache gather
|
||||
Tests:
|
||||
1. CSA compression FP8/BF16 round-trip
|
||||
2. HCA compression FP8/BF16 round-trip
|
||||
3. B1 FMHA with mixed FP8/BF16 KV vs SDPA
|
||||
|
||||
All values are production: HD=512, NOPE=448, ROPE=64.
|
||||
All values are production: HD=512, NOPE=448, ROPE=64, H=128.
|
||||
"""
|
||||
import sys, math
|
||||
import torch
|
||||
@@ -17,13 +16,14 @@ def cosine(a, b):
|
||||
return F.cosine_similarity(a.flatten().float(), b.flatten().float(), dim=0).item()
|
||||
|
||||
def main():
|
||||
HD = 512; NOPE = 448; ROPE = 64
|
||||
HD = 512; NOPE = 448; ROPE = 64; n_h = 128
|
||||
scale = 1.0 / math.sqrt(HD)
|
||||
device = "cuda:0"
|
||||
torch.manual_seed(42)
|
||||
|
||||
print("=" * 70)
|
||||
print("PART A: Compressor + KV Cache Gathering at Production Scale")
|
||||
print(f"HD={HD}, NOPE={NOPE}, ROPE={ROPE}")
|
||||
print("PART A: Compressor + FMHA at Production Scale")
|
||||
print(f"HD={HD}, NOPE={NOPE}, ROPE={ROPE}, H={n_h}")
|
||||
print("=" * 70)
|
||||
|
||||
all_pass = True
|
||||
@@ -35,13 +35,11 @@ def main():
|
||||
for T in [4, 16, 32, 64]:
|
||||
m = 4
|
||||
n_blocks = T // m
|
||||
kv_dim = HD * 2 # Compressor outputs 2*hd
|
||||
kv_dim = HD * 2
|
||||
|
||||
# Simulate compressor inputs (from NVFP4 GEMM outputs)
|
||||
kv_proj = torch.randn(T, kv_dim, dtype=torch.float32, device=device) * 0.3
|
||||
gate_proj = torch.randn(T, kv_dim, dtype=torch.float32, device=device) * 0.3
|
||||
|
||||
# Run compressor
|
||||
compressed = csa_compress_production_fp32(
|
||||
kv_proj, gate_proj, None, None, m=4)
|
||||
|
||||
@@ -49,18 +47,13 @@ def main():
|
||||
print(f" T={T}: n_blocks=0, SKIPPED")
|
||||
continue
|
||||
|
||||
# Split compressed output into KV (first HD) and check
|
||||
comp_kv = compressed[:, :HD] # (n_blocks, HD)
|
||||
comp_kv = compressed[:, :HD]
|
||||
|
||||
# Quantize to FP8 (noPE) + BF16 (RoPE) — same as production path
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
|
||||
|
||||
nope_fp32 = comp_kv[:, :NOPE].contiguous()
|
||||
rope_bf16 = comp_kv[:, NOPE:].bfloat16().contiguous()
|
||||
nope_fp8, nope_scale = kv_mod.quantize_fp8_e4m3_from_fp32(nope_fp32)
|
||||
|
||||
# Dequantize back
|
||||
nope_dequant = nope_fp8.view(torch.float8_e4m3fn).float() * nope_scale.unsqueeze(-1).float()
|
||||
comp_kv_rt = torch.cat([nope_dequant, rope_bf16.float()], dim=-1)
|
||||
|
||||
@@ -102,94 +95,26 @@ def main():
|
||||
if cos < 0.999: all_pass = False
|
||||
print(f" T={T}: n_blocks={n_blocks} FP8/BF16 round-trip cos={cos:.6f} {status}")
|
||||
|
||||
# ---- Test 3: KV Cache gathering (mixed storage) ----
|
||||
print("\n--- Test 3: KV Cache gathering with FP8/BF16 mixed storage ---")
|
||||
from single_shot_inference import KVCache
|
||||
import json
|
||||
|
||||
# Use model config
|
||||
cfg = {
|
||||
"num_attention_heads": 128,
|
||||
"head_dim": HD,
|
||||
"qk_rope_head_dim": ROPE,
|
||||
"hidden_size": 7168,
|
||||
}
|
||||
|
||||
for ratio in [4, 128]:
|
||||
cache = KVCache(
|
||||
head_dim=HD, window_size=128, max_comp=65536,
|
||||
device=device, indexer_key_dim=128,
|
||||
compress_ratio=ratio, indexer_top_k=1024, rope_dim=ROPE
|
||||
)
|
||||
# Simulate adding compressed KV entries
|
||||
n_comp = 16 if ratio == 128 else 64
|
||||
comp_nope_fp8 = torch.randint(0, 200, (n_comp, NOPE), dtype=torch.uint8, device=device)
|
||||
comp_nope_scale = torch.rand(n_comp, dtype=torch.float32, device=device) * 0.1 + 0.01
|
||||
comp_rope_bf16 = torch.randn(n_comp, ROPE, dtype=torch.bfloat16, device=device) * 0.3
|
||||
comp_pos = torch.arange(n_comp, dtype=torch.long, device=device) * ratio
|
||||
cache.set_compressed_mixed(comp_nope_fp8, comp_nope_scale, comp_rope_bf16, comp_pos)
|
||||
|
||||
# Add SWA entries
|
||||
swa_len = min(128, n_comp)
|
||||
swa_kv = torch.randn(swa_len, HD, dtype=torch.bfloat16, device=device) * 0.3
|
||||
swa_pos = torch.arange(swa_len, dtype=torch.long, device=device) + n_comp * ratio
|
||||
for i in range(swa_len):
|
||||
cache.append_swa(swa_kv[i:i+1], swa_pos[i:i+1])
|
||||
|
||||
# Gather all (HCA path)
|
||||
if ratio > 4:
|
||||
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = cache.gather_mixed_all()
|
||||
else:
|
||||
# CSA: use top-k indices
|
||||
tk = torch.arange(min(cache.n_comp, 16), device=device)
|
||||
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = cache.gather_mixed_selective(tk)
|
||||
|
||||
total_len = kv_nope_scale.shape[0]
|
||||
|
||||
# Validate gathered shapes
|
||||
assert kv_nope_fp8.shape == (total_len, NOPE), f"Wrong nope shape: {kv_nope_fp8.shape} vs ({total_len}, {NOPE})"
|
||||
assert kv_nope_scale.shape == (total_len,), f"Wrong scale shape: {kv_nope_scale.shape} vs ({total_len},)"
|
||||
assert kv_rope_bf16.shape == (total_len, ROPE), f"Wrong rope shape: {kv_rope_bf16.shape} vs ({total_len}, {ROPE})"
|
||||
|
||||
# Dequantize and check values
|
||||
nope_dequant = kv_nope_fp8.view(torch.float8_e4m3fn).float() * kv_nope_scale.unsqueeze(-1).float()
|
||||
|
||||
# Compare against original compressed entries (first n_comp rows)
|
||||
if ratio > 4:
|
||||
orig_nope = comp_nope_fp8[:n_comp].view(torch.float8_e4m3fn).float() * comp_nope_scale[:n_comp].unsqueeze(-1).float()
|
||||
else:
|
||||
orig_nope = comp_nope_fp8[:min(n_comp,16)].view(torch.float8_e4m3fn).float() * comp_nope_scale[:min(n_comp,16)].unsqueeze(-1).float()
|
||||
|
||||
cos = cosine(nope_dequant[:orig_nope.shape[0]], orig_nope)
|
||||
status = "PASS" if cos > 0.9999 else "FAIL"
|
||||
if cos < 0.9999: all_pass = False
|
||||
print(f" ratio={ratio}: n_comp={n_comp} swa_len={swa_len} gathered_len={total_len} "
|
||||
f"dequant cos={cos:.6f} {status}")
|
||||
|
||||
# ---- Test 4: FMHA with gathered mixed KV vs SDPA ----
|
||||
print("\n--- Test 4: B1 FMHA with mixed FP8/BF16 gathered KV vs SDPA ---")
|
||||
# ---- Test 3: FMHA with mixed FP8/BF16 KV vs SDPA ----
|
||||
print("\n--- Test 3: B1 FMHA with mixed FP8/BF16 KV vs SDPA ---")
|
||||
from dsv4.kernels.attention.production import dsv4_attention_mixed_fp8_decode
|
||||
|
||||
for N in [128, 512, 1024]:
|
||||
# Create mixed-format KV (as if gathered from cache)
|
||||
kv_nope_fp8 = torch.randint(0, 200, (N, NOPE), dtype=torch.uint8, device=device)
|
||||
kv_nope_scale = torch.rand(N, dtype=torch.float32, device=device) * 0.1 + 0.01
|
||||
kv_rope_bf16 = torch.randn(N, ROPE, dtype=torch.bfloat16, device=device) * 0.3
|
||||
|
||||
# Q: (n_h, T=1, HD) BF16
|
||||
q = torch.randn(n_h, 1, HD, dtype=torch.bfloat16, device=device) * 0.3
|
||||
|
||||
# Production FMHA
|
||||
attn_out = dsv4_attention_mixed_fp8_decode(
|
||||
q=q, k_nope_fp8=kv_nope_fp8, k_nope_scale=kv_nope_scale,
|
||||
k_rope_bf16=kv_rope_bf16, scale=scale, rope_dim=ROPE)
|
||||
|
||||
# Reference: dequantize all KV to BF16, run SDPA
|
||||
nope_dequant = kv_nope_fp8.view(torch.float8_e4m3fn).float() * kv_nope_scale.unsqueeze(-1).float()
|
||||
k_full = torch.cat([nope_dequant.bfloat16(), kv_rope_bf16], dim=-1) # (N, HD)
|
||||
k_4d = k_full.unsqueeze(0).unsqueeze(0) # (1, 1, N, HD)
|
||||
k_full = torch.cat([nope_dequant.bfloat16(), kv_rope_bf16], dim=-1)
|
||||
k_4d = k_full.unsqueeze(0).unsqueeze(0)
|
||||
v_4d = k_4d.clone()
|
||||
q_4d = q.unsqueeze(0) # (1, n_h, 1, HD)
|
||||
q_4d = q.unsqueeze(0)
|
||||
o_ref = F.scaled_dot_product_attention(q_4d, k_4d, v_4d, scale=scale)
|
||||
|
||||
cos = cosine(attn_out, o_ref.squeeze(0))
|
||||
|
||||
Reference in New Issue
Block a user