fix: PersistentTileSchedulerParams constructor takes (problem_shape, cluster_shape) not from_shape

This commit is contained in:
2026-05-31 23:49:12 +00:00
parent 824d054ad7
commit 210391e571

View File

@@ -137,9 +137,9 @@ class DenseRouterDecodeKernel:
grid = (num_M_tiles * num_N_tiles, 1, 1)
max_active_clusters = 0
tile_sched_params = utils.PersistentTileSchedulerParams.from_shape(
cutlass.Int32(num_M_tiles), cutlass.Int32(num_N_tiles),
cutlass.Int32(L), max_active_clusters, self.cluster_shape_mn)
tile_sched_params = utils.PersistentTileSchedulerParams(
(cutlass.Int32(num_M_tiles), cutlass.Int32(num_N_tiles), cutlass.Int32(L)),
(*self.cluster_shape_mn, 1))
self._kernel(
tiled_mma, tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b,