Fix: set mma_inst_shape_mn in __init__ before _create_tiled_mma call
This commit is contained in:
@@ -69,6 +69,13 @@ class Nvfp4FusedRouterKernel:
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.arch = "sm_100"
|
||||
|
||||
# Set up MMA instruction shapes before any MMA creation calls
|
||||
self.mma_inst_shape_mn = (mma_tiler_mnk[0], mma_tiler_mnk[1])
|
||||
self.mma_inst_shape_mn_sfb = (
|
||||
self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1),
|
||||
cute.round_up(self.mma_inst_shape_mn[1], 128),
|
||||
)
|
||||
|
||||
# 6-warp specialization (no scheduler warp for dense GEMM)
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4
|
||||
@@ -109,11 +116,7 @@ class Nvfp4FusedRouterKernel:
|
||||
|
||||
def _setup_attributes(self, tiled_mma, tiled_mma_sfb, a_dtype, b_dtype, sf_dtype):
|
||||
"""Set up kernel attributes. Mirrors FusedSwiGLUScaledGroupedGemmKernel._setup_attributes."""
|
||||
self.mma_inst_shape_mn = (self.mma_tiler_mnk[0], self.mma_tiler_mnk[1])
|
||||
self.mma_inst_shape_mn_sfb = (
|
||||
self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1),
|
||||
cute.round_up(self.mma_inst_shape_mn[1], 128),
|
||||
)
|
||||
# mma_inst_shape_mn is set in __init__ before _create_tiled_mma is called
|
||||
|
||||
mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2])
|
||||
mma_inst_tile_k = self.mma_tiler_mnk[2] // mma_inst_shape_k
|
||||
|
||||
Reference in New Issue
Block a user