fix: NVFP4-0 diag script — import SF_VEC_SIZE from quantize.py
This commit is contained in:
@@ -8,34 +8,29 @@ import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.nvgpu import tcgen05
|
||||
from cutlass import Float32, BFloat16, Float8E4M3FN, Float8E8M0FNU, Float4E2M1FN
|
||||
import sys, os
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
|
||||
|
||||
def test_nvfp4_primitives():
|
||||
print("=" * 60)
|
||||
print("NVFP4-0.1: sf_dtype and sf_vec_size in GEMM runner")
|
||||
print("NVFP4-0.1: sf_dtype and sf_vec_size")
|
||||
print("=" * 60)
|
||||
|
||||
# Check what gemm_runner uses
|
||||
import sys, os
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
from dsv4.ops.gemm_runner import SF_VEC_SIZE
|
||||
print(f" gemm_runner SF_VEC_SIZE = {SF_VEC_SIZE}")
|
||||
from dsv4.ops.quantize import SF_VEC_SIZE
|
||||
print(f" quantize.py SF_VEC_SIZE = {SF_VEC_SIZE}")
|
||||
assert SF_VEC_SIZE == 16, f"SF_VEC_SIZE should be 16 for NVFP4, got {SF_VEC_SIZE}"
|
||||
print(f" ✅ SF_VEC_SIZE = 16 (NVFP4 correct)")
|
||||
|
||||
# Check what dense.py BlockScaledGEMM expects
|
||||
from dsv4.kernels.gemm.dense import BlockScaledGEMM
|
||||
# Print the docstring-specified types
|
||||
print(f" dense.py NVF4 spec: A/B=Float4E2M1FN, SF=Float8E8M0FNU/Float8E4M3FN, sf_vec_size=16")
|
||||
print(f" dense.py MXF4 spec: A/B=Float4E2M1FN, SF=Float8E8M0FNU, sf_vec_size=32")
|
||||
|
||||
# Construct a small NVFP4 GEMM and print the sf_dtype it receives
|
||||
# Construct a BlockScaledGEMM with E4M3 scales (NVFP4)
|
||||
M, N, K = 128, 256, 512
|
||||
# Create small FP4 tensors
|
||||
a_fp4 = torch.randn(M, K // 2, device='cuda', dtype=torch.float4_e2m1fn_x2)
|
||||
b_fp4 = torch.randn(K // 2, N, device='cuda', dtype=torch.float4_e2m1fn_x2)
|
||||
a_sf = torch.randn(M, K // 16, device='cuda', dtype=torch.float8_e4m3fn)
|
||||
b_sf = torch.randn(N, K // 16, device='cuda', dtype=torch.float8_e4m3fn)
|
||||
|
||||
from dsv4.kernels.gemm.dense import BlockScaledGEMM
|
||||
try:
|
||||
from dsv4.kernels.gemm.dense import BlockScaledGEMM
|
||||
gemm = BlockScaledGEMM(
|
||||
a_ptr=a_fp4,
|
||||
b_ptr=b_fp4,
|
||||
@@ -46,10 +41,17 @@ def test_nvfp4_primitives():
|
||||
print(f" BlockScaledGEMM.sf_vec_size = {gemm.sf_vec_size}")
|
||||
print(f" BlockScaledGEMM.sf_dtype = {gemm.sf_dtype}")
|
||||
print(f" BlockScaledGEMM.mma_inst_shape_mn = {gemm.mma_inst_shape_mn}")
|
||||
if gemm.sf_dtype == Float8E4M3FN:
|
||||
print(f" ✅ sf_dtype is Float8E4M3FN (NVFP4 correct)")
|
||||
elif gemm.sf_dtype == Float8E8M0FNU:
|
||||
print(f" ⚠️ sf_dtype is Float8E8M0FNU — this is MXFP4 scale format, NOT NVFP4!")
|
||||
print(f" ⚠️ NVFP4 should use Float8E4M3FN scales at sf_vec_size=16")
|
||||
else:
|
||||
print(f" ❌ sf_dtype is {gemm.sf_dtype} — unexpected!")
|
||||
except Exception as e:
|
||||
print(f" BlockScaledGEMM construction failed: {e}")
|
||||
|
||||
# Also try with E8M0 scale dtype to see if that's what's being used
|
||||
# Also try with E8M0 scales (MXFP4)
|
||||
a_sf_e8m0 = torch.randn(M, K // 32, device='cuda', dtype=torch.float8_e8m0fnu)
|
||||
b_sf_e8m0 = torch.randn(N, K // 32, device='cuda', dtype=torch.float8_e8m0fnu)
|
||||
try:
|
||||
@@ -60,17 +62,17 @@ def test_nvfp4_primitives():
|
||||
sfb_ptr=b_sf_e8m0,
|
||||
sf_vec_size=32,
|
||||
)
|
||||
print(f" BlockScaledGEMM (E8M0/sf_vec_size=32).sf_dtype = {gemm_e8m0.sf_dtype}")
|
||||
print(f" BlockScaledGEMM (E8M0/sf_vec_size=32).sf_vec_size = {gemm_e8m0.sf_vec_size}")
|
||||
print(f" BlockScaledGEMM (E8M0/vs=32).sf_dtype = {gemm_e8m0.sf_dtype}")
|
||||
print(f" BlockScaledGEMM (E8M0/vs=32).sf_vec_size = {gemm_e8m0.sf_vec_size}")
|
||||
except Exception as e:
|
||||
print(f" BlockScaledGEMM (E8M0) construction failed: {e}")
|
||||
print(f" BlockScaledGEMM (E8M0) failed: {e}")
|
||||
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("NVFP4-0.3: FP4 TMA element type in quantize.py")
|
||||
print("=" * 60)
|
||||
|
||||
from dsv4.ops.quantize import quantize_tensor_nvfp4, _create_fp4_tensor
|
||||
from dsv4.ops.quantize import quantize_tensor_nvfp4
|
||||
x = torch.randn(4, 512, device='cuda', dtype=torch.bfloat16)
|
||||
x_fp4, x_sf = quantize_tensor_nvfp4(x)
|
||||
print(f" Input dtype: {x.dtype}")
|
||||
@@ -78,19 +80,17 @@ def test_nvfp4_primitives():
|
||||
print(f" SF output dtype: {x_sf.dtype}")
|
||||
print(f" FP4 shape: {x_fp4.shape} (expected: [4, 256])")
|
||||
print(f" SF shape: {x_sf.shape} (expected: [4, 32])")
|
||||
|
||||
# Check that float4_e2m1fn_x2 survives into TMA
|
||||
print(f" torch.float4_e2m1fn_x2 = {torch.float4_e2m1fn_x2}")
|
||||
print(f" FP4 tensor is float4_e2m1fn_x2: {x_fp4.dtype == torch.float4_e2m1fn_x2}")
|
||||
if x_fp4.dtype == torch.float4_e2m1fn_x2:
|
||||
print(f" ✅ FP4 tensor is float4_e2m1fn_x2 — correct for TMA")
|
||||
else:
|
||||
print(f" ❌ FP4 tensor dtype is {x_fp4.dtype} — should be float4_e2m1fn_x2!")
|
||||
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("NVFP4-0.4: MMA kind verification")
|
||||
print("=" * 60)
|
||||
|
||||
# Try to inspect the MMA kind that tcgen05 uses for NVFP4
|
||||
try:
|
||||
# Create a tiled MMA with FP4 inputs
|
||||
a_major = cutlass.utils.LayoutEnum.ROW_MAJOR.mma_major_mode()
|
||||
b_major = cutlass.utils.LayoutEnum.COLUMN_MAJOR.mma_major_mode()
|
||||
mma = cutlass.utils.sm100.make_trivial_tiled_mma(
|
||||
@@ -99,8 +99,7 @@ def test_nvfp4_primitives():
|
||||
tcgen05.OperandSource.SMEM,
|
||||
)
|
||||
print(f" FP4 MMA shape_mnk = {mma.shape_mnk}")
|
||||
print(f" FP4 MMA thr_id shape = {mma.thr_id.shape}")
|
||||
print(f" FP4 MMA operand A dtype = {mma.a_layout.value_type if hasattr(mma, 'a_layout') else 'N/A'}")
|
||||
print(f" ✅ FP4 MMA construction succeeded")
|
||||
except Exception as e:
|
||||
print(f" FP4 MMA construction failed: {e}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user