Fix BF16 router mma_tiler: use cutlass.Int32 for CuTe DSL compatibility

This commit is contained in:
2026-06-01 07:29:30 +00:00
parent 79be9cb8da
commit ef4c0ad489

View File

@@ -67,7 +67,8 @@ class DenseRouterDecodeKernel:
self._tiled_mma = self._create_tiled_mma()
mma_inst_shape_k = cute.size(self._tiled_mma.shape_mnk, mode=[2])
mma_inst_tile_k = 4
self.mma_tiler = (*self.mma_tiler_mn, mma_inst_shape_k * mma_inst_tile_k)
k_tile = mma_inst_shape_k * mma_inst_tile_k
self.mma_tiler = (cutlass.Int32(self.mma_tiler_mn[0]), cutlass.Int32(self.mma_tiler_mn[1]), cutlass.Int32(k_tile))
self.cta_tile_shape_mnk = (
self.mma_tiler[0] // cute.size(self._tiled_mma.thr_id.shape),
self.mma_tiler[1], self.mma_tiler[2],