Files
nvfp4-megamoe-kernel/tests/unit/test_compressor_position_bias.py
biondizzle 8de47e26ce Cleanup Step 1: Move root-level files to proper directories
- Move test_*.py → tests/integration/
- Move probe_*.py, dump_*.py → helpers/
- Move PERFORMANCE_AUDIT.md → docs/
- Move single_shot_PYTORCH_REFERENCE.py → dsv4/reference/
- Fix 3 import references in test_layer_comparison, test_mhc_comparison, test_compressor_position_bias
- Add helpers/import_closure.py (dead-code detection tool)
2026-06-02 19:24:39 +00:00

211 lines
7.4 KiB
Python

"""Test compressor CUDA kernel with position_bias.
Verifies that compressor_reduce.cu produces identical output to the
PyTorch reference when position_bias is provided.
CSA (m=4): position_bias is (m, 2*hd), added to both kv and gate
HCA (m=128): position_bias is (m, hd), added to both kv and gate
"""
import torch
import sys
import os
# Add kernel path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from dsv4.kernels.compressor.production_compress import csa_compress_production, hca_compress_production
def test_csa_position_bias():
"""CSA compress with position_bias: CUDA kernel vs PyTorch reference."""
torch.manual_seed(42)
device = "cuda"
T = 16 # 4 complete blocks with m=4
hd = 512
m = 4
n_blocks = T // m
# Create test data
kv = torch.randn(T, 2 * hd, device=device, dtype=torch.bfloat16).float()
gate = torch.randn(T, 2 * hd, device=device, dtype=torch.bfloat16).float()
position_bias = torch.randn(m, 2 * hd, device=device, dtype=torch.bfloat16)
kv_norm_weight = torch.randn(hd, device=device, dtype=torch.bfloat16)
# --- CUDA kernel path ---
compressed_cuda = csa_compress_production(kv, gate, position_bias, kv_norm_weight, m=m)
# --- PyTorch reference path (matches dsv4/reference/single_shot_PYTORCH_REFERENCE.py) ---
kv_ref = kv.clone()
gate_ref = gate.clone()
# Add position_bias cyclic per block
ape = position_bias.float()
for bi in range(n_blocks):
s, e = bi * m, (bi + 1) * m
kv_ref[s:e] += ape[:m]
gate_ref[s:e] += ape[:m]
# CSA softmax + weighted sum per block
comp_list = []
for bi in range(n_blocks):
if bi > 0:
# Overlap: Ca[bi-1] + Cb[bi]
Ca_prev = kv_ref[(bi-1)*m : bi*m, :hd] # (m, hd)
Cb_cur = kv_ref[bi*m : (bi+1)*m, hd:] # (m, hd)
Ga_prev = gate_ref[(bi-1)*m : bi*m, :hd]
Gb_cur = gate_ref[bi*m : (bi+1)*m, hd:]
block_kv = torch.cat([Ca_prev, Cb_cur], dim=0) # (2m, hd)
block_gate = torch.cat([Ga_prev, Gb_cur], dim=0)
else:
# Block 0: only Cb[0]
block_kv = kv_ref[:m, hd:] # (m, hd)
block_gate = gate_ref[:m, hd:]
probs = torch.softmax(block_gate.float(), dim=0) # (n_tokens, hd)
compressed = (probs * block_kv.float()).sum(0) # (hd,)
# kv_norm
nw = kv_norm_weight.float()
compressed = compressed * compressed.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() * nw
comp_list.append(compressed)
compressed_ref = torch.stack(comp_list).bfloat16()
# Compare
cos = torch.nn.functional.cosine_similarity(
compressed_cuda.flatten().unsqueeze(0).float(),
compressed_ref.flatten().unsqueeze(0).float()
).item()
max_diff = (compressed_cuda.float() - compressed_ref.float()).abs().max().item()
print(f"CSA position_bias test (T={T}, hd={hd}, m={m}, n_blocks={n_blocks}):")
print(f" Cosine similarity: {cos:.6f}")
print(f" Max absolute diff: {max_diff:.6f}")
if cos < 0.999:
print(f" FAIL: cos={cos:.6f} < 0.999")
# Print per-block comparison
for bi in range(n_blocks):
cb = torch.nn.functional.cosine_similarity(
compressed_cuda[bi].unsqueeze(0).float(),
compressed_ref[bi].unsqueeze(0).float()
).item()
md = (compressed_cuda[bi].float() - compressed_ref[bi].float()).abs().max().item()
print(f" Block {bi}: cos={cb:.6f}, max_diff={md:.6f}")
sys.exit(1)
else:
print(f" PASS ✓")
def test_csa_no_position_bias():
"""CSA compress without position_bias: verify kernel works with None."""
torch.manual_seed(123)
device = "cuda"
T = 8
hd = 512
m = 4
n_blocks = T // m
kv = torch.randn(T, 2 * hd, device=device, dtype=torch.bfloat16).float()
gate = torch.randn(T, 2 * hd, device=device, dtype=torch.bfloat16).float()
kv_norm_weight = torch.randn(hd, device=device, dtype=torch.bfloat16)
# CUDA kernel with None position_bias
compressed_cuda = csa_compress_production(kv, gate, None, kv_norm_weight, m=m)
# PyTorch reference (no position_bias)
comp_list = []
for bi in range(n_blocks):
if bi > 0:
Ca_prev = kv[(bi-1)*m : bi*m, :hd]
Cb_cur = kv[bi*m : (bi+1)*m, hd:]
Ga_prev = gate[(bi-1)*m : bi*m, :hd]
Gb_cur = gate[bi*m : (bi+1)*m, hd:]
block_kv = torch.cat([Ca_prev, Cb_cur], dim=0)
block_gate = torch.cat([Ga_prev, Gb_cur], dim=0)
else:
block_kv = kv[:m, hd:]
block_gate = gate[:m, hd:]
probs = torch.softmax(block_gate.float(), dim=0)
compressed = (probs * block_kv.float()).sum(0)
nw = kv_norm_weight.float()
compressed = compressed * compressed.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() * nw
comp_list.append(compressed)
compressed_ref = torch.stack(comp_list).bfloat16()
cos = torch.nn.functional.cosine_similarity(
compressed_cuda.flatten().unsqueeze(0).float(),
compressed_ref.flatten().unsqueeze(0).float()
).item()
print(f"CSA no position_bias test (T={T}, hd={hd}): cos={cos:.6f}", end=" ")
if cos < 0.999:
print("FAIL")
sys.exit(1)
else:
print("PASS ✓")
def test_hca_position_bias():
"""HCA compress with position_bias: CUDA kernel vs PyTorch reference."""
torch.manual_seed(99)
device = "cuda"
hd = 512
m = 128
T = 256 # 2 complete blocks
n_blocks = T // m
kv = torch.randn(T, hd, device=device, dtype=torch.bfloat16).float()
gate = torch.randn(T, hd, device=device, dtype=torch.bfloat16).float()
position_bias = torch.randn(m, hd, device=device, dtype=torch.bfloat16)
kv_norm_weight = torch.randn(hd, device=device, dtype=torch.bfloat16)
# CUDA kernel
compressed_cuda = hca_compress_production(kv, gate, position_bias, kv_norm_weight, m=m)
# PyTorch reference
kv_ref = kv.clone()
gate_ref = gate.clone()
ape = position_bias.float()
for bi in range(n_blocks):
s, e = bi * m, (bi + 1) * m
kv_ref[s:e] += ape[:m]
gate_ref[s:e] += ape[:m]
comp_list = []
for bi in range(n_blocks):
block_kv = kv_ref[bi*m : (bi+1)*m] # (m, hd)
block_gate = gate_ref[bi*m : (bi+1)*m]
probs = torch.softmax(block_gate.float(), dim=0)
compressed = (probs * block_kv.float()).sum(0)
nw = kv_norm_weight.float()
compressed = compressed * compressed.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() * nw
comp_list.append(compressed)
compressed_ref = torch.stack(comp_list).bfloat16()
cos = torch.nn.functional.cosine_similarity(
compressed_cuda.flatten().unsqueeze(0).float(),
compressed_ref.flatten().unsqueeze(0).float()
).item()
max_diff = (compressed_cuda.float() - compressed_ref.float()).abs().max().item()
print(f"HCA position_bias test (T={T}, hd={hd}, m={m}):")
print(f" Cosine similarity: {cos:.6f}")
print(f" Max absolute diff: {max_diff:.6f}")
if cos < 0.999:
print(f" FAIL: cos={cos:.6f} < 0.999")
sys.exit(1)
else:
print(f" PASS ✓")
if __name__ == "__main__":
test_csa_no_position_bias()
test_csa_position_bias()
test_hca_position_bias()
print("\nAll compressor position_bias tests PASSED ✓")