Revert cutlass.Int32 wrapping — now inside @cute.jit, cute.round_up works
All CuTe DSL calls now happen inside @cute.jit context, so cute.round_up and all layout operations have proper MLIR context. No need for manual Int32 wrapping or Python math workarounds.
This commit is contained in:
@@ -69,19 +69,12 @@ class Nvfp4FusedRouterKernel:
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.arch = "sm_100"
|
||||
|
||||
# Set up MMA instruction shapes before any MMA creation calls
|
||||
# ALL shape values must be cutlass.Int32 so CuTe layout construction
|
||||
# produces static (not dynamic) dimensions. Python ints cause
|
||||
# "Expected an MLIR object (got None)" in _pack_shape.
|
||||
self.mma_inst_shape_mn = (
|
||||
cutlass.Int32(mma_tiler_mnk[0]),
|
||||
cutlass.Int32(mma_tiler_mnk[1]),
|
||||
)
|
||||
# round_up with Python math, then wrap in cutlass.Int32
|
||||
sfb_n = ((mma_tiler_mnk[1] + 128 - 1) // 128) * 128
|
||||
# Set up MMA instruction shapes
|
||||
# These are now used inside @cute.jit context, so cute.round_up is fine
|
||||
self.mma_inst_shape_mn = (mma_tiler_mnk[0], mma_tiler_mnk[1])
|
||||
self.mma_inst_shape_mn_sfb = (
|
||||
cutlass.Int32(mma_tiler_mnk[0] // (2 if self.use_2cta_instrs else 1)),
|
||||
cutlass.Int32(sfb_n),
|
||||
self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1),
|
||||
cute.round_up(self.mma_inst_shape_mn[1], 128),
|
||||
)
|
||||
|
||||
# 6-warp specialization (no scheduler warp for dense GEMM)
|
||||
@@ -132,12 +125,12 @@ class Nvfp4FusedRouterKernel:
|
||||
self.mma_tiler = (
|
||||
self.mma_inst_shape_mn[0],
|
||||
self.mma_inst_shape_mn[1],
|
||||
cutlass.Int32(self.mma_tiler_mnk[2]),
|
||||
self.mma_tiler_mnk[2],
|
||||
)
|
||||
self.mma_tiler_sfb = (
|
||||
self.mma_inst_shape_mn_sfb[0],
|
||||
self.mma_inst_shape_mn_sfb[1],
|
||||
cutlass.Int32(self.mma_tiler_mnk[2]),
|
||||
self.mma_tiler_mnk[2],
|
||||
)
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
|
||||
@@ -151,18 +144,10 @@ class Nvfp4FusedRouterKernel:
|
||||
)
|
||||
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(
|
||||
cute.make_layout((
|
||||
cutlass.Int32(self.cluster_shape_mn[0]),
|
||||
cutlass.Int32(self.cluster_shape_mn[1]),
|
||||
cutlass.Int32(1),
|
||||
)),
|
||||
cute.make_layout((*self.cluster_shape_mn, 1)),
|
||||
(tiled_mma.thr_id.shape,))
|
||||
self.cluster_layout_sfb_vmnk = cute.tiled_divide(
|
||||
cute.make_layout((
|
||||
cutlass.Int32(self.cluster_shape_mn[0]),
|
||||
cutlass.Int32(self.cluster_shape_mn[1]),
|
||||
cutlass.Int32(1),
|
||||
)),
|
||||
cute.make_layout((*self.cluster_shape_mn, 1)),
|
||||
(tiled_mma_sfb.thr_id.shape,))
|
||||
|
||||
self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2])
|
||||
|
||||
Reference in New Issue
Block a user