diff --git a/dsv4/kernels/router/nvfp4_fused_router_kernel.py b/dsv4/kernels/router/nvfp4_fused_router_kernel.py index 308c1b11..9f88dbec 100644 --- a/dsv4/kernels/router/nvfp4_fused_router_kernel.py +++ b/dsv4/kernels/router/nvfp4_fused_router_kernel.py @@ -69,19 +69,12 @@ class Nvfp4FusedRouterKernel: self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE self.arch = "sm_100" - # Set up MMA instruction shapes before any MMA creation calls - # ALL shape values must be cutlass.Int32 so CuTe layout construction - # produces static (not dynamic) dimensions. Python ints cause - # "Expected an MLIR object (got None)" in _pack_shape. - self.mma_inst_shape_mn = ( - cutlass.Int32(mma_tiler_mnk[0]), - cutlass.Int32(mma_tiler_mnk[1]), - ) - # round_up with Python math, then wrap in cutlass.Int32 - sfb_n = ((mma_tiler_mnk[1] + 128 - 1) // 128) * 128 + # Set up MMA instruction shapes + # These are now used inside @cute.jit context, so cute.round_up is fine + self.mma_inst_shape_mn = (mma_tiler_mnk[0], mma_tiler_mnk[1]) self.mma_inst_shape_mn_sfb = ( - cutlass.Int32(mma_tiler_mnk[0] // (2 if self.use_2cta_instrs else 1)), - cutlass.Int32(sfb_n), + self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1), + cute.round_up(self.mma_inst_shape_mn[1], 128), ) # 6-warp specialization (no scheduler warp for dense GEMM) @@ -132,12 +125,12 @@ class Nvfp4FusedRouterKernel: self.mma_tiler = ( self.mma_inst_shape_mn[0], self.mma_inst_shape_mn[1], - cutlass.Int32(self.mma_tiler_mnk[2]), + self.mma_tiler_mnk[2], ) self.mma_tiler_sfb = ( self.mma_inst_shape_mn_sfb[0], self.mma_inst_shape_mn_sfb[1], - cutlass.Int32(self.mma_tiler_mnk[2]), + self.mma_tiler_mnk[2], ) self.cta_tile_shape_mnk = ( self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), @@ -151,18 +144,10 @@ class Nvfp4FusedRouterKernel: ) self.cluster_layout_vmnk = cute.tiled_divide( - cute.make_layout(( - cutlass.Int32(self.cluster_shape_mn[0]), - cutlass.Int32(self.cluster_shape_mn[1]), - cutlass.Int32(1), - )), + cute.make_layout((*self.cluster_shape_mn, 1)), (tiled_mma.thr_id.shape,)) self.cluster_layout_sfb_vmnk = cute.tiled_divide( - cute.make_layout(( - cutlass.Int32(self.cluster_shape_mn[0]), - cutlass.Int32(self.cluster_shape_mn[1]), - cutlass.Int32(1), - )), + cute.make_layout((*self.cluster_shape_mn, 1)), (tiled_mma_sfb.thr_id.shape,)) self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2])