"""NVFP4-0: Verify Blackwell FP4 primitives are correct. Quick diagnostic — prints sf_dtype, sf_vec_size, MMA kind, TMA element type. Run on B200 only. """ import torch import cutlass import cutlass.cute as cute from cutlass.cute.nvgpu import tcgen05 from cutlass import Float32, BFloat16, Float8E4M3FN, Float8E8M0FNU, Float4E2M1FN import sys, os sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) def test_nvfp4_primitives(): print("=" * 60) print("NVFP4-0.1: sf_dtype and sf_vec_size in GEMM kernels") print("=" * 60) from dsv4.ops.quantize import SF_VEC_SIZE print(f" quantize.py SF_VEC_SIZE = {SF_VEC_SIZE}") assert SF_VEC_SIZE == 16 print(f" ✅ SF_VEC_SIZE = 16 (NVFP4 correct)") # Construct dense GEMM with the real API from dsv4.kernels.gemm.dense import Sm100BlockScaledPersistentDenseGemmKernel try: gemm = Sm100BlockScaledPersistentDenseGemmKernel( sf_vec_size=16, mma_tiler_mn=(128, 256), cluster_shape_mn=(1, 1), ) print(f" DenseGEMM.sf_vec_size = {gemm.sf_vec_size}") print(f" DenseGEMM.cta_group = {gemm.cta_group}") # sf_dtype is set in _setup_attributes from the actual tensor pointers # Print what we can here print(f" DenseGEMM.acc_dtype = {gemm.acc_dtype}") print(f" DenseGEMM.use_2cta_instrs = {gemm.use_2cta_instrs}") except Exception as e: print(f" DenseGEMM construction failed: {e}") # Try sf_vec_size=32 (MXFP4) for comparison try: gemm_mxf4 = Sm100BlockScaledPersistentDenseGemmKernel( sf_vec_size=32, mma_tiler_mn=(128, 256), cluster_shape_mn=(1, 1), ) print(f" DenseGEMM (MXF4/vs=32).sf_vec_size = {gemm_mxf4.sf_vec_size}") except Exception as e: print(f" DenseGEMM (MXF4) failed: {e}") # Check grouped kernel from dsv4.kernels.gemm.grouped import ScaledGroupedGemmKernel try: grp = ScaledGroupedGemmKernel( sf_vec_size=16, mma_tiler_mn=(128, 256), cluster_shape_mn=(1, 1), ) print(f" GroupedGEMM.sf_vec_size = {grp.sf_vec_size}") except Exception as e: print(f" GroupedGEMM construction failed: {e}") # Check fused SwiGLU kernel from dsv4.kernels.gemm.fused_swiglu import FusedSwiGLUScaledGroupedGemmKernel try: fused = FusedSwiGLUScaledGroupedGemmKernel( sf_vec_size=16, mma_tiler_mn=(128, 256), cluster_shape_mn=(1, 1), swiglu_limit=10.0, ) print(f" FusedSwiGLU.sf_vec_size = {fused.sf_vec_size}") print(f" FusedSwiGLU.swiglu_limit = {fused.swiglu_limit}") except Exception as e: print(f" FusedSwiGLU construction failed: {e}") # Now check sf_dtype by actually creating the tensors and calling the kernel # sf_dtype is inferred from the sfa_ptr dtype at call time print() print(" --- sf_dtype inference from tensor dtypes ---") M, N, K = 128, 256, 512 a_fp4 = torch.randint(0, 256, (M, K // 2), device='cuda', dtype=torch.uint8).view(torch.float4_e2m1fn_x2) b_fp4 = torch.randint(0, 256, (K // 2, N), device='cuda', dtype=torch.uint8).view(torch.float4_e2m1fn_x2) a_sf_e4m3 = torch.randint(0, 256, (M, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn) b_sf_e4m3 = torch.randint(0, 256, (N, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn) a_sf_e8m0 = torch.randint(0, 256, (M, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu) b_sf_e8m0 = torch.randint(0, 256, (N, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu) print(f" a_fp4 dtype = {a_fp4.dtype} (expect float4_e2m1fn_x2)") print(f" a_sf (E4M3) dtype = {a_sf_e4m3.dtype}, shape = {a_sf_e4m3.shape}") print(f" a_sf (E8M0) dtype = {a_sf_e8m0.dtype}, shape = {a_sf_e8m0.shape}") # The runner uses E4M3 or E8M0? Check what quantize actually produces from dsv4.ops.quantize import quantize_activation_nvfp4 x = torch.randn(4, 512, device='cuda', dtype=torch.bfloat16) global_scale = torch.tensor(448.0, device='cuda', dtype=torch.float32) x_fp4, x_sf = quantize_activation_nvfp4(x, global_scale) print(f" quantize_activation_nvfp4 output: FP4 dtype={x_fp4.dtype}, SF dtype={x_sf.dtype}") if x_sf.dtype == torch.float8_e4m3fn: print(f" ✅ Scale factors are E4M3 — NVFP4 correct") elif x_sf.dtype == torch.float8_e8m0fnu: print(f" ⚠️ Scale factors are E8M0 — this is MXFP4 format, NOT NVFP4!") else: print(f" ❌ Scale factors are {x_sf.dtype} — unexpected!") # Check the gemm_runner class name import dsv4.ops.gemm_runner as gr runner_classes = [name for name in dir(gr) if 'unner' in name or 'Gemm' in name] print(f" gemm_runner classes: {runner_classes}") print() print("=" * 60) print("NVFP4-0.3: FP4 TMA element type") print("=" * 60) print(f" FP4 tensor dtype = {x_fp4.dtype}") print(f" FP4 tensor shape = {x_fp4.shape}") print(f" torch.float4_e2m1fn_x2 available = {hasattr(torch, 'float4_e2m1fn_x2')}") if x_fp4.dtype == torch.float4_e2m1fn_x2: print(f" ✅ FP4 tensor is float4_e2m1fn_x2 — correct for TMA") else: print(f" ❌ FP4 tensor dtype is {x_fp4.dtype} — should be float4_e2m1fn_x2!") print() print("=" * 60) print("NVFP4-0.4: MMA kind verification") print("=" * 60) # Test with correct LayoutEnum from cutlass.utils import LayoutEnum try: a_major = LayoutEnum.ROW_MAJOR.mma_major_mode() b_major = LayoutEnum.COLUMN_MAJOR.mma_major_mode() mma = cutlass.utils.sm100.make_trivial_tiled_mma( Float4E2M1FN, Float4E2M1FN, a_major, b_major, Float32, tcgen05.CtaGroup.ONE, (128, 256), tcgen05.OperandSource.SMEM, ) print(f" FP4 MMA shape_mnk = {mma.shape_mnk}") print(f" ✅ FP4 MMA construction succeeded") except Exception as e: print(f" FP4 MMA construction failed: {e}") # Check use_2cta_instrs behavior print() print(" --- use_2cta_instrs check ---") try: gemm_1cta = Sm100BlockScaledPersistentDenseGemmKernel( sf_vec_size=16, mma_tiler_mn=(128, 256), cluster_shape_mn=(1, 1), ) print(f" mma_tiler_mn=(128,256) → use_2cta_instrs={gemm_1cta.use_2cta_instrs}") except Exception as e: print(f" 1-CTA failed: {e}") try: gemm_2cta = Sm100BlockScaledPersistentDenseGemmKernel( sf_vec_size=16, mma_tiler_mn=(256, 256), cluster_shape_mn=(2, 1), ) print(f" mma_tiler_mn=(256,256) → use_2cta_instrs={gemm_2cta.use_2cta_instrs}") except Exception as e: print(f" 2-CTA failed: {e}") print() print("DONE — NVFP4-0 verification complete") if __name__ == "__main__": test_nvfp4_primitives()