""" NVFP4 GPU Quantize Kernel Test. Tests: 1. quantize_nvfp4_gpu: BF16 → FP4 (no deinterleave, GPU-only, no CPU sync) 2. deinterleave_quantize_nvfp4_cuda: interleaved BF16 → FP4 (deinterleave + quantize) 3. Integration: Both kernels match the Python quantize_activation_nvfp4 reference Run: ~/.openclaw/workspace/fire_b200_test tests/unit/test_nvfp4_gpu_quantize.py """ import torch import math import sys, os sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) from dsv4.ops.quantize import ( quantize_activation_nvfp4, quantize_nvfp4_gpu, deinterleave_quantize_nvfp4_cuda, SF_VEC_SIZE, ) from dsv4.ops.layouts import interleave_l1_weights def test_quantize_nvfp4_gpu_basic(): """Test GPU-only quantize kernel against Python reference.""" print("\n=== Test 1: quantize_nvfp4_gpu basic correctness ===") torch.manual_seed(42) M, N = 128, 512 x = torch.randn(M, N, dtype=torch.bfloat16, device='cuda') global_scale = 1.0 # Python reference ref_fp4, ref_sf = quantize_activation_nvfp4(x, global_scale) # GPU kernel gpu_fp4, gpu_sf = quantize_nvfp4_gpu(x, global_scale) # Compare shapes and dtypes assert gpu_fp4.shape == ref_fp4.shape, f"FP4 shape mismatch: {gpu_fp4.shape} vs {ref_fp4.shape}" assert gpu_sf.shape == ref_sf.shape, f"SF shape mismatch: {gpu_sf.shape} vs {ref_sf.shape}" assert gpu_fp4.dtype == torch.float4_e2m1fn_x2, f"FP4 dtype: {gpu_fp4.dtype}" assert gpu_sf.dtype == torch.float8_e4m3fn, f"SF dtype: {gpu_sf.dtype}" # Byte-exact comparison of FP4 data ref_bytes = ref_fp4.view(torch.uint8) gpu_bytes = gpu_fp4.view(torch.uint8) fp4_match = (ref_bytes == gpu_bytes).float().mean().item() print(f" FP4 byte match: {fp4_match*100:.1f}%") # SF comparison ref_sf_bytes = ref_sf.view(torch.uint8) gpu_sf_bytes = gpu_sf.view(torch.uint8) sf_match = (ref_sf_bytes == gpu_sf_bytes).float().mean().item() print(f" SF byte match: {sf_match*100:.1f}%") # Round-trip cosine similarity ref_deq = _dequantize_nvfp4(ref_fp4, ref_sf, global_scale, N) gpu_deq = _dequantize_nvfp4(gpu_fp4, gpu_sf, global_scale, N) cos_ref = torch.nn.functional.cosine_similarity( x.flatten().float().unsqueeze(0), ref_deq.flatten().float().unsqueeze(0) ).item() cos_gpu = torch.nn.functional.cosine_similarity( x.flatten().float().unsqueeze(0), gpu_deq.flatten().float().unsqueeze(0) ).item() cos_cross = torch.nn.functional.cosine_similarity( ref_deq.flatten().float().unsqueeze(0), gpu_deq.flatten().float().unsqueeze(0) ).item() print(f" Python round-trip cos: {cos_ref:.6f}") print(f" GPU round-trip cos: {cos_gpu:.6f}") print(f" Python vs GPU cos: {cos_cross:.6f}") assert cos_gpu >= 0.95, f"GPU round-trip cosine too low: {cos_gpu}" assert cos_cross >= 0.99, f"Python vs GPU cosine too low: {cos_cross}" print(f" ✅ PASS") def test_quantize_nvfp4_gpu_larger(): """Test GPU quantize with larger dimensions (MoE intermediate size).""" print("\n=== Test 2: quantize_nvfp4_gpu larger shape ===") torch.manual_seed(42) M, N = 64, 4096 x = torch.randn(M, N, dtype=torch.bfloat16, device='cuda') # Use quantize_to_nvfp4 to get a proper global_scale (from data amax) from dsv4.ops.quantize import quantize_to_nvfp4 _, _, global_scale = quantize_to_nvfp4(x) ref_fp4, ref_sf = quantize_activation_nvfp4(x, global_scale) gpu_fp4, gpu_sf = quantize_nvfp4_gpu(x, global_scale) # Byte-exact comparison fp4_match = (ref_fp4.view(torch.uint8) == gpu_fp4.view(torch.uint8)).float().mean().item() sf_match = (ref_sf.view(torch.uint8) == gpu_sf.view(torch.uint8)).float().mean().item() print(f" FP4 byte match: {fp4_match*100:.1f}%") print(f" SF byte match: {sf_match*100:.1f}%") assert fp4_match >= 0.99, f"FP4 byte match too low: {fp4_match}" assert sf_match >= 0.99, f"SF byte match too low: {sf_match}" print(f" ✅ PASS") def test_quantize_nvfp4_gpu_no_cpu_sync(): """Verify quantize_nvfp4_gpu has no CPU-GPU sync points.""" print("\n=== Test 3: No CPU-GPU sync verification ===") torch.manual_seed(42) M, N = 32, 512 x = torch.randn(M, N, dtype=torch.bfloat16, device='cuda') global_scale = 1.0 # This should NOT trigger a CUDA synchronization # If it does, cudagraph capture would fail gpu_fp4, gpu_sf = quantize_nvfp4_gpu(x, global_scale) # Force a sync to check results result = gpu_fp4.view(torch.uint8).sum().item() print(f" FP4 sum (verification): {result}") print(f" ✅ PASS (no crash = no CPU sync in kernel path)") def test_deinterleave_quantize_correctness(): """Test deinterleave + quantize CUDA kernel against Python reference.""" print("\n=== Test 4: deinterleave_quantize_nvfp4_cuda correctness ===") torch.manual_seed(42) M = 64 intermediate = 512 N = 2 * intermediate # fused output has silu(gate) + swiglu interleaved # Create SwiGLU-style interleaved data matching fused kernel output # The fused kernel outputs [silu(gate)*8, silu(gate)*up*8, ...] interleaved gate = torch.randn(M, intermediate, dtype=torch.bfloat16, device='cuda') up = torch.randn(M, intermediate, dtype=torch.bfloat16, device='cuda') gate_silu = torch.nn.functional.silu(gate) swiglu_result = gate_silu * up # Create interleaved layout: [silu(gate)*8, swiglu*8, ...] fused = torch.cat([gate_silu, swiglu_result], dim=-1).unsqueeze(0) # (1, M, N) fused = interleave_l1_weights(fused)[0] # (M, N) interleaved global_scale = 1.0 # Python reference: quantize the swiglu_result (odd groups) ref_fp4, ref_sf = quantize_activation_nvfp4(swiglu_result, global_scale) # CUDA kernel: deinterleave (extract odd groups = swiglu) + quantize in one pass gpu_fp4, gpu_sf = deinterleave_quantize_nvfp4_cuda(fused, intermediate, global_scale) # Byte-exact comparison fp4_match = (ref_fp4.view(torch.uint8) == gpu_fp4.view(torch.uint8)).float().mean().item() sf_match = (ref_sf.view(torch.uint8) == gpu_sf.view(torch.uint8)).float().mean().item() print(f" FP4 byte match: {fp4_match*100:.1f}%") print(f" SF byte match: {sf_match*100:.1f}%") assert fp4_match >= 0.99, f"FP4 byte match too low: {fp4_match}" assert sf_match >= 0.99, f"SF byte match too low: {sf_match}" print(f" ✅ PASS") def test(): print("=== NVFP4 GPU Quantize Kernel Tests ===") test_quantize_nvfp4_gpu_basic() test_quantize_nvfp4_gpu_larger() test_quantize_nvfp4_gpu_no_cpu_sync() test_deinterleave_quantize_correctness() print("\n=== ALL TESTS PASSED ===") def _dequantize_nvfp4(x_fp4, block_scale, global_scale, N): """Dequantize NVFP4 back to BF16 for verification.""" M = x_fp4.shape[0] block_size = SF_VEC_SIZE raw = x_fp4.view(torch.uint8) even = raw & 0x0F odd = (raw >> 4) & 0x0F indices = torch.stack([even, odd], dim=-1).reshape(M, N) signs = (indices >= 8).float() * -2 + 1 mag = indices % 8 idx_to_half_step = torch.tensor([0, 2, 4, 6, 8, 10, 12, 14], dtype=torch.float32, device='cuda') half_steps = idx_to_half_step[mag.long()] x_deq_fp32 = signs * half_steps / 2.0 block_scale_exp = block_scale.repeat_interleave(block_size, dim=-1).float() x_deq_fp32 = x_deq_fp32 * block_scale_exp * global_scale return x_deq_fp32.to(torch.bfloat16) if __name__ == '__main__': test()