""" NVFP4-1.1 Phase 1: Verify FP4 quantization math in CuTeDSL kernel. Simplified: uses Float32 input to avoid BF16 scalar load issues. Two-pass: (1) compute amax+scale, (2) quantize+pack. """ import torch import cutlass import cutlass.cute as cute import cutlass.torch as cutlass_torch import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../..")) from dsv4.ops.quantize import quantize_activation_nvfp4, SF_VEC_SIZE from dsv4.kernels.gemm.fp4_quant import ( fp8_e4m3_from_float32, fp8_e4m3_to_float32, quantize_e2m1_nibble, ) @cute.kernel def fp4_quant_test_kernel( input_f32: cute.Tensor, # (16,) Float32 out_fp4: cute.Tensor, # (8,) Int32 out_sf: cute.Tensor, # (1,) Int32 gs_scalar: cute.Tensor, # (1,) Float32 ): tidx, _, _ = cute.arch.thread_idx() if tidx == cutlass.Int32(0): gs = cute.arch.load(gs_scalar.iterator, cutlass.Float32) # Pass 1: amax (load Float32, divide by global_scale) amax = cutlass.Float32(0.0) for i in cutlass.range(16, unroll=1): ptr = input_f32.iterator + i * cutlass.Int32(4) # Float32 = 4 bytes v = cute.arch.load(ptr, cutlass.Float32) / gs a = cute.arch.fmax(v, cutlass.Float32(0.0) - v) amax = cute.arch.fmax(amax, a) # Block scale bsf_f32 = amax / cutlass.Float32(6.0) if amax < cutlass.Float32(6.0 * (2.0 ** -9)): bsf_f32 = cutlass.Float32(0.0) sf_bits = fp8_e4m3_from_float32(bsf_f32) bs_dequant = fp8_e4m3_to_float32(sf_bits) # Write SF cute.arch.store(out_sf.iterator, sf_bits) # Pass 2: quantize and pack for i in cutlass.range(8, unroll=1): ptr0 = input_f32.iterator + (2 * i) * cutlass.Int32(4) v0 = cute.arch.load(ptr0, cutlass.Float32) / gs nibble0 = quantize_e2m1_nibble(v0, bs_dequant) ptr1 = input_f32.iterator + (2 * i + cutlass.Int32(1)) * cutlass.Int32(4) v1 = cute.arch.load(ptr1, cutlass.Float32) / gs nibble1 = quantize_e2m1_nibble(v1, bs_dequant) packed = (nibble1 << cutlass.Int32(4)) | nibble0 cute.arch.store(out_fp4.iterator + i * cutlass.Int32(4), packed) def run_test(): device = "cuda" N = 16 torch.manual_seed(42) x_bf16 = torch.randn(1, N, dtype=torch.bfloat16, device=device) x_f32 = x_bf16.float() amax_val = x_f32.abs().max().item() global_scale = max(amax_val / (6.0 * 448.0), 1e-8) # Pre-divide by global_scale (kernel expects normalized input) x_norm = (x_f32 / global_scale).reshape(N).contiguous() # Python reference (on the same normalized values) ref_fp4, ref_sf = quantize_activation_nvfp4(x_bf16, global_scale) ref_fp4_bytes = ref_fp4.view(torch.uint8).reshape(-1).cpu() ref_sf_bytes = ref_sf.view(torch.uint8).cpu() print(f"Global scale: {global_scale:.8f}") print(f"Ref FP4: {ref_fp4_bytes}") print(f"Ref SF: {ref_sf_bytes}") # Output tensors out_fp4 = torch.zeros(8, dtype=torch.int32, device=device) out_sf = torch.zeros(1, dtype=torch.int32, device=device) gs_tensor = torch.tensor([1.0], dtype=torch.float32, device=device) # already normalized def to_cute(t): ct = cutlass_torch.from_dlpack(t) return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) input_c = to_cute(x_norm) out_fp4_c = to_cute(out_fp4) out_sf_c = to_cute(out_sf) gs_c = to_cute(gs_tensor) print("\nCompiling kernel...") try: compiled = cute.compile( fp4_quant_test_kernel, input_c, out_fp4_c, out_sf_c, gs_c, ) print("Compiled. Running...") compiled(input_c, out_fp4_c, out_sf_c, gs_c) torch.cuda.synchronize() our_fp4 = out_fp4[:8].to(torch.uint8).cpu() our_sf = out_sf[0].to(torch.uint8).cpu().item() print(f"Our FP4: {our_fp4}") print(f"Our SF: {our_sf}") fp4_match = torch.equal(our_fp4, ref_fp4_bytes[:8]) sf_match = our_sf == ref_sf_bytes[0].item() if fp4_match and sf_match: print("\n✅ PASS!") return True else: print(f"\n❌ FAIL: FP4={fp4_match} SF={sf_match}") return False except Exception as e: print(f"\n❌ ERROR: {e}") import traceback traceback.print_exc() return False if __name__ == "__main__": print("=" * 60) print("NVFP4-1.1 Phase 1: FP4 Quantization Math Test") print("Threshold rounding — Float32 input — no BF16 scalar loads") print("=" * 60) success = run_test() exit(0 if success else 1)