Files
nvfp4-megamoe-kernel/tests/unit/test_production_compress.py

83 lines
2.8 KiB
Python

"""Test production compressor kernel (CSA + HCA reduce)."""
import torch
import math
def test_csa_compress():
"""CSA: ratio=4, overlapping Ca/Cb streams."""
torch.manual_seed(42)
device = 'cuda'
hd = 512
m = 4
T = 16 # 4 blocks of 4 tokens
n_blocks = T // m
# Create synthetic kv and gate projections
kv = torch.randn(T, 2 * hd, dtype=torch.float32, device=device)
gate = torch.randn(T, 2 * hd, dtype=torch.float32, device=device)
# Reference: PyTorch
Ca = kv[:, :hd].reshape(n_blocks, m, hd)
Cb = kv[:, hd:].reshape(n_blocks, m, hd)
Ga = gate[:, :hd].reshape(n_blocks, m, hd)
Gb = gate[:, hd:].reshape(n_blocks, m, hd)
ref = []
for bi in range(n_blocks):
if bi > 0:
block_kv = torch.cat([Ca[bi-1], Cb[bi]], dim=0)
block_gate = torch.cat([Ga[bi-1], Gb[bi]], dim=0)
else:
block_kv = Cb[bi]
block_gate = Gb[bi]
probs = torch.softmax(block_gate, dim=0)
compressed = (probs * block_kv).sum(0)
ref.append(compressed)
ref = torch.stack(ref)
# Production: CUDA kernel
from dsv4.kernels.compressor.production_compress import csa_compress_production
prod = csa_compress_production(kv, gate, None, None, m=m)
cos = torch.nn.functional.cosine_similarity(ref.flatten().float(), prod.flatten().float(), dim=0).item()
max_err = (ref - prod).abs().max().item()
print(f"CSA compress: cos={cos:.6f} max_err={max_err:.6f} ref_max={ref.abs().max().item():.4f} prod_max={prod.abs().max().item():.4f}")
assert cos > 0.999, f"CSA compress cosine too low: {cos}"
print(" PASSED")
def test_hca_compress():
"""HCA: ratio=128, single stream."""
torch.manual_seed(42)
device = 'cuda'
hd = 512
m = 8 # Use 8 instead of 128 for test speed
T = 24 # 3 blocks
n_blocks = T // m
kv = torch.randn(T, hd, dtype=torch.float32, device=device)
gate = torch.randn(T, hd, dtype=torch.float32, device=device)
# Reference
ref = []
for bi in range(n_blocks):
block_kv = kv[bi*m:(bi+1)*m]
block_gate = gate[bi*m:(bi+1)*m]
probs = torch.softmax(block_gate, dim=0)
compressed = (probs * block_kv).sum(0)
ref.append(compressed)
ref = torch.stack(ref)
# Production
from dsv4.kernels.compressor.production_compress import hca_compress_production
prod = hca_compress_production(kv, gate, None, None, m=m)
cos = torch.nn.functional.cosine_similarity(ref.flatten().float(), prod.flatten().float(), dim=0).item()
max_err = (ref - prod).abs().max().item()
print(f"HCA compress: cos={cos:.6f} max_err={max_err:.6f}")
assert cos > 0.999, f"HCA compress cosine too low: {cos}"
print(" PASSED")
if __name__ == "__main__":
test_csa_compress()
test_hca_compress()
print("\nAll compressor tests PASSED")