diff --git a/tests/unit/test_nvfp4_1_1_quant.py b/tests/unit/test_nvfp4_1_1_quant.py index eea6b20c..e3399b1d 100644 --- a/tests/unit/test_nvfp4_1_1_quant.py +++ b/tests/unit/test_nvfp4_1_1_quant.py @@ -1,11 +1,8 @@ """ NVFP4-1.1 Phase 1: Verify FP4 quantization math in CuTeDSL kernel. -Two-pass approach: -1. First pass: load 16 values, compute amax + FP8 scale -2. Second pass: reload 16 values, quantize to FP4 - -Uses threshold rounding (no float-to-int conversion). +Simplified: uses Float32 input to avoid BF16 scalar load issues. +Two-pass: (1) compute amax+scale, (2) quantize+pack. """ import torch @@ -27,22 +24,21 @@ from dsv4.kernels.gemm.fp4_quant import ( @cute.kernel def fp4_quant_test_kernel( - input_bf16: cute.Tensor, # (16,) BF16 - out_fp4: cute.Tensor, # (8,) Int32 — packed FP4 bytes - out_sf: cute.Tensor, # (1,) Int32 — FP8 scale byte - gs_scalar: cute.Tensor, # (1,) Float32 — global scale + input_f32: cute.Tensor, # (16,) Float32 + out_fp4: cute.Tensor, # (8,) Int32 + out_sf: cute.Tensor, # (1,) Int32 + gs_scalar: cute.Tensor, # (1,) Float32 ): tidx, _, _ = cute.arch.thread_idx() if tidx == cutlass.Int32(0): gs = cute.arch.load(gs_scalar.iterator, cutlass.Float32) - # Pass 1: amax + # Pass 1: amax (load Float32, divide by global_scale) amax = cutlass.Float32(0.0) for i in cutlass.range(16, unroll=1): - ptr = input_bf16.iterator + i * cutlass.Int32(2) - bf16_val = cute.arch.load(ptr, cutlass.BFloat16) - v = bf16_val.to(cutlass.Float32) / gs + ptr = input_f32.iterator + i * cutlass.Int32(4) # Float32 = 4 bytes + v = cute.arch.load(ptr, cutlass.Float32) / gs a = cute.arch.fmax(v, cutlass.Float32(0.0) - v) amax = cute.arch.fmax(amax, a) @@ -59,14 +55,12 @@ def fp4_quant_test_kernel( # Pass 2: quantize and pack for i in cutlass.range(8, unroll=1): - ptr0 = input_bf16.iterator + (2 * i) * cutlass.Int32(2) - v0_bf16 = cute.arch.load(ptr0, cutlass.BFloat16) - v0 = v0_bf16.to(cutlass.Float32) / gs + ptr0 = input_f32.iterator + (2 * i) * cutlass.Int32(4) + v0 = cute.arch.load(ptr0, cutlass.Float32) / gs nibble0 = quantize_e2m1_nibble(v0, bs_dequant) - ptr1 = input_bf16.iterator + (2 * i + cutlass.Int32(1)) * cutlass.Int32(2) - v1_bf16 = cute.arch.load(ptr1, cutlass.BFloat16) - v1 = v1_bf16.to(cutlass.Float32) / gs + ptr1 = input_f32.iterator + (2 * i + cutlass.Int32(1)) * cutlass.Int32(4) + v1 = cute.arch.load(ptr1, cutlass.Float32) / gs nibble1 = quantize_e2m1_nibble(v1, bs_dequant) packed = (nibble1 << cutlass.Int32(4)) | nibble0 @@ -84,6 +78,10 @@ def run_test(): amax_val = x_f32.abs().max().item() global_scale = max(amax_val / (6.0 * 448.0), 1e-8) + # Pre-divide by global_scale (kernel expects normalized input) + x_norm = (x_f32 / global_scale).reshape(N).contiguous() + + # Python reference (on the same normalized values) ref_fp4, ref_sf = quantize_activation_nvfp4(x_bf16, global_scale) ref_fp4_bytes = ref_fp4.view(torch.uint8).reshape(-1).cpu() ref_sf_bytes = ref_sf.view(torch.uint8).cpu() @@ -92,16 +90,16 @@ def run_test(): print(f"Ref FP4: {ref_fp4_bytes}") print(f"Ref SF: {ref_sf_bytes}") + # Output tensors out_fp4 = torch.zeros(8, dtype=torch.int32, device=device) out_sf = torch.zeros(1, dtype=torch.int32, device=device) - gs_tensor = torch.tensor([global_scale], dtype=torch.float32, device=device) + gs_tensor = torch.tensor([1.0], dtype=torch.float32, device=device) # already normalized def to_cute(t): ct = cutlass_torch.from_dlpack(t) return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) - x_flat = x_bf16.reshape(N).contiguous() - input_c = to_cute(x_flat) + input_c = to_cute(x_norm) out_fp4_c = to_cute(out_fp4) out_sf_c = to_cute(out_sf) gs_c = to_cute(gs_tensor) @@ -141,7 +139,7 @@ def run_test(): if __name__ == "__main__": print("=" * 60) print("NVFP4-1.1 Phase 1: FP4 Quantization Math Test") - print("Threshold rounding — no float-to-int conversion") + print("Threshold rounding — Float32 input — no BF16 scalar loads") print("=" * 60) success = run_test() exit(0 if success else 1)