Files
nvfp4-megamoe-kernel/tests/unit/test_nvfp4_primitives.py

170 lines
6.9 KiB
Python

"""NVFP4-0: Verify Blackwell FP4 primitives are correct.
Quick diagnostic — prints sf_dtype, sf_vec_size, MMA kind, TMA element type.
Run on B200 only.
"""
import torch
import cutlass
import cutlass.cute as cute
from cutlass.cute.nvgpu import tcgen05
from cutlass import Float32, BFloat16, Float8E4M3FN, Float8E8M0FNU, Float4E2M1FN
import sys, os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
def test_nvfp4_primitives():
print("=" * 60)
print("NVFP4-0.1: sf_dtype and sf_vec_size in GEMM kernels")
print("=" * 60)
from dsv4.ops.quantize import SF_VEC_SIZE
print(f" quantize.py SF_VEC_SIZE = {SF_VEC_SIZE}")
assert SF_VEC_SIZE == 16
print(f" ✅ SF_VEC_SIZE = 16 (NVFP4 correct)")
# Construct dense GEMM with the real API
from dsv4.kernels.gemm.dense import Sm100BlockScaledPersistentDenseGemmKernel
try:
gemm = Sm100BlockScaledPersistentDenseGemmKernel(
sf_vec_size=16,
mma_tiler_mn=(128, 256),
cluster_shape_mn=(1, 1),
)
print(f" DenseGEMM.sf_vec_size = {gemm.sf_vec_size}")
print(f" DenseGEMM.cta_group = {gemm.cta_group}")
# sf_dtype is set in _setup_attributes from the actual tensor pointers
# Print what we can here
print(f" DenseGEMM.acc_dtype = {gemm.acc_dtype}")
print(f" DenseGEMM.use_2cta_instrs = {gemm.use_2cta_instrs}")
except Exception as e:
print(f" DenseGEMM construction failed: {e}")
# Try sf_vec_size=32 (MXFP4) for comparison
try:
gemm_mxf4 = Sm100BlockScaledPersistentDenseGemmKernel(
sf_vec_size=32,
mma_tiler_mn=(128, 256),
cluster_shape_mn=(1, 1),
)
print(f" DenseGEMM (MXF4/vs=32).sf_vec_size = {gemm_mxf4.sf_vec_size}")
except Exception as e:
print(f" DenseGEMM (MXF4) failed: {e}")
# Check grouped kernel
from dsv4.kernels.gemm.grouped import ScaledGroupedGemmKernel
try:
grp = ScaledGroupedGemmKernel(
sf_vec_size=16,
mma_tiler_mn=(128, 256),
cluster_shape_mn=(1, 1),
)
print(f" GroupedGEMM.sf_vec_size = {grp.sf_vec_size}")
except Exception as e:
print(f" GroupedGEMM construction failed: {e}")
# Check fused SwiGLU kernel
from dsv4.kernels.gemm.fused_swiglu import FusedSwiGLUScaledGroupedGemmKernel
try:
fused = FusedSwiGLUScaledGroupedGemmKernel(
sf_vec_size=16,
mma_tiler_mn=(128, 256),
cluster_shape_mn=(1, 1),
swiglu_limit=10.0,
)
print(f" FusedSwiGLU.sf_vec_size = {fused.sf_vec_size}")
print(f" FusedSwiGLU.swiglu_limit = {fused.swiglu_limit}")
except Exception as e:
print(f" FusedSwiGLU construction failed: {e}")
# Now check sf_dtype by actually creating the tensors and calling the kernel
# sf_dtype is inferred from the sfa_ptr dtype at call time
print()
print(" --- sf_dtype inference from tensor dtypes ---")
M, N, K = 128, 256, 512
a_fp4 = torch.randint(0, 256, (M, K // 2), device='cuda', dtype=torch.uint8).view(torch.float4_e2m1fn_x2)
b_fp4 = torch.randint(0, 256, (K // 2, N), device='cuda', dtype=torch.uint8).view(torch.float4_e2m1fn_x2)
a_sf_e4m3 = torch.randint(0, 256, (M, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn)
b_sf_e4m3 = torch.randint(0, 256, (N, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn)
a_sf_e8m0 = torch.randint(0, 256, (M, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu)
b_sf_e8m0 = torch.randint(0, 256, (N, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu)
print(f" a_fp4 dtype = {a_fp4.dtype} (expect float4_e2m1fn_x2)")
print(f" a_sf (E4M3) dtype = {a_sf_e4m3.dtype}, shape = {a_sf_e4m3.shape}")
print(f" a_sf (E8M0) dtype = {a_sf_e8m0.dtype}, shape = {a_sf_e8m0.shape}")
# The runner uses E4M3 or E8M0? Check what quantize actually produces
from dsv4.ops.quantize import quantize_activation_nvfp4
x = torch.randn(4, 512, device='cuda', dtype=torch.bfloat16)
global_scale = torch.tensor(448.0, device='cuda', dtype=torch.float32)
x_fp4, x_sf = quantize_activation_nvfp4(x, global_scale)
print(f" quantize_activation_nvfp4 output: FP4 dtype={x_fp4.dtype}, SF dtype={x_sf.dtype}")
if x_sf.dtype == torch.float8_e4m3fn:
print(f" ✅ Scale factors are E4M3 — NVFP4 correct")
elif x_sf.dtype == torch.float8_e8m0fnu:
print(f" ⚠️ Scale factors are E8M0 — this is MXFP4 format, NOT NVFP4!")
else:
print(f" ❌ Scale factors are {x_sf.dtype} — unexpected!")
# Check the gemm_runner class name
import dsv4.ops.gemm_runner as gr
runner_classes = [name for name in dir(gr) if 'unner' in name or 'Gemm' in name]
print(f" gemm_runner classes: {runner_classes}")
print()
print("=" * 60)
print("NVFP4-0.3: FP4 TMA element type")
print("=" * 60)
print(f" FP4 tensor dtype = {x_fp4.dtype}")
print(f" FP4 tensor shape = {x_fp4.shape}")
print(f" torch.float4_e2m1fn_x2 available = {hasattr(torch, 'float4_e2m1fn_x2')}")
if x_fp4.dtype == torch.float4_e2m1fn_x2:
print(f" ✅ FP4 tensor is float4_e2m1fn_x2 — correct for TMA")
else:
print(f" ❌ FP4 tensor dtype is {x_fp4.dtype} — should be float4_e2m1fn_x2!")
print()
print("=" * 60)
print("NVFP4-0.4: MMA kind verification")
print("=" * 60)
# Test with correct LayoutEnum
from cutlass.utils import LayoutEnum
try:
a_major = LayoutEnum.ROW_MAJOR.mma_major_mode()
b_major = LayoutEnum.COLUMN_MAJOR.mma_major_mode()
mma = cutlass.utils.sm100.make_trivial_tiled_mma(
Float4E2M1FN, Float4E2M1FN, a_major, b_major, Float32,
tcgen05.CtaGroup.ONE, (128, 256),
tcgen05.OperandSource.SMEM,
)
print(f" FP4 MMA shape_mnk = {mma.shape_mnk}")
print(f" ✅ FP4 MMA construction succeeded")
except Exception as e:
print(f" FP4 MMA construction failed: {e}")
# Check use_2cta_instrs behavior
print()
print(" --- use_2cta_instrs check ---")
try:
gemm_1cta = Sm100BlockScaledPersistentDenseGemmKernel(
sf_vec_size=16, mma_tiler_mn=(128, 256), cluster_shape_mn=(1, 1),
)
print(f" mma_tiler_mn=(128,256) → use_2cta_instrs={gemm_1cta.use_2cta_instrs}")
except Exception as e:
print(f" 1-CTA failed: {e}")
try:
gemm_2cta = Sm100BlockScaledPersistentDenseGemmKernel(
sf_vec_size=16, mma_tiler_mn=(256, 256), cluster_shape_mn=(2, 1),
)
print(f" mma_tiler_mn=(256,256) → use_2cta_instrs={gemm_2cta.use_2cta_instrs}")
except Exception as e:
print(f" 2-CTA failed: {e}")
print()
print("DONE — NVFP4-0 verification complete")
if __name__ == "__main__":
test_nvfp4_primitives()