diff --git a/dsv4/kernels/router/dense_router_decode_kernel.py b/dsv4/kernels/router/dense_router_decode_kernel.py index 717615da..872a9f5c 100644 --- a/dsv4/kernels/router/dense_router_decode_kernel.py +++ b/dsv4/kernels/router/dense_router_decode_kernel.py @@ -137,9 +137,9 @@ class DenseRouterDecodeKernel: grid = (num_M_tiles * num_N_tiles, 1, 1) max_active_clusters = 0 - tile_sched_params = utils.PersistentTileSchedulerParams.from_shape( - cutlass.Int32(num_M_tiles), cutlass.Int32(num_N_tiles), - cutlass.Int32(L), max_active_clusters, self.cluster_shape_mn) + tile_sched_params = utils.PersistentTileSchedulerParams( + (cutlass.Int32(num_M_tiles), cutlass.Int32(num_N_tiles), cutlass.Int32(L)), + (*self.cluster_shape_mn, 1)) self._kernel( tiled_mma, tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b,