fix: defer router MMA/TMA setup into cute.compile context (matches MoE pattern)

This commit is contained in:
2026-05-31 23:44:00 +00:00
parent 157f1c5258
commit d5d2b7b4b8

View File

@@ -103,52 +103,52 @@ class DenseRouterDecodeKernel:
def run(self, X, W_gate, e_bias, out_w, out_ids, M, E, K, scaling, top_k, stream=None):
self.a_major_mode = OperandMajorMode.K
self.b_major_mode = OperandMajorMode.K
self._setup_attributes()
X_cu = cutlass_torch.to_cuTe_tensor(X, major_mode=self.a_major_mode)
W_cu = cutlass_torch.to_cuTe_tensor(W_gate, major_mode=self.b_major_mode)
e_bias_cu = cutlass_torch.to_cuTe_tensor(e_bias)
out_w_cu = cutlass_torch.to_cuTe_tensor(out_w)
out_ids_cu = cutlass_torch.to_cuTe_tensor(out_ids)
tiled_mma = self._tiled_mma
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
a_smem = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
a_op, X_cu, a_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
b_smem = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id)
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
b_op, W_cu, b_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
a_copy = cute.size_in_bytes(self.a_dtype, a_smem)
b_copy = cute.size_in_bytes(self.b_dtype, b_smem)
self.num_tma_load_bytes = (a_copy + b_copy) * atom_thr_size
num_M_tiles = cute.ceil_div(M, self.cta_tile_shape_mnk[0])
num_N_tiles = cute.ceil_div(E, self.cta_tile_shape_mnk[1])
L = 1
grid = (num_M_tiles * num_N_tiles, 1, 1)
max_active_clusters = 0
tile_sched_params = utils.PersistentTileSchedulerParams.from_shape(
cutlass.Int32(num_M_tiles), cutlass.Int32(num_N_tiles),
cutlass.Int32(L), max_active_clusters, self.cluster_shape_mn)
if stream is None:
stream = cuda.CUstream(0)
self._kernel(
tiled_mma, tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b,
self.cluster_layout_vmnk, self.a_smem_layout_staged,
self.b_smem_layout_staged, self.epi_tile,
e_bias_cu, out_w_cu, out_ids_cu, tile_sched_params,
M, E, K, top_k, scaling,
).launch(grid=grid, block=[self.threads_per_cta, 1, 1],
cluster=(*self.cluster_shape_mn, 1), stream=stream, min_blocks_per_mp=1)
# All MLIR-dependent setup (tiled_mma, TMA atoms, CuTe tensor conversion)
# must happen inside cute.compile context. This matches the MoE kernel pattern.
def _compiled_fn(X, W_gate, e_bias, out_w, out_ids):
self._setup_attributes()
tiled_mma = self._tiled_mma
X_cu = cutlass_torch.to_cuTe_tensor(X, major_mode=self.a_major_mode)
W_cu = cutlass_torch.to_cuTe_tensor(W_gate, major_mode=self.b_major_mode)
e_bias_cu = cutlass_torch.to_cuTe_tensor(e_bias)
out_w_cu = cutlass_torch.to_cuTe_tensor(out_w)
out_ids_cu = cutlass_torch.to_cuTe_tensor(out_ids)
a_smem = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
a_op, X_cu, a_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
b_smem = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id)
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
b_op, W_cu, b_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
num_M_tiles = cute.ceil_div(M, self.cta_tile_shape_mnk[0])
num_N_tiles = cute.ceil_div(E, self.cta_tile_shape_mnk[1])
L = 1
grid = (num_M_tiles * num_N_tiles, 1, 1)
max_active_clusters = 0
tile_sched_params = utils.PersistentTileSchedulerParams.from_shape(
cutlass.Int32(num_M_tiles), cutlass.Int32(num_N_tiles),
cutlass.Int32(L), max_active_clusters, self.cluster_shape_mn)
self._kernel(
tiled_mma, tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b,
self.cluster_layout_vmnk, self.a_smem_layout_staged,
self.b_smem_layout_staged, self.epi_tile,
e_bias_cu, out_w_cu, out_ids_cu, tile_sched_params,
M, E, K, top_k, scaling,
).launch(grid=grid, block=[self.threads_per_cta, 1, 1],
cluster=(*self.cluster_shape_mn, 1), stream=stream, min_blocks_per_mp=1)
cute.compile(_compiled_fn, X, W_gate, e_bias, out_w, out_ids)
@cute.kernel
def _kernel(self, tiled_mma, tma_atom_a, mA_mkl, tma_atom_b, mB_nkl,