Fix BF16 router mma_tiler: use cutlass.Int32 for CuTe DSL compatibility
This commit is contained in:
@@ -67,7 +67,8 @@ class DenseRouterDecodeKernel:
|
||||
self._tiled_mma = self._create_tiled_mma()
|
||||
mma_inst_shape_k = cute.size(self._tiled_mma.shape_mnk, mode=[2])
|
||||
mma_inst_tile_k = 4
|
||||
self.mma_tiler = (*self.mma_tiler_mn, mma_inst_shape_k * mma_inst_tile_k)
|
||||
k_tile = mma_inst_shape_k * mma_inst_tile_k
|
||||
self.mma_tiler = (cutlass.Int32(self.mma_tiler_mn[0]), cutlass.Int32(self.mma_tiler_mn[1]), cutlass.Int32(k_tile))
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.mma_tiler[0] // cute.size(self._tiled_mma.thr_id.shape),
|
||||
self.mma_tiler[1], self.mma_tiler[2],
|
||||
|
||||
Reference in New Issue
Block a user