116 lines
5.3 KiB
Python
116 lines
5.3 KiB
Python
#!/usr/bin/env python3
|
|
"""PART A diagnostic: Compressor + FMHA at production scale."""
|
|
import sys, math
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
def cosine(a, b):
|
|
return F.cosine_similarity(a.flatten().float(), b.flatten().float(), dim=0).item()
|
|
|
|
def main():
|
|
HD = 512; NOPE = 448; ROPE = 64; n_h = 128
|
|
scale = 1.0 / math.sqrt(HD)
|
|
device = "cuda:0"
|
|
torch.manual_seed(42)
|
|
|
|
print("=" * 70)
|
|
print("PART A: Compressor + FMHA at Production Scale")
|
|
print("=" * 70)
|
|
|
|
all_pass = True
|
|
|
|
# ---- Test 1: CSA compression round-trip ----
|
|
print("\n--- Test 1: CSA compression (ratio=4) ---")
|
|
from dsv4.kernels.compressor.production_compress import csa_compress_production_fp32
|
|
from dsv4.kernels.cuda.loader import get_cuda_module
|
|
kv_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
|
|
|
|
for T in [4, 16, 32, 64]:
|
|
m = 4; n_blocks = T // m; kv_dim = HD * 2
|
|
kv_proj = torch.randn(T, kv_dim, dtype=torch.float32, device=device) * 0.3
|
|
gate_proj = torch.randn(T, kv_dim, dtype=torch.float32, device=device) * 0.3
|
|
compressed = csa_compress_production_fp32(kv_proj, gate_proj, None, None, m=4)
|
|
if compressed.shape[0] == 0: print(f" T={T}: SKIP"); continue
|
|
comp_kv = compressed[:, :HD]
|
|
nope_fp32 = comp_kv[:, :NOPE].contiguous()
|
|
rope_bf16 = comp_kv[:, NOPE:].bfloat16().contiguous()
|
|
nope_fp8, nope_scale = kv_mod.quantize_fp8_e4m3_from_fp32(nope_fp32)
|
|
nope_dequant = nope_fp8.view(torch.float8_e4m3fn).float() * nope_scale.unsqueeze(-1).float()
|
|
comp_kv_rt = torch.cat([nope_dequant, rope_bf16.float()], dim=-1)
|
|
cos = cosine(comp_kv, comp_kv_rt)
|
|
ok = cos > 0.999
|
|
if not ok: all_pass = False
|
|
print(f" T={T}: n_blocks={n_blocks} cos={cos:.6f} {'PASS' if ok else 'FAIL'}")
|
|
|
|
# ---- Test 2: HCA compression round-trip ----
|
|
print("\n--- Test 2: HCA compression (ratio=128) ---")
|
|
from dsv4.kernels.compressor.production_compress import hca_compress_production_fp32
|
|
|
|
for T in [128, 256]:
|
|
m = 128; n_blocks = T // m
|
|
if n_blocks == 0: print(f" T={T}: SKIP"); continue
|
|
kv_dim = HD * 2
|
|
kv_proj = torch.randn(T, kv_dim, dtype=torch.float32, device=device) * 0.3
|
|
gate_proj = torch.randn(T, kv_dim, dtype=torch.float32, device=device) * 0.3
|
|
compressed = hca_compress_production_fp32(kv_proj, gate_proj, None, None, m=128)
|
|
comp_kv = compressed[:, :HD]
|
|
nope_fp32 = comp_kv[:, :NOPE].contiguous()
|
|
rope_bf16 = comp_kv[:, NOPE:].bfloat16().contiguous()
|
|
nope_fp8, nope_scale = kv_mod.quantize_fp8_e4m3_from_fp32(nope_fp32)
|
|
nope_dequant = nope_fp8.view(torch.float8_e4m3fn).float() * nope_scale.unsqueeze(-1).float()
|
|
comp_kv_rt = torch.cat([nope_dequant, rope_bf16.float()], dim=-1)
|
|
cos = cosine(comp_kv, comp_kv_rt)
|
|
ok = cos > 0.999
|
|
if not ok: all_pass = False
|
|
print(f" T={T}: n_blocks={n_blocks} cos={cos:.6f} {'PASS' if ok else 'FAIL'}")
|
|
|
|
# ---- Test 3: B1 FMHA decode vs SDPA (H=128, MQA) ----
|
|
print("\n--- Test 3: B1 FMHA decode vs SDPA (H=128, MQA) ---")
|
|
from dsv4.kernels.attention.production import dsv4_attention_mixed_fp8_decode
|
|
|
|
for N in [128, 512, 1024]:
|
|
# Realistic FP8 quantized KV
|
|
kv_nope_fp32 = torch.randn(N, NOPE, dtype=torch.float32, device=device) * 0.3
|
|
kv_rope_bf16 = torch.randn(N, ROPE, dtype=torch.bfloat16, device=device) * 0.3
|
|
amax = kv_nope_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
|
|
nope_scale = (amax / 448.0).squeeze(-1)
|
|
nope_clamped = (kv_nope_fp32 / nope_scale.unsqueeze(-1)).clamp(-448, 448)
|
|
kv_nope_fp8 = nope_clamped.to(torch.float8_e4m3fn).view(torch.uint8).contiguous()
|
|
kv_nope_scale = nope_scale.contiguous()
|
|
|
|
q = torch.randn(n_h, 1, HD, dtype=torch.bfloat16, device=device) * 0.3
|
|
|
|
# Production FMHA (128 heads, each attends to the same KV)
|
|
attn_out = dsv4_attention_mixed_fp8_decode(
|
|
q=q, k_nope_fp8=kv_nope_fp8, k_nope_scale=kv_nope_scale,
|
|
k_rope_bf16=kv_rope_bf16, scale=scale, rope_dim=ROPE)
|
|
|
|
# Reference: dequantize, run SDPA per-head (MQA: all Q heads share 1 KV head)
|
|
nope_dequant = kv_nope_fp8.view(torch.float8_e4m3fn).float() * kv_nope_scale.unsqueeze(-1).float()
|
|
k_full = torch.cat([nope_dequant.bfloat16(), kv_rope_bf16], dim=-1)
|
|
# MQA reference: expand K/V for all Q heads
|
|
k_expanded = k_full.unsqueeze(0).expand(n_h, -1, -1) # (n_h, N, HD)
|
|
# SDPA per head
|
|
o_ref = torch.zeros_like(attn_out)
|
|
for h in range(n_h):
|
|
q_h = q[h:h+1] # (1, 1, HD)
|
|
k_h = k_full.unsqueeze(0).unsqueeze(0) # (1, 1, N, HD)
|
|
v_h = k_h.clone()
|
|
q_4d = q_h.unsqueeze(0) # (1, 1, 1, HD)
|
|
o_h = F.scaled_dot_product_attention(q_4d, k_h, v_h, scale=scale)
|
|
o_ref[h] = o_h.squeeze()
|
|
|
|
cos = cosine(attn_out, o_ref)
|
|
ok = cos > 0.999
|
|
if not ok: all_pass = False
|
|
print(f" N={N}: cos={cos:.6f} {'PASS' if ok else 'FAIL'}")
|
|
|
|
# ---- Summary ----
|
|
print("\n" + "=" * 70)
|
|
print(f"OVERALL: {'PASS' if all_pass else 'FAIL'}")
|
|
print("=" * 70)
|
|
sys.exit(0 if all_pass else 1)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|