From 28d0cb4f412299c498747871b249d2dadba21818 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 1 Jun 2026 09:35:03 +0000 Subject: [PATCH] =?UTF-8?q?Revert=20cutlass.Int32=20wrapping=20=E2=80=94?= =?UTF-8?q?=20now=20inside=20@cute.jit,=20cute.round=5Fup=20works?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit All CuTe DSL calls now happen inside @cute.jit context, so cute.round_up and all layout operations have proper MLIR context. No need for manual Int32 wrapping or Python math workarounds. --- .../router/nvfp4_fused_router_kernel.py | 33 +++++-------------- 1 file changed, 9 insertions(+), 24 deletions(-) diff --git a/dsv4/kernels/router/nvfp4_fused_router_kernel.py b/dsv4/kernels/router/nvfp4_fused_router_kernel.py index 308c1b11..9f88dbec 100644 --- a/dsv4/kernels/router/nvfp4_fused_router_kernel.py +++ b/dsv4/kernels/router/nvfp4_fused_router_kernel.py @@ -69,19 +69,12 @@ class Nvfp4FusedRouterKernel: self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE self.arch = "sm_100" - # Set up MMA instruction shapes before any MMA creation calls - # ALL shape values must be cutlass.Int32 so CuTe layout construction - # produces static (not dynamic) dimensions. Python ints cause - # "Expected an MLIR object (got None)" in _pack_shape. - self.mma_inst_shape_mn = ( - cutlass.Int32(mma_tiler_mnk[0]), - cutlass.Int32(mma_tiler_mnk[1]), - ) - # round_up with Python math, then wrap in cutlass.Int32 - sfb_n = ((mma_tiler_mnk[1] + 128 - 1) // 128) * 128 + # Set up MMA instruction shapes + # These are now used inside @cute.jit context, so cute.round_up is fine + self.mma_inst_shape_mn = (mma_tiler_mnk[0], mma_tiler_mnk[1]) self.mma_inst_shape_mn_sfb = ( - cutlass.Int32(mma_tiler_mnk[0] // (2 if self.use_2cta_instrs else 1)), - cutlass.Int32(sfb_n), + self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1), + cute.round_up(self.mma_inst_shape_mn[1], 128), ) # 6-warp specialization (no scheduler warp for dense GEMM) @@ -132,12 +125,12 @@ class Nvfp4FusedRouterKernel: self.mma_tiler = ( self.mma_inst_shape_mn[0], self.mma_inst_shape_mn[1], - cutlass.Int32(self.mma_tiler_mnk[2]), + self.mma_tiler_mnk[2], ) self.mma_tiler_sfb = ( self.mma_inst_shape_mn_sfb[0], self.mma_inst_shape_mn_sfb[1], - cutlass.Int32(self.mma_tiler_mnk[2]), + self.mma_tiler_mnk[2], ) self.cta_tile_shape_mnk = ( self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), @@ -151,18 +144,10 @@ class Nvfp4FusedRouterKernel: ) self.cluster_layout_vmnk = cute.tiled_divide( - cute.make_layout(( - cutlass.Int32(self.cluster_shape_mn[0]), - cutlass.Int32(self.cluster_shape_mn[1]), - cutlass.Int32(1), - )), + cute.make_layout((*self.cluster_shape_mn, 1)), (tiled_mma.thr_id.shape,)) self.cluster_layout_sfb_vmnk = cute.tiled_divide( - cute.make_layout(( - cutlass.Int32(self.cluster_shape_mn[0]), - cutlass.Int32(self.cluster_shape_mn[1]), - cutlass.Int32(1), - )), + cute.make_layout((*self.cluster_shape_mn, 1)), (tiled_mma_sfb.thr_id.shape,)) self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2])