diff --git a/tests/unit/test_nvfp4_primitives.py b/tests/unit/test_nvfp4_primitives.py new file mode 100644 index 00000000..a1bed490 --- /dev/null +++ b/tests/unit/test_nvfp4_primitives.py @@ -0,0 +1,112 @@ +"""NVFP4-0: Verify Blackwell FP4 primitives are correct. + +Quick diagnostic — prints sf_dtype, sf_vec_size, MMA kind, TMA element type. +Run on B200 only. +""" +import torch +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import tcgen05 +from cutlass import Float32, BFloat16, Float8E4M3FN, Float8E8M0FNU, Float4E2M1FN + +def test_nvfp4_primitives(): + print("=" * 60) + print("NVFP4-0.1: sf_dtype and sf_vec_size in GEMM runner") + print("=" * 60) + + # Check what gemm_runner uses + import sys, os + sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + from dsv4.ops.gemm_runner import SF_VEC_SIZE + print(f" gemm_runner SF_VEC_SIZE = {SF_VEC_SIZE}") + + # Check what dense.py BlockScaledGEMM expects + from dsv4.kernels.gemm.dense import BlockScaledGEMM + # Print the docstring-specified types + print(f" dense.py NVF4 spec: A/B=Float4E2M1FN, SF=Float8E8M0FNU/Float8E4M3FN, sf_vec_size=16") + print(f" dense.py MXF4 spec: A/B=Float4E2M1FN, SF=Float8E8M0FNU, sf_vec_size=32") + + # Construct a small NVFP4 GEMM and print the sf_dtype it receives + M, N, K = 128, 256, 512 + # Create small FP4 tensors + a_fp4 = torch.randn(M, K // 2, device='cuda', dtype=torch.float4_e2m1fn_x2) + b_fp4 = torch.randn(K // 2, N, device='cuda', dtype=torch.float4_e2m1fn_x2) + a_sf = torch.randn(M, K // 16, device='cuda', dtype=torch.float8_e4m3fn) + b_sf = torch.randn(N, K // 16, device='cuda', dtype=torch.float8_e4m3fn) + + try: + from dsv4.kernels.gemm.dense import BlockScaledGEMM + gemm = BlockScaledGEMM( + a_ptr=a_fp4, + b_ptr=b_fp4, + sfa_ptr=a_sf, + sfb_ptr=b_sf, + sf_vec_size=16, + ) + print(f" BlockScaledGEMM.sf_vec_size = {gemm.sf_vec_size}") + print(f" BlockScaledGEMM.sf_dtype = {gemm.sf_dtype}") + print(f" BlockScaledGEMM.mma_inst_shape_mn = {gemm.mma_inst_shape_mn}") + except Exception as e: + print(f" BlockScaledGEMM construction failed: {e}") + + # Also try with E8M0 scale dtype to see if that's what's being used + a_sf_e8m0 = torch.randn(M, K // 32, device='cuda', dtype=torch.float8_e8m0fnu) + b_sf_e8m0 = torch.randn(N, K // 32, device='cuda', dtype=torch.float8_e8m0fnu) + try: + gemm_e8m0 = BlockScaledGEMM( + a_ptr=a_fp4, + b_ptr=b_fp4, + sfa_ptr=a_sf_e8m0, + sfb_ptr=b_sf_e8m0, + sf_vec_size=32, + ) + print(f" BlockScaledGEMM (E8M0/sf_vec_size=32).sf_dtype = {gemm_e8m0.sf_dtype}") + print(f" BlockScaledGEMM (E8M0/sf_vec_size=32).sf_vec_size = {gemm_e8m0.sf_vec_size}") + except Exception as e: + print(f" BlockScaledGEMM (E8M0) construction failed: {e}") + + print() + print("=" * 60) + print("NVFP4-0.3: FP4 TMA element type in quantize.py") + print("=" * 60) + + from dsv4.ops.quantize import quantize_tensor_nvfp4, _create_fp4_tensor + x = torch.randn(4, 512, device='cuda', dtype=torch.bfloat16) + x_fp4, x_sf = quantize_tensor_nvfp4(x) + print(f" Input dtype: {x.dtype}") + print(f" FP4 output dtype: {x_fp4.dtype}") + print(f" SF output dtype: {x_sf.dtype}") + print(f" FP4 shape: {x_fp4.shape} (expected: [4, 256])") + print(f" SF shape: {x_sf.shape} (expected: [4, 32])") + + # Check that float4_e2m1fn_x2 survives into TMA + print(f" torch.float4_e2m1fn_x2 = {torch.float4_e2m1fn_x2}") + print(f" FP4 tensor is float4_e2m1fn_x2: {x_fp4.dtype == torch.float4_e2m1fn_x2}") + + print() + print("=" * 60) + print("NVFP4-0.4: MMA kind verification") + print("=" * 60) + + # Try to inspect the MMA kind that tcgen05 uses for NVFP4 + try: + # Create a tiled MMA with FP4 inputs + a_major = cutlass.utils.LayoutEnum.ROW_MAJOR.mma_major_mode() + b_major = cutlass.utils.LayoutEnum.COLUMN_MAJOR.mma_major_mode() + mma = cutlass.utils.sm100.make_trivial_tiled_mma( + Float4E2M1FN, Float4E2M1FN, a_major, b_major, Float32, + tcgen05.CtaGroup.ONE, (128, 256), + tcgen05.OperandSource.SMEM, + ) + print(f" FP4 MMA shape_mnk = {mma.shape_mnk}") + print(f" FP4 MMA thr_id shape = {mma.thr_id.shape}") + print(f" FP4 MMA operand A dtype = {mma.a_layout.value_type if hasattr(mma, 'a_layout') else 'N/A'}") + except Exception as e: + print(f" FP4 MMA construction failed: {e}") + + print() + print("DONE — NVFP4-0 verification complete") + + +if __name__ == "__main__": + test_nvfp4_primitives()