fix: NVFP4-0 diag script — import SF_VEC_SIZE from quantize.py

This commit is contained in:
2026-05-23 08:28:13 +00:00
parent f6f4ce8ec2
commit 96567ea64d

View File

@@ -8,34 +8,29 @@ import cutlass
import cutlass.cute as cute
from cutlass.cute.nvgpu import tcgen05
from cutlass import Float32, BFloat16, Float8E4M3FN, Float8E8M0FNU, Float4E2M1FN
import sys, os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
def test_nvfp4_primitives():
print("=" * 60)
print("NVFP4-0.1: sf_dtype and sf_vec_size in GEMM runner")
print("NVFP4-0.1: sf_dtype and sf_vec_size")
print("=" * 60)
# Check what gemm_runner uses
import sys, os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from dsv4.ops.gemm_runner import SF_VEC_SIZE
print(f" gemm_runner SF_VEC_SIZE = {SF_VEC_SIZE}")
from dsv4.ops.quantize import SF_VEC_SIZE
print(f" quantize.py SF_VEC_SIZE = {SF_VEC_SIZE}")
assert SF_VEC_SIZE == 16, f"SF_VEC_SIZE should be 16 for NVFP4, got {SF_VEC_SIZE}"
print(f" ✅ SF_VEC_SIZE = 16 (NVFP4 correct)")
# Check what dense.py BlockScaledGEMM expects
from dsv4.kernels.gemm.dense import BlockScaledGEMM
# Print the docstring-specified types
print(f" dense.py NVF4 spec: A/B=Float4E2M1FN, SF=Float8E8M0FNU/Float8E4M3FN, sf_vec_size=16")
print(f" dense.py MXF4 spec: A/B=Float4E2M1FN, SF=Float8E8M0FNU, sf_vec_size=32")
# Construct a small NVFP4 GEMM and print the sf_dtype it receives
# Construct a BlockScaledGEMM with E4M3 scales (NVFP4)
M, N, K = 128, 256, 512
# Create small FP4 tensors
a_fp4 = torch.randn(M, K // 2, device='cuda', dtype=torch.float4_e2m1fn_x2)
b_fp4 = torch.randn(K // 2, N, device='cuda', dtype=torch.float4_e2m1fn_x2)
a_sf = torch.randn(M, K // 16, device='cuda', dtype=torch.float8_e4m3fn)
b_sf = torch.randn(N, K // 16, device='cuda', dtype=torch.float8_e4m3fn)
from dsv4.kernels.gemm.dense import BlockScaledGEMM
try:
from dsv4.kernels.gemm.dense import BlockScaledGEMM
gemm = BlockScaledGEMM(
a_ptr=a_fp4,
b_ptr=b_fp4,
@@ -46,10 +41,17 @@ def test_nvfp4_primitives():
print(f" BlockScaledGEMM.sf_vec_size = {gemm.sf_vec_size}")
print(f" BlockScaledGEMM.sf_dtype = {gemm.sf_dtype}")
print(f" BlockScaledGEMM.mma_inst_shape_mn = {gemm.mma_inst_shape_mn}")
if gemm.sf_dtype == Float8E4M3FN:
print(f" ✅ sf_dtype is Float8E4M3FN (NVFP4 correct)")
elif gemm.sf_dtype == Float8E8M0FNU:
print(f" ⚠️ sf_dtype is Float8E8M0FNU — this is MXFP4 scale format, NOT NVFP4!")
print(f" ⚠️ NVFP4 should use Float8E4M3FN scales at sf_vec_size=16")
else:
print(f" ❌ sf_dtype is {gemm.sf_dtype} — unexpected!")
except Exception as e:
print(f" BlockScaledGEMM construction failed: {e}")
# Also try with E8M0 scale dtype to see if that's what's being used
# Also try with E8M0 scales (MXFP4)
a_sf_e8m0 = torch.randn(M, K // 32, device='cuda', dtype=torch.float8_e8m0fnu)
b_sf_e8m0 = torch.randn(N, K // 32, device='cuda', dtype=torch.float8_e8m0fnu)
try:
@@ -60,17 +62,17 @@ def test_nvfp4_primitives():
sfb_ptr=b_sf_e8m0,
sf_vec_size=32,
)
print(f" BlockScaledGEMM (E8M0/sf_vec_size=32).sf_dtype = {gemm_e8m0.sf_dtype}")
print(f" BlockScaledGEMM (E8M0/sf_vec_size=32).sf_vec_size = {gemm_e8m0.sf_vec_size}")
print(f" BlockScaledGEMM (E8M0/vs=32).sf_dtype = {gemm_e8m0.sf_dtype}")
print(f" BlockScaledGEMM (E8M0/vs=32).sf_vec_size = {gemm_e8m0.sf_vec_size}")
except Exception as e:
print(f" BlockScaledGEMM (E8M0) construction failed: {e}")
print(f" BlockScaledGEMM (E8M0) failed: {e}")
print()
print("=" * 60)
print("NVFP4-0.3: FP4 TMA element type in quantize.py")
print("=" * 60)
from dsv4.ops.quantize import quantize_tensor_nvfp4, _create_fp4_tensor
from dsv4.ops.quantize import quantize_tensor_nvfp4
x = torch.randn(4, 512, device='cuda', dtype=torch.bfloat16)
x_fp4, x_sf = quantize_tensor_nvfp4(x)
print(f" Input dtype: {x.dtype}")
@@ -78,19 +80,17 @@ def test_nvfp4_primitives():
print(f" SF output dtype: {x_sf.dtype}")
print(f" FP4 shape: {x_fp4.shape} (expected: [4, 256])")
print(f" SF shape: {x_sf.shape} (expected: [4, 32])")
# Check that float4_e2m1fn_x2 survives into TMA
print(f" torch.float4_e2m1fn_x2 = {torch.float4_e2m1fn_x2}")
print(f" FP4 tensor is float4_e2m1fn_x2: {x_fp4.dtype == torch.float4_e2m1fn_x2}")
if x_fp4.dtype == torch.float4_e2m1fn_x2:
print(f" ✅ FP4 tensor is float4_e2m1fn_x2 — correct for TMA")
else:
print(f" FP4 tensor dtype is {x_fp4.dtype} — should be float4_e2m1fn_x2!")
print()
print("=" * 60)
print("NVFP4-0.4: MMA kind verification")
print("=" * 60)
# Try to inspect the MMA kind that tcgen05 uses for NVFP4
try:
# Create a tiled MMA with FP4 inputs
a_major = cutlass.utils.LayoutEnum.ROW_MAJOR.mma_major_mode()
b_major = cutlass.utils.LayoutEnum.COLUMN_MAJOR.mma_major_mode()
mma = cutlass.utils.sm100.make_trivial_tiled_mma(
@@ -99,8 +99,7 @@ def test_nvfp4_primitives():
tcgen05.OperandSource.SMEM,
)
print(f" FP4 MMA shape_mnk = {mma.shape_mnk}")
print(f" FP4 MMA thr_id shape = {mma.thr_id.shape}")
print(f" FP4 MMA operand A dtype = {mma.a_layout.value_type if hasattr(mma, 'a_layout') else 'N/A'}")
print(f" FP4 MMA construction succeeded")
except Exception as e:
print(f" FP4 MMA construction failed: {e}")