From d5d2b7b4b8e14eb8242cbd601b67a441d6c17021 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 31 May 2026 23:44:00 +0000 Subject: [PATCH] fix: defer router MMA/TMA setup into cute.compile context (matches MoE pattern) --- .../router/dense_router_decode_kernel.py | 84 +++++++++---------- 1 file changed, 42 insertions(+), 42 deletions(-) diff --git a/dsv4/kernels/router/dense_router_decode_kernel.py b/dsv4/kernels/router/dense_router_decode_kernel.py index 97b39b61..f1ebcb94 100644 --- a/dsv4/kernels/router/dense_router_decode_kernel.py +++ b/dsv4/kernels/router/dense_router_decode_kernel.py @@ -103,52 +103,52 @@ class DenseRouterDecodeKernel: def run(self, X, W_gate, e_bias, out_w, out_ids, M, E, K, scaling, top_k, stream=None): self.a_major_mode = OperandMajorMode.K self.b_major_mode = OperandMajorMode.K - self._setup_attributes() - - X_cu = cutlass_torch.to_cuTe_tensor(X, major_mode=self.a_major_mode) - W_cu = cutlass_torch.to_cuTe_tensor(W_gate, major_mode=self.b_major_mode) - e_bias_cu = cutlass_torch.to_cuTe_tensor(e_bias) - out_w_cu = cutlass_torch.to_cuTe_tensor(out_w) - out_ids_cu = cutlass_torch.to_cuTe_tensor(out_ids) - - tiled_mma = self._tiled_mma - atom_thr_size = cute.size(tiled_mma.thr_id.shape) - - a_smem = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) - a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) - tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( - a_op, X_cu, a_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape) - - b_smem = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) - b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id) - tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( - b_op, W_cu, b_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape) - - a_copy = cute.size_in_bytes(self.a_dtype, a_smem) - b_copy = cute.size_in_bytes(self.b_dtype, b_smem) - self.num_tma_load_bytes = (a_copy + b_copy) * atom_thr_size - - num_M_tiles = cute.ceil_div(M, self.cta_tile_shape_mnk[0]) - num_N_tiles = cute.ceil_div(E, self.cta_tile_shape_mnk[1]) - L = 1 - grid = (num_M_tiles * num_N_tiles, 1, 1) - - max_active_clusters = 0 - tile_sched_params = utils.PersistentTileSchedulerParams.from_shape( - cutlass.Int32(num_M_tiles), cutlass.Int32(num_N_tiles), - cutlass.Int32(L), max_active_clusters, self.cluster_shape_mn) if stream is None: stream = cuda.CUstream(0) - self._kernel( - tiled_mma, tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b, - self.cluster_layout_vmnk, self.a_smem_layout_staged, - self.b_smem_layout_staged, self.epi_tile, - e_bias_cu, out_w_cu, out_ids_cu, tile_sched_params, - M, E, K, top_k, scaling, - ).launch(grid=grid, block=[self.threads_per_cta, 1, 1], - cluster=(*self.cluster_shape_mn, 1), stream=stream, min_blocks_per_mp=1) + # All MLIR-dependent setup (tiled_mma, TMA atoms, CuTe tensor conversion) + # must happen inside cute.compile context. This matches the MoE kernel pattern. + def _compiled_fn(X, W_gate, e_bias, out_w, out_ids): + self._setup_attributes() + tiled_mma = self._tiled_mma + + X_cu = cutlass_torch.to_cuTe_tensor(X, major_mode=self.a_major_mode) + W_cu = cutlass_torch.to_cuTe_tensor(W_gate, major_mode=self.b_major_mode) + e_bias_cu = cutlass_torch.to_cuTe_tensor(e_bias) + out_w_cu = cutlass_torch.to_cuTe_tensor(out_w) + out_ids_cu = cutlass_torch.to_cuTe_tensor(out_ids) + + a_smem = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, X_cu, a_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape) + + b_smem = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, W_cu, b_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape) + + num_M_tiles = cute.ceil_div(M, self.cta_tile_shape_mnk[0]) + num_N_tiles = cute.ceil_div(E, self.cta_tile_shape_mnk[1]) + L = 1 + grid = (num_M_tiles * num_N_tiles, 1, 1) + + max_active_clusters = 0 + tile_sched_params = utils.PersistentTileSchedulerParams.from_shape( + cutlass.Int32(num_M_tiles), cutlass.Int32(num_N_tiles), + cutlass.Int32(L), max_active_clusters, self.cluster_shape_mn) + + self._kernel( + tiled_mma, tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b, + self.cluster_layout_vmnk, self.a_smem_layout_staged, + self.b_smem_layout_staged, self.epi_tile, + e_bias_cu, out_w_cu, out_ids_cu, tile_sched_params, + M, E, K, top_k, scaling, + ).launch(grid=grid, block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), stream=stream, min_blocks_per_mp=1) + + cute.compile(_compiled_fn, X, W_gate, e_bias, out_w, out_ids) @cute.kernel def _kernel(self, tiled_mma, tma_atom_a, mA_mkl, tma_atom_b, mB_nkl,