diff --git a/dsv4/kernels/router/nvfp4_fused_router_kernel.py b/dsv4/kernels/router/nvfp4_fused_router_kernel.py index f7243d33..3c730206 100644 --- a/dsv4/kernels/router/nvfp4_fused_router_kernel.py +++ b/dsv4/kernels/router/nvfp4_fused_router_kernel.py @@ -69,6 +69,13 @@ class Nvfp4FusedRouterKernel: self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE self.arch = "sm_100" + # Set up MMA instruction shapes before any MMA creation calls + self.mma_inst_shape_mn = (mma_tiler_mnk[0], mma_tiler_mnk[1]) + self.mma_inst_shape_mn_sfb = ( + self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1), + cute.round_up(self.mma_inst_shape_mn[1], 128), + ) + # 6-warp specialization (no scheduler warp for dense GEMM) self.epilogue_warp_id = (0, 1, 2, 3) self.mma_warp_id = 4 @@ -109,11 +116,7 @@ class Nvfp4FusedRouterKernel: def _setup_attributes(self, tiled_mma, tiled_mma_sfb, a_dtype, b_dtype, sf_dtype): """Set up kernel attributes. Mirrors FusedSwiGLUScaledGroupedGemmKernel._setup_attributes.""" - self.mma_inst_shape_mn = (self.mma_tiler_mnk[0], self.mma_tiler_mnk[1]) - self.mma_inst_shape_mn_sfb = ( - self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1), - cute.round_up(self.mma_inst_shape_mn[1], 128), - ) + # mma_inst_shape_mn is set in __init__ before _create_tiled_mma is called mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) mma_inst_tile_k = self.mma_tiler_mnk[2] // mma_inst_shape_k