diff --git a/tests/unit/test_nvfp4_primitives.py b/tests/unit/test_nvfp4_primitives.py index f0c3d2da..07288205 100644 --- a/tests/unit/test_nvfp4_primitives.py +++ b/tests/unit/test_nvfp4_primitives.py @@ -128,9 +128,11 @@ def test_nvfp4_primitives(): print("NVFP4-0.4: MMA kind verification") print("=" * 60) + # Test with correct LayoutEnum + from cutlass.utils import LayoutEnum try: - a_major = cutlass.utils.LayoutEnum.ROW_MAJOR.mma_major_mode() - b_major = cutlass.utils.LayoutEnum.COLUMN_MAJOR.mma_major_mode() + a_major = LayoutEnum.ROW_MAJOR.mma_major_mode() + b_major = LayoutEnum.COLUMN_MAJOR.mma_major_mode() mma = cutlass.utils.sm100.make_trivial_tiled_mma( Float4E2M1FN, Float4E2M1FN, a_major, b_major, Float32, tcgen05.CtaGroup.ONE, (128, 256), @@ -141,6 +143,24 @@ def test_nvfp4_primitives(): except Exception as e: print(f" FP4 MMA construction failed: {e}") + # Check use_2cta_instrs behavior + print() + print(" --- use_2cta_instrs check ---") + try: + gemm_1cta = Sm100BlockScaledPersistentDenseGemmKernel( + sf_vec_size=16, mma_tiler_mn=(128, 256), cluster_shape_mn=(1, 1), + ) + print(f" mma_tiler_mn=(128,256) → use_2cta_instrs={gemm_1cta.use_2cta_instrs}") + except Exception as e: + print(f" 1-CTA failed: {e}") + try: + gemm_2cta = Sm100BlockScaledPersistentDenseGemmKernel( + sf_vec_size=16, mma_tiler_mn=(256, 256), cluster_shape_mn=(2, 1), + ) + print(f" mma_tiler_mn=(256,256) → use_2cta_instrs={gemm_2cta.use_2cta_instrs}") + except Exception as e: + print(f" 2-CTA failed: {e}") + print() print("DONE — NVFP4-0 verification complete")