83 lines
2.8 KiB
Python
83 lines
2.8 KiB
Python
"""Test production compressor kernel (CSA + HCA reduce)."""
|
|
import torch
|
|
import math
|
|
|
|
def test_csa_compress():
|
|
"""CSA: ratio=4, overlapping Ca/Cb streams."""
|
|
torch.manual_seed(42)
|
|
device = 'cuda'
|
|
hd = 512
|
|
m = 4
|
|
T = 16 # 4 blocks of 4 tokens
|
|
n_blocks = T // m
|
|
|
|
# Create synthetic kv and gate projections
|
|
kv = torch.randn(T, 2 * hd, dtype=torch.float32, device=device)
|
|
gate = torch.randn(T, 2 * hd, dtype=torch.float32, device=device)
|
|
|
|
# Reference: PyTorch
|
|
Ca = kv[:, :hd].reshape(n_blocks, m, hd)
|
|
Cb = kv[:, hd:].reshape(n_blocks, m, hd)
|
|
Ga = gate[:, :hd].reshape(n_blocks, m, hd)
|
|
Gb = gate[:, hd:].reshape(n_blocks, m, hd)
|
|
|
|
ref = []
|
|
for bi in range(n_blocks):
|
|
if bi > 0:
|
|
block_kv = torch.cat([Ca[bi-1], Cb[bi]], dim=0)
|
|
block_gate = torch.cat([Ga[bi-1], Gb[bi]], dim=0)
|
|
else:
|
|
block_kv = Cb[bi]
|
|
block_gate = Gb[bi]
|
|
probs = torch.softmax(block_gate, dim=0)
|
|
compressed = (probs * block_kv).sum(0)
|
|
ref.append(compressed)
|
|
ref = torch.stack(ref)
|
|
|
|
# Production: CUDA kernel
|
|
from dsv4.kernels.compressor.production_compress import csa_compress_production
|
|
prod = csa_compress_production(kv, gate, None, None, m=m)
|
|
|
|
cos = torch.nn.functional.cosine_similarity(ref.flatten().float(), prod.flatten().float(), dim=0).item()
|
|
max_err = (ref - prod).abs().max().item()
|
|
print(f"CSA compress: cos={cos:.6f} max_err={max_err:.6f} ref_max={ref.abs().max().item():.4f} prod_max={prod.abs().max().item():.4f}")
|
|
assert cos > 0.999, f"CSA compress cosine too low: {cos}"
|
|
print(" PASSED")
|
|
|
|
def test_hca_compress():
|
|
"""HCA: ratio=128, single stream."""
|
|
torch.manual_seed(42)
|
|
device = 'cuda'
|
|
hd = 512
|
|
m = 8 # Use 8 instead of 128 for test speed
|
|
T = 24 # 3 blocks
|
|
n_blocks = T // m
|
|
|
|
kv = torch.randn(T, hd, dtype=torch.float32, device=device)
|
|
gate = torch.randn(T, hd, dtype=torch.float32, device=device)
|
|
|
|
# Reference
|
|
ref = []
|
|
for bi in range(n_blocks):
|
|
block_kv = kv[bi*m:(bi+1)*m]
|
|
block_gate = gate[bi*m:(bi+1)*m]
|
|
probs = torch.softmax(block_gate, dim=0)
|
|
compressed = (probs * block_kv).sum(0)
|
|
ref.append(compressed)
|
|
ref = torch.stack(ref)
|
|
|
|
# Production
|
|
from dsv4.kernels.compressor.production_compress import hca_compress_production
|
|
prod = hca_compress_production(kv, gate, None, None, m=m)
|
|
|
|
cos = torch.nn.functional.cosine_similarity(ref.flatten().float(), prod.flatten().float(), dim=0).item()
|
|
max_err = (ref - prod).abs().max().item()
|
|
print(f"HCA compress: cos={cos:.6f} max_err={max_err:.6f}")
|
|
assert cos > 0.999, f"HCA compress cosine too low: {cos}"
|
|
print(" PASSED")
|
|
|
|
if __name__ == "__main__":
|
|
test_csa_compress()
|
|
test_hca_compress()
|
|
print("\nAll compressor tests PASSED")
|