fix test 4: use silu(gate)+swiglu interleaved (matching fused kernel output)

This commit is contained in:
2026-05-25 16:24:04 +00:00
parent e76ea36337
commit a064b99d3d

View File

@@ -125,36 +125,35 @@ def test_deinterleave_quantize_correctness():
torch.manual_seed(42)
M = 64
intermediate = 512
N = 2 * intermediate # fused output has gate + up interleaved
N = 2 * intermediate # fused output has silu(gate) + swiglu interleaved
# Create SwiGLU-style interleaved data
# Create SwiGLU-style interleaved data matching fused kernel output
# The fused kernel outputs [silu(gate)*8, silu(gate)*up*8, ...] interleaved
gate = torch.randn(M, intermediate, dtype=torch.bfloat16, device='cuda')
up = torch.randn(M, intermediate, dtype=torch.bfloat16, device='cuda')
gate_silu = torch.nn.functional.silu(gate)
swiglu_result = gate_silu * up
# Create interleaved layout: [silu(gate)*8, up*8, ...]
fused = torch.cat([gate_silu, up], dim=-1).unsqueeze(0) # (1, M, N)
# Create interleaved layout: [silu(gate)*8, swiglu*8, ...]
fused = torch.cat([gate_silu, swiglu_result], dim=-1).unsqueeze(0) # (1, M, N)
fused = interleave_l1_weights(fused)[0] # (M, N) interleaved
global_scale = 1.0
# Python reference: deinterleave then quantize
# Python reference: quantize the swiglu_result (odd groups)
ref_fp4, ref_sf = quantize_activation_nvfp4(swiglu_result, global_scale)
# CUDA kernel: deinterleave + quantize in one pass
# CUDA kernel: deinterleave (extract odd groups = swiglu) + quantize in one pass
gpu_fp4, gpu_sf = deinterleave_quantize_nvfp4_cuda(fused, intermediate, global_scale)
# Compare round-trip
ref_deq = _dequantize_nvfp4(ref_fp4, ref_sf, global_scale, intermediate)
gpu_deq = _dequantize_nvfp4(gpu_fp4, gpu_sf, global_scale, intermediate)
# Byte-exact comparison
fp4_match = (ref_fp4.view(torch.uint8) == gpu_fp4.view(torch.uint8)).float().mean().item()
sf_match = (ref_sf.view(torch.uint8) == gpu_sf.view(torch.uint8)).float().mean().item()
print(f" FP4 byte match: {fp4_match*100:.1f}%")
print(f" SF byte match: {sf_match*100:.1f}%")
cos_cross = torch.nn.functional.cosine_similarity(
ref_deq.flatten().float().unsqueeze(0), gpu_deq.flatten().float().unsqueeze(0)
).item()
print(f" Python vs CUDA kernel cos: {cos_cross:.6f}")
assert cos_cross >= 0.99, f"Python vs CUDA kernel cosine too low: {cos_cross}"
assert fp4_match >= 0.99, f"FP4 byte match too low: {fp4_match}"
assert sf_match >= 0.99, f"SF byte match too low: {sf_match}"
print(f" ✅ PASS")