fix: use Sm100BlockScaledPersistentDenseGemmKernel in diag
This commit is contained in:
@@ -29,43 +29,56 @@ def test_nvfp4_primitives():
|
||||
a_sf = torch.randint(0, 256, (M, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn)
|
||||
b_sf = torch.randint(0, 256, (N, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn)
|
||||
|
||||
from dsv4.kernels.gemm.dense import BlockScaledGEMM
|
||||
from dsv4.kernels.gemm.dense import Sm100BlockScaledPersistentDenseGemmKernel
|
||||
try:
|
||||
gemm = BlockScaledGEMM(
|
||||
gemm = Sm100BlockScaledPersistentDenseGemmKernel(
|
||||
a_ptr=a_fp4,
|
||||
b_ptr=b_fp4,
|
||||
sfa_ptr=a_sf,
|
||||
sfb_ptr=b_sf,
|
||||
sf_vec_size=16,
|
||||
)
|
||||
print(f" BlockScaledGEMM.sf_vec_size = {gemm.sf_vec_size}")
|
||||
print(f" BlockScaledGEMM.sf_dtype = {gemm.sf_dtype}")
|
||||
print(f" BlockScaledGEMM.mma_inst_shape_mn = {gemm.mma_inst_shape_mn}")
|
||||
print(f" DenseGEMM.sf_vec_size = {gemm.sf_vec_size}")
|
||||
print(f" DenseGEMM.sf_dtype = {gemm.sf_dtype}")
|
||||
if gemm.sf_dtype == Float8E4M3FN:
|
||||
print(f" ✅ sf_dtype is Float8E4M3FN (NVFP4 correct)")
|
||||
elif gemm.sf_dtype == Float8E8M0FNU:
|
||||
print(f" ⚠️ sf_dtype is Float8E8M0FNU — this is MXFP4 scale format, NOT NVFP4!")
|
||||
print(f" ⚠️ NVFP4 should use Float8E4M3FN scales at sf_vec_size=16")
|
||||
else:
|
||||
print(f" ❌ sf_dtype is {gemm.sf_dtype} — unexpected!")
|
||||
except Exception as e:
|
||||
print(f" BlockScaledGEMM construction failed: {e}")
|
||||
print(f" DenseGEMM construction failed: {e}")
|
||||
|
||||
# Also try with E8M0 scales (MXFP4)
|
||||
a_sf_e8m0 = torch.randint(0, 256, (M, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu)
|
||||
b_sf_e8m0 = torch.randint(0, 256, (N, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu)
|
||||
try:
|
||||
gemm_e8m0 = BlockScaledGEMM(
|
||||
gemm_e8m0 = Sm100BlockScaledPersistentDenseGemmKernel(
|
||||
a_ptr=a_fp4,
|
||||
b_ptr=b_fp4,
|
||||
sfa_ptr=a_sf_e8m0,
|
||||
sfb_ptr=b_sf_e8m0,
|
||||
sf_vec_size=32,
|
||||
)
|
||||
print(f" BlockScaledGEMM (E8M0/vs=32).sf_dtype = {gemm_e8m0.sf_dtype}")
|
||||
print(f" BlockScaledGEMM (E8M0/vs=32).sf_vec_size = {gemm_e8m0.sf_vec_size}")
|
||||
print(f" DenseGEMM (E8M0/vs=32).sf_dtype = {gemm_e8m0.sf_dtype}")
|
||||
print(f" DenseGEMM (E8M0/vs=32).sf_vec_size = {gemm_e8m0.sf_vec_size}")
|
||||
except Exception as e:
|
||||
print(f" BlockScaledGEMM (E8M0) failed: {e}")
|
||||
print(f" DenseGEMM (E8M0) failed: {e}")
|
||||
|
||||
# Check the grouped kernel too
|
||||
from dsv4.kernels.gemm.grouped import ScaledGroupedGemmKernel
|
||||
try:
|
||||
grp = ScaledGroupedGemmKernel(
|
||||
a_ptr=a_fp4,
|
||||
b_ptr=b_fp4,
|
||||
sfa_ptr=a_sf,
|
||||
sfb_ptr=b_sf,
|
||||
sf_vec_size=16,
|
||||
)
|
||||
print(f" GroupedGEMM.sf_vec_size = {grp.sf_vec_size}")
|
||||
print(f" GroupedGEMM.sf_dtype = {grp.sf_dtype}")
|
||||
except Exception as e:
|
||||
print(f" GroupedGEMM construction failed: {e}")
|
||||
|
||||
print()
|
||||
print("=" * 60)
|
||||
|
||||
Reference in New Issue
Block a user