diag: add 2-CTA check + fix LayoutEnum in MMA test
This commit is contained in:
@@ -128,9 +128,11 @@ def test_nvfp4_primitives():
|
||||
print("NVFP4-0.4: MMA kind verification")
|
||||
print("=" * 60)
|
||||
|
||||
# Test with correct LayoutEnum
|
||||
from cutlass.utils import LayoutEnum
|
||||
try:
|
||||
a_major = cutlass.utils.LayoutEnum.ROW_MAJOR.mma_major_mode()
|
||||
b_major = cutlass.utils.LayoutEnum.COLUMN_MAJOR.mma_major_mode()
|
||||
a_major = LayoutEnum.ROW_MAJOR.mma_major_mode()
|
||||
b_major = LayoutEnum.COLUMN_MAJOR.mma_major_mode()
|
||||
mma = cutlass.utils.sm100.make_trivial_tiled_mma(
|
||||
Float4E2M1FN, Float4E2M1FN, a_major, b_major, Float32,
|
||||
tcgen05.CtaGroup.ONE, (128, 256),
|
||||
@@ -141,6 +143,24 @@ def test_nvfp4_primitives():
|
||||
except Exception as e:
|
||||
print(f" FP4 MMA construction failed: {e}")
|
||||
|
||||
# Check use_2cta_instrs behavior
|
||||
print()
|
||||
print(" --- use_2cta_instrs check ---")
|
||||
try:
|
||||
gemm_1cta = Sm100BlockScaledPersistentDenseGemmKernel(
|
||||
sf_vec_size=16, mma_tiler_mn=(128, 256), cluster_shape_mn=(1, 1),
|
||||
)
|
||||
print(f" mma_tiler_mn=(128,256) → use_2cta_instrs={gemm_1cta.use_2cta_instrs}")
|
||||
except Exception as e:
|
||||
print(f" 1-CTA failed: {e}")
|
||||
try:
|
||||
gemm_2cta = Sm100BlockScaledPersistentDenseGemmKernel(
|
||||
sf_vec_size=16, mma_tiler_mn=(256, 256), cluster_shape_mn=(2, 1),
|
||||
)
|
||||
print(f" mma_tiler_mn=(256,256) → use_2cta_instrs={gemm_2cta.use_2cta_instrs}")
|
||||
except Exception as e:
|
||||
print(f" 2-CTA failed: {e}")
|
||||
|
||||
print()
|
||||
print("DONE — NVFP4-0 verification complete")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user