fix: use correct API for NVFP4-0 diag (sf_vec_size + mma_tiler_mn)

This commit is contained in:
2026-05-23 08:38:19 +00:00
parent 2b86aca551
commit 96a43b60f3

View File

@@ -14,85 +14,108 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspa
def test_nvfp4_primitives():
print("=" * 60)
print("NVFP4-0.1: sf_dtype and sf_vec_size")
print("NVFP4-0.1: sf_dtype and sf_vec_size in GEMM kernels")
print("=" * 60)
from dsv4.ops.quantize import SF_VEC_SIZE
print(f" quantize.py SF_VEC_SIZE = {SF_VEC_SIZE}")
assert SF_VEC_SIZE == 16, f"SF_VEC_SIZE should be 16 for NVFP4, got {SF_VEC_SIZE}"
assert SF_VEC_SIZE == 16
print(f" ✅ SF_VEC_SIZE = 16 (NVFP4 correct)")
# Construct a BlockScaledGEMM with E4M3 scales (NVFP4)
M, N, K = 128, 256, 512
a_fp4 = torch.randint(0, 256, (M, K // 2), device='cuda', dtype=torch.uint8).view(torch.float4_e2m1fn_x2)
b_fp4 = torch.randint(0, 256, (K // 2, N), device='cuda', dtype=torch.uint8).view(torch.float4_e2m1fn_x2)
a_sf = torch.randint(0, 256, (M, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn)
b_sf = torch.randint(0, 256, (N, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn)
# Construct dense GEMM with the real API
from dsv4.kernels.gemm.dense import Sm100BlockScaledPersistentDenseGemmKernel
try:
gemm = Sm100BlockScaledPersistentDenseGemmKernel(
a_ptr=a_fp4,
b_ptr=b_fp4,
sfa_ptr=a_sf,
sfb_ptr=b_sf,
sf_vec_size=16,
mma_tiler_mn=(128, 256),
cluster_shape_mn=(1, 1),
)
print(f" DenseGEMM.sf_vec_size = {gemm.sf_vec_size}")
print(f" DenseGEMM.sf_dtype = {gemm.sf_dtype}")
if gemm.sf_dtype == Float8E4M3FN:
print(f" ✅ sf_dtype is Float8E4M3FN (NVFP4 correct)")
elif gemm.sf_dtype == Float8E8M0FNU:
print(f" ⚠️ sf_dtype is Float8E8M0FNU — this is MXFP4 scale format, NOT NVFP4!")
else:
print(f" ❌ sf_dtype is {gemm.sf_dtype} — unexpected!")
print(f" DenseGEMM.cta_group = {gemm.cta_group}")
# sf_dtype is set in _setup_attributes from the actual tensor pointers
# Print what we can here
print(f" DenseGEMM.acc_dtype = {gemm.acc_dtype}")
print(f" DenseGEMM.use_2cta_instrs = {gemm.use_2cta_instrs}")
except Exception as e:
print(f" DenseGEMM construction failed: {e}")
# Also try with E8M0 scales (MXFP4)
a_sf_e8m0 = torch.randint(0, 256, (M, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu)
b_sf_e8m0 = torch.randint(0, 256, (N, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu)
# Try sf_vec_size=32 (MXFP4) for comparison
try:
gemm_e8m0 = Sm100BlockScaledPersistentDenseGemmKernel(
a_ptr=a_fp4,
b_ptr=b_fp4,
sfa_ptr=a_sf_e8m0,
sfb_ptr=b_sf_e8m0,
gemm_mxf4 = Sm100BlockScaledPersistentDenseGemmKernel(
sf_vec_size=32,
mma_tiler_mn=(128, 256),
cluster_shape_mn=(1, 1),
)
print(f" DenseGEMM (E8M0/vs=32).sf_dtype = {gemm_e8m0.sf_dtype}")
print(f" DenseGEMM (E8M0/vs=32).sf_vec_size = {gemm_e8m0.sf_vec_size}")
print(f" DenseGEMM (MXF4/vs=32).sf_vec_size = {gemm_mxf4.sf_vec_size}")
except Exception as e:
print(f" DenseGEMM (E8M0) failed: {e}")
print(f" DenseGEMM (MXF4) failed: {e}")
# Check the grouped kernel too
# Check grouped kernel
from dsv4.kernels.gemm.grouped import ScaledGroupedGemmKernel
try:
grp = ScaledGroupedGemmKernel(
a_ptr=a_fp4,
b_ptr=b_fp4,
sfa_ptr=a_sf,
sfb_ptr=b_sf,
sf_vec_size=16,
mma_tiler_mn=(128, 256),
cluster_shape_mn=(1, 1),
)
print(f" GroupedGEMM.sf_vec_size = {grp.sf_vec_size}")
print(f" GroupedGEMM.sf_dtype = {grp.sf_dtype}")
except Exception as e:
print(f" GroupedGEMM construction failed: {e}")
# Check fused SwiGLU kernel
from dsv4.kernels.gemm.fused_swiglu import FusedSwiGLUScaledGroupedGemmKernel
try:
fused = FusedSwiGLUScaledGroupedGemmKernel(
sf_vec_size=16,
mma_tiler_mn=(128, 256),
cluster_shape_mn=(1, 1),
swiglu_limit=10.0,
)
print(f" FusedSwiGLU.sf_vec_size = {fused.sf_vec_size}")
print(f" FusedSwiGLU.swiglu_limit = {fused.swiglu_limit}")
except Exception as e:
print(f" FusedSwiGLU construction failed: {e}")
# Now check sf_dtype by actually creating the tensors and calling the kernel
# sf_dtype is inferred from the sfa_ptr dtype at call time
print()
print(" --- sf_dtype inference from tensor dtypes ---")
M, N, K = 128, 256, 512
a_fp4 = torch.randint(0, 256, (M, K // 2), device='cuda', dtype=torch.uint8).view(torch.float4_e2m1fn_x2)
b_fp4 = torch.randint(0, 256, (K // 2, N), device='cuda', dtype=torch.uint8).view(torch.float4_e2m1fn_x2)
a_sf_e4m3 = torch.randint(0, 256, (M, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn)
b_sf_e4m3 = torch.randint(0, 256, (N, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn)
a_sf_e8m0 = torch.randint(0, 256, (M, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu)
b_sf_e8m0 = torch.randint(0, 256, (N, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu)
print(f" a_fp4 dtype = {a_fp4.dtype} (expect float4_e2m1fn_x2)")
print(f" a_sf (E4M3) dtype = {a_sf_e4m3.dtype}, shape = {a_sf_e4m3.shape}")
print(f" a_sf (E8M0) dtype = {a_sf_e8m0.dtype}, shape = {a_sf_e8m0.shape}")
# The runner uses E4M3 or E8M0? Check what quantize actually produces
from dsv4.ops.quantize import quantize_to_nvfp4
x = torch.randn(4, 512, device='cuda', dtype=torch.bfloat16)
x_fp4, x_sf = quantize_to_nvfp4(x)
print(f" quantize_to_nvfp4 output: FP4 dtype={x_fp4.dtype}, SF dtype={x_sf.dtype}")
if x_sf.dtype == torch.float8_e4m3fn:
print(f" ✅ Scale factors are E4M3 — NVFP4 correct")
elif x_sf.dtype == torch.float8_e8m0fnu:
print(f" ⚠️ Scale factors are E8M0 — this is MXFP4 format, NOT NVFP4!")
else:
print(f" ❌ Scale factors are {x_sf.dtype} — unexpected!")
# Check what the gemm_runner actually passes to the kernel
from dsv4.ops.gemm_runner import Nvfp4GemmRunner
print(f" Nvfp4GemmRunner exists: {Nvfp4GemmRunner is not None}")
print()
print("=" * 60)
print("NVFP4-0.3: FP4 TMA element type in quantize.py")
print("NVFP4-0.3: FP4 TMA element type")
print("=" * 60)
from dsv4.ops.quantize import quantize_tensor_nvfp4
x = torch.randn(4, 512, device='cuda', dtype=torch.bfloat16)
x_fp4, x_sf = quantize_tensor_nvfp4(x)
print(f" Input dtype: {x.dtype}")
print(f" FP4 output dtype: {x_fp4.dtype}")
print(f" SF output dtype: {x_sf.dtype}")
print(f" FP4 shape: {x_fp4.shape} (expected: [4, 256])")
print(f" SF shape: {x_sf.shape} (expected: [4, 32])")
print(f" FP4 tensor dtype = {x_fp4.dtype}")
print(f" FP4 tensor shape = {x_fp4.shape}")
print(f" torch.float4_e2m1fn_x2 available = {hasattr(torch, 'float4_e2m1fn_x2')}")
if x_fp4.dtype == torch.float4_e2m1fn_x2:
print(f" ✅ FP4 tensor is float4_e2m1fn_x2 — correct for TMA")
else: