diff --git a/tests/unit/test_nvfp4_primitives.py b/tests/unit/test_nvfp4_primitives.py index 3486a28a..cc2942e2 100644 --- a/tests/unit/test_nvfp4_primitives.py +++ b/tests/unit/test_nvfp4_primitives.py @@ -14,85 +14,108 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspa def test_nvfp4_primitives(): print("=" * 60) - print("NVFP4-0.1: sf_dtype and sf_vec_size") + print("NVFP4-0.1: sf_dtype and sf_vec_size in GEMM kernels") print("=" * 60) from dsv4.ops.quantize import SF_VEC_SIZE print(f" quantize.py SF_VEC_SIZE = {SF_VEC_SIZE}") - assert SF_VEC_SIZE == 16, f"SF_VEC_SIZE should be 16 for NVFP4, got {SF_VEC_SIZE}" + assert SF_VEC_SIZE == 16 print(f" ✅ SF_VEC_SIZE = 16 (NVFP4 correct)") - # Construct a BlockScaledGEMM with E4M3 scales (NVFP4) - M, N, K = 128, 256, 512 - a_fp4 = torch.randint(0, 256, (M, K // 2), device='cuda', dtype=torch.uint8).view(torch.float4_e2m1fn_x2) - b_fp4 = torch.randint(0, 256, (K // 2, N), device='cuda', dtype=torch.uint8).view(torch.float4_e2m1fn_x2) - a_sf = torch.randint(0, 256, (M, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn) - b_sf = torch.randint(0, 256, (N, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn) - + # Construct dense GEMM with the real API from dsv4.kernels.gemm.dense import Sm100BlockScaledPersistentDenseGemmKernel try: gemm = Sm100BlockScaledPersistentDenseGemmKernel( - a_ptr=a_fp4, - b_ptr=b_fp4, - sfa_ptr=a_sf, - sfb_ptr=b_sf, sf_vec_size=16, + mma_tiler_mn=(128, 256), + cluster_shape_mn=(1, 1), ) print(f" DenseGEMM.sf_vec_size = {gemm.sf_vec_size}") - print(f" DenseGEMM.sf_dtype = {gemm.sf_dtype}") - if gemm.sf_dtype == Float8E4M3FN: - print(f" ✅ sf_dtype is Float8E4M3FN (NVFP4 correct)") - elif gemm.sf_dtype == Float8E8M0FNU: - print(f" ⚠️ sf_dtype is Float8E8M0FNU — this is MXFP4 scale format, NOT NVFP4!") - else: - print(f" ❌ sf_dtype is {gemm.sf_dtype} — unexpected!") + print(f" DenseGEMM.cta_group = {gemm.cta_group}") + # sf_dtype is set in _setup_attributes from the actual tensor pointers + # Print what we can here + print(f" DenseGEMM.acc_dtype = {gemm.acc_dtype}") + print(f" DenseGEMM.use_2cta_instrs = {gemm.use_2cta_instrs}") except Exception as e: print(f" DenseGEMM construction failed: {e}") - # Also try with E8M0 scales (MXFP4) - a_sf_e8m0 = torch.randint(0, 256, (M, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu) - b_sf_e8m0 = torch.randint(0, 256, (N, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu) + # Try sf_vec_size=32 (MXFP4) for comparison try: - gemm_e8m0 = Sm100BlockScaledPersistentDenseGemmKernel( - a_ptr=a_fp4, - b_ptr=b_fp4, - sfa_ptr=a_sf_e8m0, - sfb_ptr=b_sf_e8m0, + gemm_mxf4 = Sm100BlockScaledPersistentDenseGemmKernel( sf_vec_size=32, + mma_tiler_mn=(128, 256), + cluster_shape_mn=(1, 1), ) - print(f" DenseGEMM (E8M0/vs=32).sf_dtype = {gemm_e8m0.sf_dtype}") - print(f" DenseGEMM (E8M0/vs=32).sf_vec_size = {gemm_e8m0.sf_vec_size}") + print(f" DenseGEMM (MXF4/vs=32).sf_vec_size = {gemm_mxf4.sf_vec_size}") except Exception as e: - print(f" DenseGEMM (E8M0) failed: {e}") + print(f" DenseGEMM (MXF4) failed: {e}") - # Check the grouped kernel too + # Check grouped kernel from dsv4.kernels.gemm.grouped import ScaledGroupedGemmKernel try: grp = ScaledGroupedGemmKernel( - a_ptr=a_fp4, - b_ptr=b_fp4, - sfa_ptr=a_sf, - sfb_ptr=b_sf, sf_vec_size=16, + mma_tiler_mn=(128, 256), + cluster_shape_mn=(1, 1), ) print(f" GroupedGEMM.sf_vec_size = {grp.sf_vec_size}") - print(f" GroupedGEMM.sf_dtype = {grp.sf_dtype}") except Exception as e: print(f" GroupedGEMM construction failed: {e}") + # Check fused SwiGLU kernel + from dsv4.kernels.gemm.fused_swiglu import FusedSwiGLUScaledGroupedGemmKernel + try: + fused = FusedSwiGLUScaledGroupedGemmKernel( + sf_vec_size=16, + mma_tiler_mn=(128, 256), + cluster_shape_mn=(1, 1), + swiglu_limit=10.0, + ) + print(f" FusedSwiGLU.sf_vec_size = {fused.sf_vec_size}") + print(f" FusedSwiGLU.swiglu_limit = {fused.swiglu_limit}") + except Exception as e: + print(f" FusedSwiGLU construction failed: {e}") + + # Now check sf_dtype by actually creating the tensors and calling the kernel + # sf_dtype is inferred from the sfa_ptr dtype at call time + print() + print(" --- sf_dtype inference from tensor dtypes ---") + M, N, K = 128, 256, 512 + a_fp4 = torch.randint(0, 256, (M, K // 2), device='cuda', dtype=torch.uint8).view(torch.float4_e2m1fn_x2) + b_fp4 = torch.randint(0, 256, (K // 2, N), device='cuda', dtype=torch.uint8).view(torch.float4_e2m1fn_x2) + a_sf_e4m3 = torch.randint(0, 256, (M, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn) + b_sf_e4m3 = torch.randint(0, 256, (N, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn) + a_sf_e8m0 = torch.randint(0, 256, (M, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu) + b_sf_e8m0 = torch.randint(0, 256, (N, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu) + + print(f" a_fp4 dtype = {a_fp4.dtype} (expect float4_e2m1fn_x2)") + print(f" a_sf (E4M3) dtype = {a_sf_e4m3.dtype}, shape = {a_sf_e4m3.shape}") + print(f" a_sf (E8M0) dtype = {a_sf_e8m0.dtype}, shape = {a_sf_e8m0.shape}") + + # The runner uses E4M3 or E8M0? Check what quantize actually produces + from dsv4.ops.quantize import quantize_to_nvfp4 + x = torch.randn(4, 512, device='cuda', dtype=torch.bfloat16) + x_fp4, x_sf = quantize_to_nvfp4(x) + print(f" quantize_to_nvfp4 output: FP4 dtype={x_fp4.dtype}, SF dtype={x_sf.dtype}") + if x_sf.dtype == torch.float8_e4m3fn: + print(f" ✅ Scale factors are E4M3 — NVFP4 correct") + elif x_sf.dtype == torch.float8_e8m0fnu: + print(f" ⚠️ Scale factors are E8M0 — this is MXFP4 format, NOT NVFP4!") + else: + print(f" ❌ Scale factors are {x_sf.dtype} — unexpected!") + + # Check what the gemm_runner actually passes to the kernel + from dsv4.ops.gemm_runner import Nvfp4GemmRunner + print(f" Nvfp4GemmRunner exists: {Nvfp4GemmRunner is not None}") + print() print("=" * 60) - print("NVFP4-0.3: FP4 TMA element type in quantize.py") + print("NVFP4-0.3: FP4 TMA element type") print("=" * 60) - from dsv4.ops.quantize import quantize_tensor_nvfp4 - x = torch.randn(4, 512, device='cuda', dtype=torch.bfloat16) - x_fp4, x_sf = quantize_tensor_nvfp4(x) - print(f" Input dtype: {x.dtype}") - print(f" FP4 output dtype: {x_fp4.dtype}") - print(f" SF output dtype: {x_sf.dtype}") - print(f" FP4 shape: {x_fp4.shape} (expected: [4, 256])") - print(f" SF shape: {x_sf.shape} (expected: [4, 32])") + print(f" FP4 tensor dtype = {x_fp4.dtype}") + print(f" FP4 tensor shape = {x_fp4.shape}") + print(f" torch.float4_e2m1fn_x2 available = {hasattr(torch, 'float4_e2m1fn_x2')}") if x_fp4.dtype == torch.float4_e2m1fn_x2: print(f" ✅ FP4 tensor is float4_e2m1fn_x2 — correct for TMA") else: