diff --git a/tests/unit/test_nvfp4_primitives.py b/tests/unit/test_nvfp4_primitives.py index 692080d6..3486a28a 100644 --- a/tests/unit/test_nvfp4_primitives.py +++ b/tests/unit/test_nvfp4_primitives.py @@ -29,43 +29,56 @@ def test_nvfp4_primitives(): a_sf = torch.randint(0, 256, (M, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn) b_sf = torch.randint(0, 256, (N, K // 16), device='cuda', dtype=torch.uint8).view(torch.float8_e4m3fn) - from dsv4.kernels.gemm.dense import BlockScaledGEMM + from dsv4.kernels.gemm.dense import Sm100BlockScaledPersistentDenseGemmKernel try: - gemm = BlockScaledGEMM( + gemm = Sm100BlockScaledPersistentDenseGemmKernel( a_ptr=a_fp4, b_ptr=b_fp4, sfa_ptr=a_sf, sfb_ptr=b_sf, sf_vec_size=16, ) - print(f" BlockScaledGEMM.sf_vec_size = {gemm.sf_vec_size}") - print(f" BlockScaledGEMM.sf_dtype = {gemm.sf_dtype}") - print(f" BlockScaledGEMM.mma_inst_shape_mn = {gemm.mma_inst_shape_mn}") + print(f" DenseGEMM.sf_vec_size = {gemm.sf_vec_size}") + print(f" DenseGEMM.sf_dtype = {gemm.sf_dtype}") if gemm.sf_dtype == Float8E4M3FN: print(f" ✅ sf_dtype is Float8E4M3FN (NVFP4 correct)") elif gemm.sf_dtype == Float8E8M0FNU: print(f" ⚠️ sf_dtype is Float8E8M0FNU — this is MXFP4 scale format, NOT NVFP4!") - print(f" ⚠️ NVFP4 should use Float8E4M3FN scales at sf_vec_size=16") else: print(f" ❌ sf_dtype is {gemm.sf_dtype} — unexpected!") except Exception as e: - print(f" BlockScaledGEMM construction failed: {e}") + print(f" DenseGEMM construction failed: {e}") # Also try with E8M0 scales (MXFP4) a_sf_e8m0 = torch.randint(0, 256, (M, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu) b_sf_e8m0 = torch.randint(0, 256, (N, K // 32), device='cuda', dtype=torch.uint8).view(torch.float8_e8m0fnu) try: - gemm_e8m0 = BlockScaledGEMM( + gemm_e8m0 = Sm100BlockScaledPersistentDenseGemmKernel( a_ptr=a_fp4, b_ptr=b_fp4, sfa_ptr=a_sf_e8m0, sfb_ptr=b_sf_e8m0, sf_vec_size=32, ) - print(f" BlockScaledGEMM (E8M0/vs=32).sf_dtype = {gemm_e8m0.sf_dtype}") - print(f" BlockScaledGEMM (E8M0/vs=32).sf_vec_size = {gemm_e8m0.sf_vec_size}") + print(f" DenseGEMM (E8M0/vs=32).sf_dtype = {gemm_e8m0.sf_dtype}") + print(f" DenseGEMM (E8M0/vs=32).sf_vec_size = {gemm_e8m0.sf_vec_size}") except Exception as e: - print(f" BlockScaledGEMM (E8M0) failed: {e}") + print(f" DenseGEMM (E8M0) failed: {e}") + + # Check the grouped kernel too + from dsv4.kernels.gemm.grouped import ScaledGroupedGemmKernel + try: + grp = ScaledGroupedGemmKernel( + a_ptr=a_fp4, + b_ptr=b_fp4, + sfa_ptr=a_sf, + sfb_ptr=b_sf, + sf_vec_size=16, + ) + print(f" GroupedGEMM.sf_vec_size = {grp.sf_vec_size}") + print(f" GroupedGEMM.sf_dtype = {grp.sf_dtype}") + except Exception as e: + print(f" GroupedGEMM construction failed: {e}") print() print("=" * 60)