fix: PersistentTileSchedulerParams constructor takes (problem_shape, cluster_shape) not from_shape
This commit is contained in:
@@ -137,9 +137,9 @@ class DenseRouterDecodeKernel:
|
||||
grid = (num_M_tiles * num_N_tiles, 1, 1)
|
||||
|
||||
max_active_clusters = 0
|
||||
tile_sched_params = utils.PersistentTileSchedulerParams.from_shape(
|
||||
cutlass.Int32(num_M_tiles), cutlass.Int32(num_N_tiles),
|
||||
cutlass.Int32(L), max_active_clusters, self.cluster_shape_mn)
|
||||
tile_sched_params = utils.PersistentTileSchedulerParams(
|
||||
(cutlass.Int32(num_M_tiles), cutlass.Int32(num_N_tiles), cutlass.Int32(L)),
|
||||
(*self.cluster_shape_mn, 1))
|
||||
|
||||
self._kernel(
|
||||
tiled_mma, tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b,
|
||||
|
||||
Reference in New Issue
Block a user