diag: add 2-CTA check + fix LayoutEnum in MMA test

This commit is contained in:
2026-05-23 08:45:26 +00:00
parent eca84bdcb5
commit 3b98007093

View File

@@ -128,9 +128,11 @@ def test_nvfp4_primitives():
print("NVFP4-0.4: MMA kind verification")
print("=" * 60)
# Test with correct LayoutEnum
from cutlass.utils import LayoutEnum
try:
a_major = cutlass.utils.LayoutEnum.ROW_MAJOR.mma_major_mode()
b_major = cutlass.utils.LayoutEnum.COLUMN_MAJOR.mma_major_mode()
a_major = LayoutEnum.ROW_MAJOR.mma_major_mode()
b_major = LayoutEnum.COLUMN_MAJOR.mma_major_mode()
mma = cutlass.utils.sm100.make_trivial_tiled_mma(
Float4E2M1FN, Float4E2M1FN, a_major, b_major, Float32,
tcgen05.CtaGroup.ONE, (128, 256),
@@ -141,6 +143,24 @@ def test_nvfp4_primitives():
except Exception as e:
print(f" FP4 MMA construction failed: {e}")
# Check use_2cta_instrs behavior
print()
print(" --- use_2cta_instrs check ---")
try:
gemm_1cta = Sm100BlockScaledPersistentDenseGemmKernel(
sf_vec_size=16, mma_tiler_mn=(128, 256), cluster_shape_mn=(1, 1),
)
print(f" mma_tiler_mn=(128,256) → use_2cta_instrs={gemm_1cta.use_2cta_instrs}")
except Exception as e:
print(f" 1-CTA failed: {e}")
try:
gemm_2cta = Sm100BlockScaledPersistentDenseGemmKernel(
sf_vec_size=16, mma_tiler_mn=(256, 256), cluster_shape_mn=(2, 1),
)
print(f" mma_tiler_mn=(256,256) → use_2cta_instrs={gemm_2cta.use_2cta_instrs}")
except Exception as e:
print(f" 2-CTA failed: {e}")
print()
print("DONE — NVFP4-0 verification complete")