- Move test_*.py → tests/integration/ - Move probe_*.py, dump_*.py → helpers/ - Move PERFORMANCE_AUDIT.md → docs/ - Move single_shot_PYTORCH_REFERENCE.py → dsv4/reference/ - Fix 3 import references in test_layer_comparison, test_mhc_comparison, test_compressor_position_bias - Add helpers/import_closure.py (dead-code detection tool)
211 lines
7.4 KiB
Python
211 lines
7.4 KiB
Python
"""Test compressor CUDA kernel with position_bias.
|
|
|
|
Verifies that compressor_reduce.cu produces identical output to the
|
|
PyTorch reference when position_bias is provided.
|
|
|
|
CSA (m=4): position_bias is (m, 2*hd), added to both kv and gate
|
|
HCA (m=128): position_bias is (m, hd), added to both kv and gate
|
|
"""
|
|
|
|
import torch
|
|
import sys
|
|
import os
|
|
|
|
# Add kernel path
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
|
|
from dsv4.kernels.compressor.production_compress import csa_compress_production, hca_compress_production
|
|
|
|
|
|
def test_csa_position_bias():
|
|
"""CSA compress with position_bias: CUDA kernel vs PyTorch reference."""
|
|
torch.manual_seed(42)
|
|
device = "cuda"
|
|
T = 16 # 4 complete blocks with m=4
|
|
hd = 512
|
|
m = 4
|
|
n_blocks = T // m
|
|
|
|
# Create test data
|
|
kv = torch.randn(T, 2 * hd, device=device, dtype=torch.bfloat16).float()
|
|
gate = torch.randn(T, 2 * hd, device=device, dtype=torch.bfloat16).float()
|
|
position_bias = torch.randn(m, 2 * hd, device=device, dtype=torch.bfloat16)
|
|
kv_norm_weight = torch.randn(hd, device=device, dtype=torch.bfloat16)
|
|
|
|
# --- CUDA kernel path ---
|
|
compressed_cuda = csa_compress_production(kv, gate, position_bias, kv_norm_weight, m=m)
|
|
|
|
# --- PyTorch reference path (matches dsv4/reference/single_shot_PYTORCH_REFERENCE.py) ---
|
|
kv_ref = kv.clone()
|
|
gate_ref = gate.clone()
|
|
# Add position_bias cyclic per block
|
|
ape = position_bias.float()
|
|
for bi in range(n_blocks):
|
|
s, e = bi * m, (bi + 1) * m
|
|
kv_ref[s:e] += ape[:m]
|
|
gate_ref[s:e] += ape[:m]
|
|
|
|
# CSA softmax + weighted sum per block
|
|
comp_list = []
|
|
for bi in range(n_blocks):
|
|
if bi > 0:
|
|
# Overlap: Ca[bi-1] + Cb[bi]
|
|
Ca_prev = kv_ref[(bi-1)*m : bi*m, :hd] # (m, hd)
|
|
Cb_cur = kv_ref[bi*m : (bi+1)*m, hd:] # (m, hd)
|
|
Ga_prev = gate_ref[(bi-1)*m : bi*m, :hd]
|
|
Gb_cur = gate_ref[bi*m : (bi+1)*m, hd:]
|
|
block_kv = torch.cat([Ca_prev, Cb_cur], dim=0) # (2m, hd)
|
|
block_gate = torch.cat([Ga_prev, Gb_cur], dim=0)
|
|
else:
|
|
# Block 0: only Cb[0]
|
|
block_kv = kv_ref[:m, hd:] # (m, hd)
|
|
block_gate = gate_ref[:m, hd:]
|
|
|
|
probs = torch.softmax(block_gate.float(), dim=0) # (n_tokens, hd)
|
|
compressed = (probs * block_kv.float()).sum(0) # (hd,)
|
|
|
|
# kv_norm
|
|
nw = kv_norm_weight.float()
|
|
compressed = compressed * compressed.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() * nw
|
|
comp_list.append(compressed)
|
|
|
|
compressed_ref = torch.stack(comp_list).bfloat16()
|
|
|
|
# Compare
|
|
cos = torch.nn.functional.cosine_similarity(
|
|
compressed_cuda.flatten().unsqueeze(0).float(),
|
|
compressed_ref.flatten().unsqueeze(0).float()
|
|
).item()
|
|
max_diff = (compressed_cuda.float() - compressed_ref.float()).abs().max().item()
|
|
|
|
print(f"CSA position_bias test (T={T}, hd={hd}, m={m}, n_blocks={n_blocks}):")
|
|
print(f" Cosine similarity: {cos:.6f}")
|
|
print(f" Max absolute diff: {max_diff:.6f}")
|
|
|
|
if cos < 0.999:
|
|
print(f" FAIL: cos={cos:.6f} < 0.999")
|
|
# Print per-block comparison
|
|
for bi in range(n_blocks):
|
|
cb = torch.nn.functional.cosine_similarity(
|
|
compressed_cuda[bi].unsqueeze(0).float(),
|
|
compressed_ref[bi].unsqueeze(0).float()
|
|
).item()
|
|
md = (compressed_cuda[bi].float() - compressed_ref[bi].float()).abs().max().item()
|
|
print(f" Block {bi}: cos={cb:.6f}, max_diff={md:.6f}")
|
|
sys.exit(1)
|
|
else:
|
|
print(f" PASS ✓")
|
|
|
|
|
|
def test_csa_no_position_bias():
|
|
"""CSA compress without position_bias: verify kernel works with None."""
|
|
torch.manual_seed(123)
|
|
device = "cuda"
|
|
T = 8
|
|
hd = 512
|
|
m = 4
|
|
n_blocks = T // m
|
|
|
|
kv = torch.randn(T, 2 * hd, device=device, dtype=torch.bfloat16).float()
|
|
gate = torch.randn(T, 2 * hd, device=device, dtype=torch.bfloat16).float()
|
|
kv_norm_weight = torch.randn(hd, device=device, dtype=torch.bfloat16)
|
|
|
|
# CUDA kernel with None position_bias
|
|
compressed_cuda = csa_compress_production(kv, gate, None, kv_norm_weight, m=m)
|
|
|
|
# PyTorch reference (no position_bias)
|
|
comp_list = []
|
|
for bi in range(n_blocks):
|
|
if bi > 0:
|
|
Ca_prev = kv[(bi-1)*m : bi*m, :hd]
|
|
Cb_cur = kv[bi*m : (bi+1)*m, hd:]
|
|
Ga_prev = gate[(bi-1)*m : bi*m, :hd]
|
|
Gb_cur = gate[bi*m : (bi+1)*m, hd:]
|
|
block_kv = torch.cat([Ca_prev, Cb_cur], dim=0)
|
|
block_gate = torch.cat([Ga_prev, Gb_cur], dim=0)
|
|
else:
|
|
block_kv = kv[:m, hd:]
|
|
block_gate = gate[:m, hd:]
|
|
|
|
probs = torch.softmax(block_gate.float(), dim=0)
|
|
compressed = (probs * block_kv.float()).sum(0)
|
|
nw = kv_norm_weight.float()
|
|
compressed = compressed * compressed.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() * nw
|
|
comp_list.append(compressed)
|
|
|
|
compressed_ref = torch.stack(comp_list).bfloat16()
|
|
|
|
cos = torch.nn.functional.cosine_similarity(
|
|
compressed_cuda.flatten().unsqueeze(0).float(),
|
|
compressed_ref.flatten().unsqueeze(0).float()
|
|
).item()
|
|
|
|
print(f"CSA no position_bias test (T={T}, hd={hd}): cos={cos:.6f}", end=" ")
|
|
if cos < 0.999:
|
|
print("FAIL")
|
|
sys.exit(1)
|
|
else:
|
|
print("PASS ✓")
|
|
|
|
|
|
def test_hca_position_bias():
|
|
"""HCA compress with position_bias: CUDA kernel vs PyTorch reference."""
|
|
torch.manual_seed(99)
|
|
device = "cuda"
|
|
hd = 512
|
|
m = 128
|
|
T = 256 # 2 complete blocks
|
|
n_blocks = T // m
|
|
|
|
kv = torch.randn(T, hd, device=device, dtype=torch.bfloat16).float()
|
|
gate = torch.randn(T, hd, device=device, dtype=torch.bfloat16).float()
|
|
position_bias = torch.randn(m, hd, device=device, dtype=torch.bfloat16)
|
|
kv_norm_weight = torch.randn(hd, device=device, dtype=torch.bfloat16)
|
|
|
|
# CUDA kernel
|
|
compressed_cuda = hca_compress_production(kv, gate, position_bias, kv_norm_weight, m=m)
|
|
|
|
# PyTorch reference
|
|
kv_ref = kv.clone()
|
|
gate_ref = gate.clone()
|
|
ape = position_bias.float()
|
|
for bi in range(n_blocks):
|
|
s, e = bi * m, (bi + 1) * m
|
|
kv_ref[s:e] += ape[:m]
|
|
gate_ref[s:e] += ape[:m]
|
|
|
|
comp_list = []
|
|
for bi in range(n_blocks):
|
|
block_kv = kv_ref[bi*m : (bi+1)*m] # (m, hd)
|
|
block_gate = gate_ref[bi*m : (bi+1)*m]
|
|
probs = torch.softmax(block_gate.float(), dim=0)
|
|
compressed = (probs * block_kv.float()).sum(0)
|
|
nw = kv_norm_weight.float()
|
|
compressed = compressed * compressed.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() * nw
|
|
comp_list.append(compressed)
|
|
|
|
compressed_ref = torch.stack(comp_list).bfloat16()
|
|
|
|
cos = torch.nn.functional.cosine_similarity(
|
|
compressed_cuda.flatten().unsqueeze(0).float(),
|
|
compressed_ref.flatten().unsqueeze(0).float()
|
|
).item()
|
|
max_diff = (compressed_cuda.float() - compressed_ref.float()).abs().max().item()
|
|
|
|
print(f"HCA position_bias test (T={T}, hd={hd}, m={m}):")
|
|
print(f" Cosine similarity: {cos:.6f}")
|
|
print(f" Max absolute diff: {max_diff:.6f}")
|
|
|
|
if cos < 0.999:
|
|
print(f" FAIL: cos={cos:.6f} < 0.999")
|
|
sys.exit(1)
|
|
else:
|
|
print(f" PASS ✓")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_csa_no_position_bias()
|
|
test_csa_position_bias()
|
|
test_hca_position_bias()
|
|
print("\nAll compressor position_bias tests PASSED ✓")
|