fix test 4: use silu(gate)+swiglu interleaved (matching fused kernel output)
This commit is contained in:
@@ -125,36 +125,35 @@ def test_deinterleave_quantize_correctness():
|
||||
torch.manual_seed(42)
|
||||
M = 64
|
||||
intermediate = 512
|
||||
N = 2 * intermediate # fused output has gate + up interleaved
|
||||
N = 2 * intermediate # fused output has silu(gate) + swiglu interleaved
|
||||
|
||||
# Create SwiGLU-style interleaved data
|
||||
# Create SwiGLU-style interleaved data matching fused kernel output
|
||||
# The fused kernel outputs [silu(gate)*8, silu(gate)*up*8, ...] interleaved
|
||||
gate = torch.randn(M, intermediate, dtype=torch.bfloat16, device='cuda')
|
||||
up = torch.randn(M, intermediate, dtype=torch.bfloat16, device='cuda')
|
||||
gate_silu = torch.nn.functional.silu(gate)
|
||||
swiglu_result = gate_silu * up
|
||||
|
||||
# Create interleaved layout: [silu(gate)*8, up*8, ...]
|
||||
fused = torch.cat([gate_silu, up], dim=-1).unsqueeze(0) # (1, M, N)
|
||||
# Create interleaved layout: [silu(gate)*8, swiglu*8, ...]
|
||||
fused = torch.cat([gate_silu, swiglu_result], dim=-1).unsqueeze(0) # (1, M, N)
|
||||
fused = interleave_l1_weights(fused)[0] # (M, N) interleaved
|
||||
|
||||
global_scale = 1.0
|
||||
|
||||
# Python reference: deinterleave then quantize
|
||||
# Python reference: quantize the swiglu_result (odd groups)
|
||||
ref_fp4, ref_sf = quantize_activation_nvfp4(swiglu_result, global_scale)
|
||||
|
||||
# CUDA kernel: deinterleave + quantize in one pass
|
||||
# CUDA kernel: deinterleave (extract odd groups = swiglu) + quantize in one pass
|
||||
gpu_fp4, gpu_sf = deinterleave_quantize_nvfp4_cuda(fused, intermediate, global_scale)
|
||||
|
||||
# Compare round-trip
|
||||
ref_deq = _dequantize_nvfp4(ref_fp4, ref_sf, global_scale, intermediate)
|
||||
gpu_deq = _dequantize_nvfp4(gpu_fp4, gpu_sf, global_scale, intermediate)
|
||||
# Byte-exact comparison
|
||||
fp4_match = (ref_fp4.view(torch.uint8) == gpu_fp4.view(torch.uint8)).float().mean().item()
|
||||
sf_match = (ref_sf.view(torch.uint8) == gpu_sf.view(torch.uint8)).float().mean().item()
|
||||
print(f" FP4 byte match: {fp4_match*100:.1f}%")
|
||||
print(f" SF byte match: {sf_match*100:.1f}%")
|
||||
|
||||
cos_cross = torch.nn.functional.cosine_similarity(
|
||||
ref_deq.flatten().float().unsqueeze(0), gpu_deq.flatten().float().unsqueeze(0)
|
||||
).item()
|
||||
|
||||
print(f" Python vs CUDA kernel cos: {cos_cross:.6f}")
|
||||
assert cos_cross >= 0.99, f"Python vs CUDA kernel cosine too low: {cos_cross}"
|
||||
assert fp4_match >= 0.99, f"FP4 byte match too low: {fp4_match}"
|
||||
assert sf_match >= 0.99, f"SF byte match too low: {sf_match}"
|
||||
print(f" ✅ PASS")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user