fix: use Sm100BlockScaledPersistentDenseGemmKernel in diag

This commit is contained in:
2026-05-23 08:30:43 +00:00
parent 6b1330ba47
commit 5572b74591

View File

@@ -29,43 +29,56 @@ def test_nvfp4_primitives():
a_sf = torch.randint(0, 256, (M, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn)
b_sf = torch.randint(0, 256, (N, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn)
from dsv4.kernels.gemm.dense import BlockScaledGEMM
from dsv4.kernels.gemm.dense import Sm100BlockScaledPersistentDenseGemmKernel
try:
gemm = BlockScaledGEMM(
gemm = Sm100BlockScaledPersistentDenseGemmKernel(
a_ptr=a_fp4,
b_ptr=b_fp4,
sfa_ptr=a_sf,
sfb_ptr=b_sf,
sf_vec_size=16,
)
print(f" BlockScaledGEMM.sf_vec_size = {gemm.sf_vec_size}")
print(f" BlockScaledGEMM.sf_dtype = {gemm.sf_dtype}")
print(f" BlockScaledGEMM.mma_inst_shape_mn = {gemm.mma_inst_shape_mn}")
print(f" DenseGEMM.sf_vec_size = {gemm.sf_vec_size}")
print(f" DenseGEMM.sf_dtype = {gemm.sf_dtype}")
if gemm.sf_dtype == Float8E4M3FN:
print(f" ✅ sf_dtype is Float8E4M3FN (NVFP4 correct)")
elif gemm.sf_dtype == Float8E8M0FNU:
print(f" ⚠️ sf_dtype is Float8E8M0FNU — this is MXFP4 scale format, NOT NVFP4!")
print(f" ⚠️ NVFP4 should use Float8E4M3FN scales at sf_vec_size=16")
else:
print(f" ❌ sf_dtype is {gemm.sf_dtype} — unexpected!")
except Exception as e:
print(f" BlockScaledGEMM construction failed: {e}")
print(f" DenseGEMM construction failed: {e}")
# Also try with E8M0 scales (MXFP4)
a_sf_e8m0 = torch.randint(0, 256, (M, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu)
b_sf_e8m0 = torch.randint(0, 256, (N, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu)
try:
gemm_e8m0 = BlockScaledGEMM(
gemm_e8m0 = Sm100BlockScaledPersistentDenseGemmKernel(
a_ptr=a_fp4,
b_ptr=b_fp4,
sfa_ptr=a_sf_e8m0,
sfb_ptr=b_sf_e8m0,
sf_vec_size=32,
)
print(f" BlockScaledGEMM (E8M0/vs=32).sf_dtype = {gemm_e8m0.sf_dtype}")
print(f" BlockScaledGEMM (E8M0/vs=32).sf_vec_size = {gemm_e8m0.sf_vec_size}")
print(f" DenseGEMM (E8M0/vs=32).sf_dtype = {gemm_e8m0.sf_dtype}")
print(f" DenseGEMM (E8M0/vs=32).sf_vec_size = {gemm_e8m0.sf_vec_size}")
except Exception as e:
print(f" BlockScaledGEMM (E8M0) failed: {e}")
print(f" DenseGEMM (E8M0) failed: {e}")
# Check the grouped kernel too
from dsv4.kernels.gemm.grouped import ScaledGroupedGemmKernel
try:
grp = ScaledGroupedGemmKernel(
a_ptr=a_fp4,
b_ptr=b_fp4,
sfa_ptr=a_sf,
sfb_ptr=b_sf,
sf_vec_size=16,
)
print(f" GroupedGEMM.sf_vec_size = {grp.sf_vec_size}")
print(f" GroupedGEMM.sf_dtype = {grp.sf_dtype}")
except Exception as e:
print(f" GroupedGEMM construction failed: {e}")
print()
print("=" * 60)