From 96567ea64d6ca79571d3cfae5433829ba2bed2a5 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 08:28:13 +0000 Subject: [PATCH] =?UTF-8?q?fix:=20NVFP4-0=20diag=20script=20=E2=80=94=20im?= =?UTF-8?q?port=20SF=5FVEC=5FSIZE=20from=20quantize.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/unit/test_nvfp4_primitives.py | 55 ++++++++++++++--------------- 1 file changed, 27 insertions(+), 28 deletions(-) diff --git a/tests/unit/test_nvfp4_primitives.py b/tests/unit/test_nvfp4_primitives.py index a1bed490..1b27d3dc 100644 --- a/tests/unit/test_nvfp4_primitives.py +++ b/tests/unit/test_nvfp4_primitives.py @@ -8,34 +8,29 @@ import cutlass import cutlass.cute as cute from cutlass.cute.nvgpu import tcgen05 from cutlass import Float32, BFloat16, Float8E4M3FN, Float8E8M0FNU, Float4E2M1FN +import sys, os +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + def test_nvfp4_primitives(): print("=" * 60) - print("NVFP4-0.1: sf_dtype and sf_vec_size in GEMM runner") + print("NVFP4-0.1: sf_dtype and sf_vec_size") print("=" * 60) - # Check what gemm_runner uses - import sys, os - sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) - from dsv4.ops.gemm_runner import SF_VEC_SIZE - print(f" gemm_runner SF_VEC_SIZE = {SF_VEC_SIZE}") + from dsv4.ops.quantize import SF_VEC_SIZE + print(f" quantize.py SF_VEC_SIZE = {SF_VEC_SIZE}") + assert SF_VEC_SIZE == 16, f"SF_VEC_SIZE should be 16 for NVFP4, got {SF_VEC_SIZE}" + print(f" ✅ SF_VEC_SIZE = 16 (NVFP4 correct)") - # Check what dense.py BlockScaledGEMM expects - from dsv4.kernels.gemm.dense import BlockScaledGEMM - # Print the docstring-specified types - print(f" dense.py NVF4 spec: A/B=Float4E2M1FN, SF=Float8E8M0FNU/Float8E4M3FN, sf_vec_size=16") - print(f" dense.py MXF4 spec: A/B=Float4E2M1FN, SF=Float8E8M0FNU, sf_vec_size=32") - - # Construct a small NVFP4 GEMM and print the sf_dtype it receives + # Construct a BlockScaledGEMM with E4M3 scales (NVFP4) M, N, K = 128, 256, 512 - # Create small FP4 tensors a_fp4 = torch.randn(M, K // 2, device='cuda', dtype=torch.float4_e2m1fn_x2) b_fp4 = torch.randn(K // 2, N, device='cuda', dtype=torch.float4_e2m1fn_x2) a_sf = torch.randn(M, K // 16, device='cuda', dtype=torch.float8_e4m3fn) b_sf = torch.randn(N, K // 16, device='cuda', dtype=torch.float8_e4m3fn) + from dsv4.kernels.gemm.dense import BlockScaledGEMM try: - from dsv4.kernels.gemm.dense import BlockScaledGEMM gemm = BlockScaledGEMM( a_ptr=a_fp4, b_ptr=b_fp4, @@ -46,10 +41,17 @@ def test_nvfp4_primitives(): print(f" BlockScaledGEMM.sf_vec_size = {gemm.sf_vec_size}") print(f" BlockScaledGEMM.sf_dtype = {gemm.sf_dtype}") print(f" BlockScaledGEMM.mma_inst_shape_mn = {gemm.mma_inst_shape_mn}") + if gemm.sf_dtype == Float8E4M3FN: + print(f" ✅ sf_dtype is Float8E4M3FN (NVFP4 correct)") + elif gemm.sf_dtype == Float8E8M0FNU: + print(f" ⚠️ sf_dtype is Float8E8M0FNU — this is MXFP4 scale format, NOT NVFP4!") + print(f" ⚠️ NVFP4 should use Float8E4M3FN scales at sf_vec_size=16") + else: + print(f" ❌ sf_dtype is {gemm.sf_dtype} — unexpected!") except Exception as e: print(f" BlockScaledGEMM construction failed: {e}") - # Also try with E8M0 scale dtype to see if that's what's being used + # Also try with E8M0 scales (MXFP4) a_sf_e8m0 = torch.randn(M, K // 32, device='cuda', dtype=torch.float8_e8m0fnu) b_sf_e8m0 = torch.randn(N, K // 32, device='cuda', dtype=torch.float8_e8m0fnu) try: @@ -60,17 +62,17 @@ def test_nvfp4_primitives(): sfb_ptr=b_sf_e8m0, sf_vec_size=32, ) - print(f" BlockScaledGEMM (E8M0/sf_vec_size=32).sf_dtype = {gemm_e8m0.sf_dtype}") - print(f" BlockScaledGEMM (E8M0/sf_vec_size=32).sf_vec_size = {gemm_e8m0.sf_vec_size}") + print(f" BlockScaledGEMM (E8M0/vs=32).sf_dtype = {gemm_e8m0.sf_dtype}") + print(f" BlockScaledGEMM (E8M0/vs=32).sf_vec_size = {gemm_e8m0.sf_vec_size}") except Exception as e: - print(f" BlockScaledGEMM (E8M0) construction failed: {e}") + print(f" BlockScaledGEMM (E8M0) failed: {e}") print() print("=" * 60) print("NVFP4-0.3: FP4 TMA element type in quantize.py") print("=" * 60) - from dsv4.ops.quantize import quantize_tensor_nvfp4, _create_fp4_tensor + from dsv4.ops.quantize import quantize_tensor_nvfp4 x = torch.randn(4, 512, device='cuda', dtype=torch.bfloat16) x_fp4, x_sf = quantize_tensor_nvfp4(x) print(f" Input dtype: {x.dtype}") @@ -78,19 +80,17 @@ def test_nvfp4_primitives(): print(f" SF output dtype: {x_sf.dtype}") print(f" FP4 shape: {x_fp4.shape} (expected: [4, 256])") print(f" SF shape: {x_sf.shape} (expected: [4, 32])") - - # Check that float4_e2m1fn_x2 survives into TMA - print(f" torch.float4_e2m1fn_x2 = {torch.float4_e2m1fn_x2}") - print(f" FP4 tensor is float4_e2m1fn_x2: {x_fp4.dtype == torch.float4_e2m1fn_x2}") + if x_fp4.dtype == torch.float4_e2m1fn_x2: + print(f" ✅ FP4 tensor is float4_e2m1fn_x2 — correct for TMA") + else: + print(f" ❌ FP4 tensor dtype is {x_fp4.dtype} — should be float4_e2m1fn_x2!") print() print("=" * 60) print("NVFP4-0.4: MMA kind verification") print("=" * 60) - # Try to inspect the MMA kind that tcgen05 uses for NVFP4 try: - # Create a tiled MMA with FP4 inputs a_major = cutlass.utils.LayoutEnum.ROW_MAJOR.mma_major_mode() b_major = cutlass.utils.LayoutEnum.COLUMN_MAJOR.mma_major_mode() mma = cutlass.utils.sm100.make_trivial_tiled_mma( @@ -99,8 +99,7 @@ def test_nvfp4_primitives(): tcgen05.OperandSource.SMEM, ) print(f" FP4 MMA shape_mnk = {mma.shape_mnk}") - print(f" FP4 MMA thr_id shape = {mma.thr_id.shape}") - print(f" FP4 MMA operand A dtype = {mma.a_layout.value_type if hasattr(mma, 'a_layout') else 'N/A'}") + print(f" ✅ FP4 MMA construction succeeded") except Exception as e: print(f" FP4 MMA construction failed: {e}")