Revert cutlass.Int32 wrapping — now inside @cute.jit, cute.round_up works

All CuTe DSL calls now happen inside @cute.jit context, so
cute.round_up and all layout operations have proper MLIR context.
No need for manual Int32 wrapping or Python math workarounds.
This commit is contained in:
2026-06-01 09:35:03 +00:00
parent b536f99192
commit 28d0cb4f41

View File

@@ -69,19 +69,12 @@ class Nvfp4FusedRouterKernel:
self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
self.arch = "sm_100"
# Set up MMA instruction shapes before any MMA creation calls
# ALL shape values must be cutlass.Int32 so CuTe layout construction
# produces static (not dynamic) dimensions. Python ints cause
# "Expected an MLIR object (got None)" in _pack_shape.
self.mma_inst_shape_mn = (
cutlass.Int32(mma_tiler_mnk[0]),
cutlass.Int32(mma_tiler_mnk[1]),
)
# round_up with Python math, then wrap in cutlass.Int32
sfb_n = ((mma_tiler_mnk[1] + 128 - 1) // 128) * 128
# Set up MMA instruction shapes
# These are now used inside @cute.jit context, so cute.round_up is fine
self.mma_inst_shape_mn = (mma_tiler_mnk[0], mma_tiler_mnk[1])
self.mma_inst_shape_mn_sfb = (
cutlass.Int32(mma_tiler_mnk[0] // (2 if self.use_2cta_instrs else 1)),
cutlass.Int32(sfb_n),
self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1),
cute.round_up(self.mma_inst_shape_mn[1], 128),
)
# 6-warp specialization (no scheduler warp for dense GEMM)
@@ -132,12 +125,12 @@ class Nvfp4FusedRouterKernel:
self.mma_tiler = (
self.mma_inst_shape_mn[0],
self.mma_inst_shape_mn[1],
cutlass.Int32(self.mma_tiler_mnk[2]),
self.mma_tiler_mnk[2],
)
self.mma_tiler_sfb = (
self.mma_inst_shape_mn_sfb[0],
self.mma_inst_shape_mn_sfb[1],
cutlass.Int32(self.mma_tiler_mnk[2]),
self.mma_tiler_mnk[2],
)
self.cta_tile_shape_mnk = (
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
@@ -151,18 +144,10 @@ class Nvfp4FusedRouterKernel:
)
self.cluster_layout_vmnk = cute.tiled_divide(
cute.make_layout((
cutlass.Int32(self.cluster_shape_mn[0]),
cutlass.Int32(self.cluster_shape_mn[1]),
cutlass.Int32(1),
)),
cute.make_layout((*self.cluster_shape_mn, 1)),
(tiled_mma.thr_id.shape,))
self.cluster_layout_sfb_vmnk = cute.tiled_divide(
cute.make_layout((
cutlass.Int32(self.cluster_shape_mn[0]),
cutlass.Int32(self.cluster_shape_mn[1]),
cutlass.Int32(1),
)),
cute.make_layout((*self.cluster_shape_mn, 1)),
(tiled_mma_sfb.thr_id.shape,))
self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2])