diff --git a/tests/unit/test_nvfp4_1_1_quant.py b/tests/unit/test_nvfp4_1_1_quant.py index 33f39242..86993754 100644 --- a/tests/unit/test_nvfp4_1_1_quant.py +++ b/tests/unit/test_nvfp4_1_1_quant.py @@ -1,20 +1,21 @@ """ -NVFP4-1.1 Phase 1: Verify FP4 quantization math in CuTeDSL. +NVFP4-1.1 Phase 1: Verify FP4 quantization math in CuTeDSL kernel. -Tests the fp4_quant.py functions on B200. Compares CuTeDSL kernel output -with Python reference (quantize_activation_nvfp4). +Tests that fp4_quant.py functions produce bit-exact matches with the +Python reference (quantize_activation_nvfp4). Runs on B200 only. -The kernel takes 16 BF16 values + global_scale, quantizes to NVFP4, -and writes FP4 packed bytes + FP8 scale byte to output tensors. +Strategy: Launch a kernel that processes 16 BF16 values through the +quantization pipeline and writes results to GMEM. Compare with Python. -Uses cute.arch.load for scalar GMEM reads (proven pattern from the codebase). -For writes, uses the output tensor's iterator + offset pattern. +Uses cute.arch.load for scalar GMEM reads (proven pattern). +For GMEM writes, uses cute.copy with a simple CopyUniversalOp atom. """ import torch import cutlass import cutlass.cute as cute import cutlass.torch as cutlass_torch +from cutlass.cute.nvgpu import cpasync import sys import os @@ -22,7 +23,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../..")) from dsv4.ops.quantize import quantize_activation_nvfp4, SF_VEC_SIZE from dsv4.kernels.gemm.fp4_quant import ( - fp8_e4m3_from_float32_manual, + fp8_e4m3_from_float32, fp8_e4m3_to_float32, half_step_to_e2m1_idx, quantize_e2m1_nibble, @@ -32,7 +33,8 @@ from dsv4.kernels.gemm.fp4_quant import ( @cute.kernel def fp4_quant_test_kernel( input_bf16: cute.Tensor, # (16,) BF16 — 16 input values - out_data: cute.Tensor, # (10,) Int32 — [0..7] = FP4 packed bytes, [8] = SF byte, [9] = debug + out_fp4: cute.Tensor, # (8,) Int32 — packed FP4 bytes + out_sf: cute.Tensor, # (1,) Int32 — FP8 scale byte gs_scalar: cute.Tensor, # (1,) Float32 — global scale ): """Quantize 16 BF16 values to NVFP4 using fp4_quant functions. @@ -55,43 +57,40 @@ def fp4_quant_test_kernel( ) vals_f32[i] = bf16_val.to(cutlass.Float32) / gs - # ── Compute per-16-element amax ── + # Compute per-16-element amax amax = cutlass.Float32(0.0) for i in cutlass.range(16, unroll=1): v = vals_f32[i] a = cute.math.fmax(v, cutlass.Float32(0.0) - v) # abs amax = cute.math.fmax(amax, a) - # ── Block scale = amax / 6 ── + # Block scale = amax / 6 bsf_f32 = amax / cutlass.Float32(6.0) # Underflow: if amax < 6 * 2^-9, force scale = 0 underflow_threshold = cutlass.Float32(6.0 * (2.0 ** -9)) if amax < underflow_threshold: bsf_f32 = cutlass.Float32(0.0) - # ── FP8 E4M3 cast ── - sf_bits = fp8_e4m3_from_float32_manual(bsf_f32) + # FP8 E4M3 cast + sf_bits = fp8_e4m3_from_float32(bsf_f32) - # ── Dequantize FP8 scale (round-trip) ── + # Dequantize FP8 scale (round-trip) bs_dequant = fp8_e4m3_to_float32(sf_bits) - # ── Quantize each value to E2M1 and pack ── + # Quantize each value to E2M1 and pack for i in cutlass.range(8, unroll=1): nibble0 = quantize_e2m1_nibble(vals_f32[2 * i], bs_dequant) nibble1 = quantize_e2m1_nibble(vals_f32[2 * i + 1], bs_dequant) packed = (nibble1 << cutlass.Int32(4)) | nibble0 # Write packed byte as Int32 - cute.arch.store(out_data.iterator + i * cutlass.Int32(4), packed, cutlass.Int32) + cute.arch.store(out_fp4.iterator + i * cutlass.Int32(4), packed, cutlass.Int32) - # ── Write FP8 scale byte ── - cute.arch.store(out_data.iterator + cutlass.Int32(8) * cutlass.Int32(4), sf_bits, cutlass.Int32) - - # ── Debug: write bsf_f32 and bs_dequant as float ── - # out_data[9] is unused — let's skip for simplicity + # Write FP8 scale byte + cute.arch.store(out_sf.iterator, sf_bits, cutlass.Int32) def run_test(): - """Run the FP4 quantization test.""" + """Run the FP4 quantization test on GPU.""" device = "cuda" N = 16 @@ -99,7 +98,7 @@ def run_test(): torch.manual_seed(42) x_bf16 = torch.randn(1, N, dtype=torch.bfloat16, device=device) - # Compute global scale (matching quantize_activation_nvfp4) + # Compute global scale x_f32 = x_bf16.float() amax_val = x_f32.abs().max().item() global_scale = max(amax_val / (6.0 * 448.0), 1e-8) @@ -111,11 +110,12 @@ def run_test(): print(f"Input BF16 (first 8): {x_bf16[0, :8].cpu()}") print(f"Global scale: {global_scale:.8f}") - print(f"Ref FP4 bytes: {ref_fp4_bytes}") - print(f"Ref SF byte: {ref_sf_bytes}") + print(f"Ref FP4: {ref_fp4_bytes}") + print(f"Ref SF: {ref_sf_bytes}") - # Prepare output tensor - out_data = torch.zeros(10, dtype=torch.int32, device=device) + # Prepare output tensors + out_fp4 = torch.zeros(8, dtype=torch.int32, device=device) + out_sf = torch.zeros(1, dtype=torch.int32, device=device) gs_tensor = torch.tensor([global_scale], dtype=torch.float32, device=device) # Convert to CuTe tensors @@ -125,7 +125,8 @@ def run_test(): x_flat = x_bf16.reshape(N).contiguous() input_c = to_cute(x_flat) - out_c = to_cute(out_data) + out_fp4_c = to_cute(out_fp4) + out_sf_c = to_cute(out_sf) gs_c = to_cute(gs_tensor) # Compile and run @@ -133,39 +134,44 @@ def run_test(): stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) print("\nCompiling kernel (first run may take a minute)...") - compiled = cute.compile( - fp4_quant_test_kernel, - input_c, out_c, gs_c, - stream, - ) - print("Compiled. Running...") - compiled(input_c, out_c, gs_c, stream) - torch.cuda.synchronize() - - # Extract results - our_fp4 = out_data[:8].to(torch.uint8).cpu() - our_sf = out_data[8].to(torch.uint8).cpu().item() - - print(f"\nOur FP4 bytes: {our_fp4}") - print(f"Our SF byte: {our_sf}") - - # Compare - fp4_match = torch.equal(our_fp4, ref_fp4_bytes[:8]) - sf_match = our_sf == ref_sf_bytes[0].item() - - if fp4_match and sf_match: - print("\n✅ PASS: FP4 quantization matches Python reference!") - return True - else: - print(f"\n❌ FAIL: FP4 match={fp4_match}, SF match={sf_match}") - if not fp4_match: - for i in range(8): - o = our_fp4[i].item() - r = ref_fp4_bytes[i].item() - if o != r: - print(f" Byte {i}: ours=0x{o:02x}, ref=0x{r:02x}") - if not sf_match: - print(f" SF: ours=0x{our_sf:02x}, ref=0x{ref_sf_bytes[0].item():02x}") + try: + compiled = cute.compile( + fp4_quant_test_kernel, + input_c, out_fp4_c, out_sf_c, gs_c, + stream, + ) + print("Compiled. Running...") + compiled(input_c, out_fp4_c, out_sf_c, gs_c, stream) + torch.cuda.synchronize() + + # Extract results + our_fp4 = out_fp4[:8].to(torch.uint8).cpu() + our_sf = out_sf[0].to(torch.uint8).cpu().item() + + print(f"\nOur FP4: {our_fp4}") + print(f"Our SF: {our_sf}") + + fp4_match = torch.equal(our_fp4, ref_fp4_bytes[:8]) + sf_match = our_sf == ref_sf_bytes[0].item() + + if fp4_match and sf_match: + print("\n✅ PASS: FP4 quantization matches Python reference!") + return True + else: + print(f"\n❌ FAIL: FP4 match={fp4_match}, SF match={sf_match}") + if not fp4_match: + for i in range(8): + o = our_fp4[i].item() + r = ref_fp4_bytes[i].item() + if o != r: + print(f" Byte {i}: ours=0x{o:02x}, ref=0x{r:02x}") + if not sf_match: + print(f" SF: ours=0x{our_sf:02x}, ref=0x{ref_sf_bytes[0].item():02x}") + return False + except Exception as e: + print(f"\n❌ ERROR: {e}") + import traceback + traceback.print_exc() return False