diff --git a/tests/unit/test_nvfp4_gpu_quantize.py b/tests/unit/test_nvfp4_gpu_quantize.py index a6350fb3..c3294472 100644 --- a/tests/unit/test_nvfp4_gpu_quantize.py +++ b/tests/unit/test_nvfp4_gpu_quantize.py @@ -125,36 +125,35 @@ def test_deinterleave_quantize_correctness(): torch.manual_seed(42) M = 64 intermediate = 512 - N = 2 * intermediate # fused output has gate + up interleaved + N = 2 * intermediate # fused output has silu(gate) + swiglu interleaved - # Create SwiGLU-style interleaved data + # Create SwiGLU-style interleaved data matching fused kernel output + # The fused kernel outputs [silu(gate)*8, silu(gate)*up*8, ...] interleaved gate = torch.randn(M, intermediate, dtype=torch.bfloat16, device='cuda') up = torch.randn(M, intermediate, dtype=torch.bfloat16, device='cuda') gate_silu = torch.nn.functional.silu(gate) swiglu_result = gate_silu * up - # Create interleaved layout: [silu(gate)*8, up*8, ...] - fused = torch.cat([gate_silu, up], dim=-1).unsqueeze(0) # (1, M, N) + # Create interleaved layout: [silu(gate)*8, swiglu*8, ...] + fused = torch.cat([gate_silu, swiglu_result], dim=-1).unsqueeze(0) # (1, M, N) fused = interleave_l1_weights(fused)[0] # (M, N) interleaved global_scale = 1.0 - # Python reference: deinterleave then quantize + # Python reference: quantize the swiglu_result (odd groups) ref_fp4, ref_sf = quantize_activation_nvfp4(swiglu_result, global_scale) - # CUDA kernel: deinterleave + quantize in one pass + # CUDA kernel: deinterleave (extract odd groups = swiglu) + quantize in one pass gpu_fp4, gpu_sf = deinterleave_quantize_nvfp4_cuda(fused, intermediate, global_scale) - # Compare round-trip - ref_deq = _dequantize_nvfp4(ref_fp4, ref_sf, global_scale, intermediate) - gpu_deq = _dequantize_nvfp4(gpu_fp4, gpu_sf, global_scale, intermediate) + # Byte-exact comparison + fp4_match = (ref_fp4.view(torch.uint8) == gpu_fp4.view(torch.uint8)).float().mean().item() + sf_match = (ref_sf.view(torch.uint8) == gpu_sf.view(torch.uint8)).float().mean().item() + print(f" FP4 byte match: {fp4_match*100:.1f}%") + print(f" SF byte match: {sf_match*100:.1f}%") - cos_cross = torch.nn.functional.cosine_similarity( - ref_deq.flatten().float().unsqueeze(0), gpu_deq.flatten().float().unsqueeze(0) - ).item() - - print(f" Python vs CUDA kernel cos: {cos_cross:.6f}") - assert cos_cross >= 0.99, f"Python vs CUDA kernel cosine too low: {cos_cross}" + assert fp4_match >= 0.99, f"FP4 byte match too low: {fp4_match}" + assert sf_match >= 0.99, f"SF byte match too low: {sf_match}" print(f" ✅ PASS")