Fix: set mma_inst_shape_mn in __init__ before _create_tiled_mma call

This commit is contained in:
2026-06-01 09:22:24 +00:00
parent 28f78420c2
commit df48dacc2b

View File

@@ -69,6 +69,13 @@ class Nvfp4FusedRouterKernel:
self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
self.arch = "sm_100"
# Set up MMA instruction shapes before any MMA creation calls
self.mma_inst_shape_mn = (mma_tiler_mnk[0], mma_tiler_mnk[1])
self.mma_inst_shape_mn_sfb = (
self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1),
cute.round_up(self.mma_inst_shape_mn[1], 128),
)
# 6-warp specialization (no scheduler warp for dense GEMM)
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4
@@ -109,11 +116,7 @@ class Nvfp4FusedRouterKernel:
def _setup_attributes(self, tiled_mma, tiled_mma_sfb, a_dtype, b_dtype, sf_dtype):
"""Set up kernel attributes. Mirrors FusedSwiGLUScaledGroupedGemmKernel._setup_attributes."""
self.mma_inst_shape_mn = (self.mma_tiler_mnk[0], self.mma_tiler_mnk[1])
self.mma_inst_shape_mn_sfb = (
self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1),
cute.round_up(self.mma_inst_shape_mn[1], 128),
)
# mma_inst_shape_mn is set in __init__ before _create_tiled_mma is called
mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2])
mma_inst_tile_k = self.mma_tiler_mnk[2] // mma_inst_shape_k