Ensure mma_tiler contains CuTe Ints for cute.slice_ compatibility

This commit is contained in:
2026-06-01 07:16:47 +00:00
parent 57cc20d5ad
commit 39b481e52b

View File

@@ -227,6 +227,10 @@ class Nvfp4FusedRouterKernel:
self._setup_attributes()
tiled_mma = self._tiled_mma
tiled_mma_sfb = self._tiled_mma_sfb
# Ensure mma_tiler contains CuTe Ints (not Python ints)
self.mma_tiler = (cutlass.Int32(self.mma_tiler[0]), cutlass.Int32(self.mma_tiler[1]), cutlass.Int32(self.mma_tiler[2]))
self.mma_tiler_sfb = (cutlass.Int32(self.mma_tiler_sfb[0]), cutlass.Int32(self.mma_tiler_sfb[1]), cutlass.Int32(self.mma_tiler_sfb[2]))
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
a_smem_0 = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))