diff --git a/tests/unit/test_nvfp4_1_1_quant.py b/tests/unit/test_nvfp4_1_1_quant.py index 7591272f..eea6b20c 100644 --- a/tests/unit/test_nvfp4_1_1_quant.py +++ b/tests/unit/test_nvfp4_1_1_quant.py @@ -1,11 +1,11 @@ """ NVFP4-1.1 Phase 1: Verify FP4 quantization math in CuTeDSL kernel. -Two-pass approach to avoid Python list indexing with CuTeDSL loop variables: +Two-pass approach: 1. First pass: load 16 values, compute amax + FP8 scale 2. Second pass: reload 16 values, quantize to FP4 -cute.arch.store confirmed available on B200. +Uses threshold rounding (no float-to-int conversion). """ import torch @@ -32,20 +32,15 @@ def fp4_quant_test_kernel( out_sf: cute.Tensor, # (1,) Int32 — FP8 scale byte gs_scalar: cute.Tensor, # (1,) Float32 — global scale ): - """Quantize 16 BF16 values to NVFP4. - - Two-pass: (1) compute amax+scale, (2) quantize+pack. - Only thread 0 works. Grid: (1,1,1), Block: (32,1,1). - """ tidx, _, _ = cute.arch.thread_idx() if tidx == cutlass.Int32(0): gs = cute.arch.load(gs_scalar.iterator, cutlass.Float32) - # ── Pass 1: Compute per-16-element amax ── + # Pass 1: amax amax = cutlass.Float32(0.0) for i in cutlass.range(16, unroll=1): - ptr = input_bf16.iterator + i * cutlass.Int32(2) # BF16=2 bytes + ptr = input_bf16.iterator + i * cutlass.Int32(2) bf16_val = cute.arch.load(ptr, cutlass.BFloat16) v = bf16_val.to(cutlass.Float32) / gs a = cute.arch.fmax(v, cutlass.Float32(0.0) - v) @@ -56,35 +51,29 @@ def fp4_quant_test_kernel( if amax < cutlass.Float32(6.0 * (2.0 ** -9)): bsf_f32 = cutlass.Float32(0.0) - # FP8 E4M3 cast + dequant sf_bits = fp8_e4m3_from_float32(bsf_f32) bs_dequant = fp8_e4m3_to_float32(sf_bits) - # Write FP8 scale byte (Int32 holding uint8) + # Write SF cute.arch.store(out_sf.iterator, sf_bits) - # ── Pass 2: Quantize and pack ── + # Pass 2: quantize and pack for i in cutlass.range(8, unroll=1): - # Load even element (2*i) ptr0 = input_bf16.iterator + (2 * i) * cutlass.Int32(2) v0_bf16 = cute.arch.load(ptr0, cutlass.BFloat16) v0 = v0_bf16.to(cutlass.Float32) / gs nibble0 = quantize_e2m1_nibble(v0, bs_dequant) - # Load odd element (2*i+1) ptr1 = input_bf16.iterator + (2 * i + cutlass.Int32(1)) * cutlass.Int32(2) v1_bf16 = cute.arch.load(ptr1, cutlass.BFloat16) v1 = v1_bf16.to(cutlass.Float32) / gs nibble1 = quantize_e2m1_nibble(v1, bs_dequant) packed = (nibble1 << cutlass.Int32(4)) | nibble0 - - # Write packed byte as Int32 cute.arch.store(out_fp4.iterator + i * cutlass.Int32(4), packed) def run_test(): - """Run the FP4 quantization test on GPU.""" device = "cuda" N = 16 @@ -152,6 +141,7 @@ def run_test(): if __name__ == "__main__": print("=" * 60) print("NVFP4-1.1 Phase 1: FP4 Quantization Math Test") + print("Threshold rounding — no float-to-int conversion") print("=" * 60) success = run_test() exit(0 if success else 1)