From df48dacc2beb03dde5b808cc0d1b9c463165cdaa Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 1 Jun 2026 09:22:24 +0000 Subject: [PATCH] Fix: set mma_inst_shape_mn in __init__ before _create_tiled_mma call --- dsv4/kernels/router/nvfp4_fused_router_kernel.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/dsv4/kernels/router/nvfp4_fused_router_kernel.py b/dsv4/kernels/router/nvfp4_fused_router_kernel.py index f7243d33..3c730206 100644 --- a/dsv4/kernels/router/nvfp4_fused_router_kernel.py +++ b/dsv4/kernels/router/nvfp4_fused_router_kernel.py @@ -69,6 +69,13 @@ class Nvfp4FusedRouterKernel: self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE self.arch = "sm_100" + # Set up MMA instruction shapes before any MMA creation calls + self.mma_inst_shape_mn = (mma_tiler_mnk[0], mma_tiler_mnk[1]) + self.mma_inst_shape_mn_sfb = ( + self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1), + cute.round_up(self.mma_inst_shape_mn[1], 128), + ) + # 6-warp specialization (no scheduler warp for dense GEMM) self.epilogue_warp_id = (0, 1, 2, 3) self.mma_warp_id = 4 @@ -109,11 +116,7 @@ class Nvfp4FusedRouterKernel: def _setup_attributes(self, tiled_mma, tiled_mma_sfb, a_dtype, b_dtype, sf_dtype): """Set up kernel attributes. Mirrors FusedSwiGLUScaledGroupedGemmKernel._setup_attributes.""" - self.mma_inst_shape_mn = (self.mma_tiler_mnk[0], self.mma_tiler_mnk[1]) - self.mma_inst_shape_mn_sfb = ( - self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1), - cute.round_up(self.mma_inst_shape_mn[1], 128), - ) + # mma_inst_shape_mn is set in __init__ before _create_tiled_mma is called mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) mma_inst_tile_k = self.mma_tiler_mnk[2] // mma_inst_shape_k