fix: use correct API for NVFP4-0 diag (sf_vec_size + mma_tiler_mn)
This commit is contained in:
@@ -14,85 +14,108 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspa
|
||||
|
||||
def test_nvfp4_primitives():
|
||||
print("=" * 60)
|
||||
print("NVFP4-0.1: sf_dtype and sf_vec_size")
|
||||
print("NVFP4-0.1: sf_dtype and sf_vec_size in GEMM kernels")
|
||||
print("=" * 60)
|
||||
|
||||
from dsv4.ops.quantize import SF_VEC_SIZE
|
||||
print(f" quantize.py SF_VEC_SIZE = {SF_VEC_SIZE}")
|
||||
assert SF_VEC_SIZE == 16, f"SF_VEC_SIZE should be 16 for NVFP4, got {SF_VEC_SIZE}"
|
||||
assert SF_VEC_SIZE == 16
|
||||
print(f" ✅ SF_VEC_SIZE = 16 (NVFP4 correct)")
|
||||
|
||||
# Construct a BlockScaledGEMM with E4M3 scales (NVFP4)
|
||||
M, N, K = 128, 256, 512
|
||||
a_fp4 = torch.randint(0, 256, (M, K // 2), device='cuda', dtype=torch.uint8).view(torch.float4_e2m1fn_x2)
|
||||
b_fp4 = torch.randint(0, 256, (K // 2, N), device='cuda', dtype=torch.uint8).view(torch.float4_e2m1fn_x2)
|
||||
a_sf = torch.randint(0, 256, (M, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn)
|
||||
b_sf = torch.randint(0, 256, (N, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn)
|
||||
|
||||
# Construct dense GEMM with the real API
|
||||
from dsv4.kernels.gemm.dense import Sm100BlockScaledPersistentDenseGemmKernel
|
||||
try:
|
||||
gemm = Sm100BlockScaledPersistentDenseGemmKernel(
|
||||
a_ptr=a_fp4,
|
||||
b_ptr=b_fp4,
|
||||
sfa_ptr=a_sf,
|
||||
sfb_ptr=b_sf,
|
||||
sf_vec_size=16,
|
||||
mma_tiler_mn=(128, 256),
|
||||
cluster_shape_mn=(1, 1),
|
||||
)
|
||||
print(f" DenseGEMM.sf_vec_size = {gemm.sf_vec_size}")
|
||||
print(f" DenseGEMM.sf_dtype = {gemm.sf_dtype}")
|
||||
if gemm.sf_dtype == Float8E4M3FN:
|
||||
print(f" ✅ sf_dtype is Float8E4M3FN (NVFP4 correct)")
|
||||
elif gemm.sf_dtype == Float8E8M0FNU:
|
||||
print(f" ⚠️ sf_dtype is Float8E8M0FNU — this is MXFP4 scale format, NOT NVFP4!")
|
||||
else:
|
||||
print(f" ❌ sf_dtype is {gemm.sf_dtype} — unexpected!")
|
||||
print(f" DenseGEMM.cta_group = {gemm.cta_group}")
|
||||
# sf_dtype is set in _setup_attributes from the actual tensor pointers
|
||||
# Print what we can here
|
||||
print(f" DenseGEMM.acc_dtype = {gemm.acc_dtype}")
|
||||
print(f" DenseGEMM.use_2cta_instrs = {gemm.use_2cta_instrs}")
|
||||
except Exception as e:
|
||||
print(f" DenseGEMM construction failed: {e}")
|
||||
|
||||
# Also try with E8M0 scales (MXFP4)
|
||||
a_sf_e8m0 = torch.randint(0, 256, (M, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu)
|
||||
b_sf_e8m0 = torch.randint(0, 256, (N, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu)
|
||||
# Try sf_vec_size=32 (MXFP4) for comparison
|
||||
try:
|
||||
gemm_e8m0 = Sm100BlockScaledPersistentDenseGemmKernel(
|
||||
a_ptr=a_fp4,
|
||||
b_ptr=b_fp4,
|
||||
sfa_ptr=a_sf_e8m0,
|
||||
sfb_ptr=b_sf_e8m0,
|
||||
gemm_mxf4 = Sm100BlockScaledPersistentDenseGemmKernel(
|
||||
sf_vec_size=32,
|
||||
mma_tiler_mn=(128, 256),
|
||||
cluster_shape_mn=(1, 1),
|
||||
)
|
||||
print(f" DenseGEMM (E8M0/vs=32).sf_dtype = {gemm_e8m0.sf_dtype}")
|
||||
print(f" DenseGEMM (E8M0/vs=32).sf_vec_size = {gemm_e8m0.sf_vec_size}")
|
||||
print(f" DenseGEMM (MXF4/vs=32).sf_vec_size = {gemm_mxf4.sf_vec_size}")
|
||||
except Exception as e:
|
||||
print(f" DenseGEMM (E8M0) failed: {e}")
|
||||
print(f" DenseGEMM (MXF4) failed: {e}")
|
||||
|
||||
# Check the grouped kernel too
|
||||
# Check grouped kernel
|
||||
from dsv4.kernels.gemm.grouped import ScaledGroupedGemmKernel
|
||||
try:
|
||||
grp = ScaledGroupedGemmKernel(
|
||||
a_ptr=a_fp4,
|
||||
b_ptr=b_fp4,
|
||||
sfa_ptr=a_sf,
|
||||
sfb_ptr=b_sf,
|
||||
sf_vec_size=16,
|
||||
mma_tiler_mn=(128, 256),
|
||||
cluster_shape_mn=(1, 1),
|
||||
)
|
||||
print(f" GroupedGEMM.sf_vec_size = {grp.sf_vec_size}")
|
||||
print(f" GroupedGEMM.sf_dtype = {grp.sf_dtype}")
|
||||
except Exception as e:
|
||||
print(f" GroupedGEMM construction failed: {e}")
|
||||
|
||||
# Check fused SwiGLU kernel
|
||||
from dsv4.kernels.gemm.fused_swiglu import FusedSwiGLUScaledGroupedGemmKernel
|
||||
try:
|
||||
fused = FusedSwiGLUScaledGroupedGemmKernel(
|
||||
sf_vec_size=16,
|
||||
mma_tiler_mn=(128, 256),
|
||||
cluster_shape_mn=(1, 1),
|
||||
swiglu_limit=10.0,
|
||||
)
|
||||
print(f" FusedSwiGLU.sf_vec_size = {fused.sf_vec_size}")
|
||||
print(f" FusedSwiGLU.swiglu_limit = {fused.swiglu_limit}")
|
||||
except Exception as e:
|
||||
print(f" FusedSwiGLU construction failed: {e}")
|
||||
|
||||
# Now check sf_dtype by actually creating the tensors and calling the kernel
|
||||
# sf_dtype is inferred from the sfa_ptr dtype at call time
|
||||
print()
|
||||
print(" --- sf_dtype inference from tensor dtypes ---")
|
||||
M, N, K = 128, 256, 512
|
||||
a_fp4 = torch.randint(0, 256, (M, K // 2), device='cuda', dtype=torch.uint8).view(torch.float4_e2m1fn_x2)
|
||||
b_fp4 = torch.randint(0, 256, (K // 2, N), device='cuda', dtype=torch.uint8).view(torch.float4_e2m1fn_x2)
|
||||
a_sf_e4m3 = torch.randint(0, 256, (M, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn)
|
||||
b_sf_e4m3 = torch.randint(0, 256, (N, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn)
|
||||
a_sf_e8m0 = torch.randint(0, 256, (M, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu)
|
||||
b_sf_e8m0 = torch.randint(0, 256, (N, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu)
|
||||
|
||||
print(f" a_fp4 dtype = {a_fp4.dtype} (expect float4_e2m1fn_x2)")
|
||||
print(f" a_sf (E4M3) dtype = {a_sf_e4m3.dtype}, shape = {a_sf_e4m3.shape}")
|
||||
print(f" a_sf (E8M0) dtype = {a_sf_e8m0.dtype}, shape = {a_sf_e8m0.shape}")
|
||||
|
||||
# The runner uses E4M3 or E8M0? Check what quantize actually produces
|
||||
from dsv4.ops.quantize import quantize_to_nvfp4
|
||||
x = torch.randn(4, 512, device='cuda', dtype=torch.bfloat16)
|
||||
x_fp4, x_sf = quantize_to_nvfp4(x)
|
||||
print(f" quantize_to_nvfp4 output: FP4 dtype={x_fp4.dtype}, SF dtype={x_sf.dtype}")
|
||||
if x_sf.dtype == torch.float8_e4m3fn:
|
||||
print(f" ✅ Scale factors are E4M3 — NVFP4 correct")
|
||||
elif x_sf.dtype == torch.float8_e8m0fnu:
|
||||
print(f" ⚠️ Scale factors are E8M0 — this is MXFP4 format, NOT NVFP4!")
|
||||
else:
|
||||
print(f" ❌ Scale factors are {x_sf.dtype} — unexpected!")
|
||||
|
||||
# Check what the gemm_runner actually passes to the kernel
|
||||
from dsv4.ops.gemm_runner import Nvfp4GemmRunner
|
||||
print(f" Nvfp4GemmRunner exists: {Nvfp4GemmRunner is not None}")
|
||||
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("NVFP4-0.3: FP4 TMA element type in quantize.py")
|
||||
print("NVFP4-0.3: FP4 TMA element type")
|
||||
print("=" * 60)
|
||||
|
||||
from dsv4.ops.quantize import quantize_tensor_nvfp4
|
||||
x = torch.randn(4, 512, device='cuda', dtype=torch.bfloat16)
|
||||
x_fp4, x_sf = quantize_tensor_nvfp4(x)
|
||||
print(f" Input dtype: {x.dtype}")
|
||||
print(f" FP4 output dtype: {x_fp4.dtype}")
|
||||
print(f" SF output dtype: {x_sf.dtype}")
|
||||
print(f" FP4 shape: {x_fp4.shape} (expected: [4, 256])")
|
||||
print(f" SF shape: {x_sf.shape} (expected: [4, 32])")
|
||||
print(f" FP4 tensor dtype = {x_fp4.dtype}")
|
||||
print(f" FP4 tensor shape = {x_fp4.shape}")
|
||||
print(f" torch.float4_e2m1fn_x2 available = {hasattr(torch, 'float4_e2m1fn_x2')}")
|
||||
if x_fp4.dtype == torch.float4_e2m1fn_x2:
|
||||
print(f" ✅ FP4 tensor is float4_e2m1fn_x2 — correct for TMA")
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user