diff --git a/tests/unit/test_nvfp4_1_1_quant.py b/tests/unit/test_nvfp4_1_1_quant.py index 86993754..e44a66ff 100644 --- a/tests/unit/test_nvfp4_1_1_quant.py +++ b/tests/unit/test_nvfp4_1_1_quant.py @@ -4,11 +4,8 @@ NVFP4-1.1 Phase 1: Verify FP4 quantization math in CuTeDSL kernel. Tests that fp4_quant.py functions produce bit-exact matches with the Python reference (quantize_activation_nvfp4). Runs on B200 only. -Strategy: Launch a kernel that processes 16 BF16 values through the -quantization pipeline and writes results to GMEM. Compare with Python. - -Uses cute.arch.load for scalar GMEM reads (proven pattern). -For GMEM writes, uses cute.copy with a simple CopyUniversalOp atom. +Uses cute.arch.load for scalar GMEM reads. +Uses cute.copy with CopyUniversalOp for GMEM writes (no cute.arch.store). """ import torch @@ -32,61 +29,74 @@ from dsv4.kernels.gemm.fp4_quant import ( @cute.kernel def fp4_quant_test_kernel( - input_bf16: cute.Tensor, # (16,) BF16 — 16 input values + input_bf16: cute.Tensor, # (16,) BF16 out_fp4: cute.Tensor, # (8,) Int32 — packed FP4 bytes out_sf: cute.Tensor, # (1,) Int32 — FP8 scale byte gs_scalar: cute.Tensor, # (1,) Float32 — global scale ): - """Quantize 16 BF16 values to NVFP4 using fp4_quant functions. + """Quantize 16 BF16 values to NVFP4. - Single-thread kernel (only thread 0 does work). - Grid: (1, 1, 1), Block: (32, 1, 1) + Thread 0 does all work. Results written via cute.copy. """ tidx, _, _ = cute.arch.thread_idx() + # Create a copy atom for Int32 GMEM writes (1 element per copy) + copy_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), cutlass.Int32, num_bits_per_copy=32, + ) + if tidx == cutlass.Int32(0): # Load global scale gs = cute.arch.load(gs_scalar.iterator, cutlass.Float32) - # Load 16 BF16 values, convert to FP32, normalize by global_scale + # Load 16 BF16 values, convert to FP32, normalize vals_f32 = [cutlass.Float32(0.0)] * 16 for i in cutlass.range(16, unroll=1): bf16_val = cute.arch.load( - input_bf16.iterator + i * cutlass.Int32(2), # BF16 = 2 bytes + input_bf16.iterator + i * cutlass.Int32(2), cutlass.BFloat16, ) vals_f32[i] = bf16_val.to(cutlass.Float32) / gs - # Compute per-16-element amax + # Per-16-element amax amax = cutlass.Float32(0.0) for i in cutlass.range(16, unroll=1): v = vals_f32[i] - a = cute.math.fmax(v, cutlass.Float32(0.0) - v) # abs + a = cute.math.fmax(v, cutlass.Float32(0.0) - v) amax = cute.math.fmax(amax, a) - # Block scale = amax / 6 + # Block scale bsf_f32 = amax / cutlass.Float32(6.0) - # Underflow: if amax < 6 * 2^-9, force scale = 0 underflow_threshold = cutlass.Float32(6.0 * (2.0 ** -9)) if amax < underflow_threshold: bsf_f32 = cutlass.Float32(0.0) - # FP8 E4M3 cast + # FP8 E4M3 cast + dequant sf_bits = fp8_e4m3_from_float32(bsf_f32) - - # Dequantize FP8 scale (round-trip) bs_dequant = fp8_e4m3_to_float32(sf_bits) - # Quantize each value to E2M1 and pack + # E2M1 quantize + pack for i in cutlass.range(8, unroll=1): nibble0 = quantize_e2m1_nibble(vals_f32[2 * i], bs_dequant) nibble1 = quantize_e2m1_nibble(vals_f32[2 * i + 1], bs_dequant) packed = (nibble1 << cutlass.Int32(4)) | nibble0 - # Write packed byte as Int32 - cute.arch.store(out_fp4.iterator + i * cutlass.Int32(4), packed, cutlass.Int32) + # Write to GMEM via cute.copy (1 Int32 element) + rmem = cute.make_rmem_tensor((1,), cutlass.Int32) + rmem[cutlass.Int32(0)] = packed + gmem = cute.make_tensor( + out_fp4.iterator + i * cutlass.Int32(4), + cute.make_layout((1,)), + ) + cute.copy(copy_atom, rmem, gmem) - # Write FP8 scale byte - cute.arch.store(out_sf.iterator, sf_bits, cutlass.Int32) + # Write FP8 scale + rmem_sf = cute.make_rmem_tensor((1,), cutlass.Int32) + rmem_sf[cutlass.Int32(0)] = sf_bits + gmem_sf = cute.make_tensor( + out_sf.iterator, + cute.make_layout((1,)), + ) + cute.copy(copy_atom, rmem_sf, gmem_sf) def run_test(): @@ -94,31 +104,25 @@ def run_test(): device = "cuda" N = 16 - # Generate test input torch.manual_seed(42) x_bf16 = torch.randn(1, N, dtype=torch.bfloat16, device=device) - # Compute global scale x_f32 = x_bf16.float() amax_val = x_f32.abs().max().item() global_scale = max(amax_val / (6.0 * 448.0), 1e-8) - # Python reference ref_fp4, ref_sf = quantize_activation_nvfp4(x_bf16, global_scale) ref_fp4_bytes = ref_fp4.view(torch.uint8).reshape(-1).cpu() ref_sf_bytes = ref_sf.view(torch.uint8).cpu() - print(f"Input BF16 (first 8): {x_bf16[0, :8].cpu()}") print(f"Global scale: {global_scale:.8f}") print(f"Ref FP4: {ref_fp4_bytes}") print(f"Ref SF: {ref_sf_bytes}") - # Prepare output tensors out_fp4 = torch.zeros(8, dtype=torch.int32, device=device) out_sf = torch.zeros(1, dtype=torch.int32, device=device) gs_tensor = torch.tensor([global_scale], dtype=torch.float32, device=device) - # Convert to CuTe tensors def to_cute(t): ct = cutlass_torch.from_dlpack(t) return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) @@ -129,44 +133,30 @@ def run_test(): out_sf_c = to_cute(out_sf) gs_c = to_cute(gs_tensor) - # Compile and run - import cuda.bindings.driver as cuda - stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - - print("\nCompiling kernel (first run may take a minute)...") + print("\nCompiling kernel...") try: compiled = cute.compile( fp4_quant_test_kernel, input_c, out_fp4_c, out_sf_c, gs_c, - stream, ) print("Compiled. Running...") - compiled(input_c, out_fp4_c, out_sf_c, gs_c, stream) + compiled(input_c, out_fp4_c, out_sf_c, gs_c) torch.cuda.synchronize() - # Extract results our_fp4 = out_fp4[:8].to(torch.uint8).cpu() our_sf = out_sf[0].to(torch.uint8).cpu().item() - print(f"\nOur FP4: {our_fp4}") + print(f"Our FP4: {our_fp4}") print(f"Our SF: {our_sf}") fp4_match = torch.equal(our_fp4, ref_fp4_bytes[:8]) sf_match = our_sf == ref_sf_bytes[0].item() if fp4_match and sf_match: - print("\n✅ PASS: FP4 quantization matches Python reference!") + print("\n✅ PASS!") return True else: - print(f"\n❌ FAIL: FP4 match={fp4_match}, SF match={sf_match}") - if not fp4_match: - for i in range(8): - o = our_fp4[i].item() - r = ref_fp4_bytes[i].item() - if o != r: - print(f" Byte {i}: ours=0x{o:02x}, ref=0x{r:02x}") - if not sf_match: - print(f" SF: ours=0x{our_sf:02x}, ref=0x{ref_sf_bytes[0].item():02x}") + print(f"\n❌ FAIL: FP4={fp4_match} SF={sf_match}") return False except Exception as e: print(f"\n❌ ERROR: {e}") @@ -178,7 +168,6 @@ def run_test(): if __name__ == "__main__": print("=" * 60) print("NVFP4-1.1 Phase 1: FP4 Quantization Math Test") - print("Verifies fp4_quant.py functions match Python reference") print("=" * 60) success = run_test() exit(0 if success else 1)