fix: defer router MMA/TMA setup into cute.compile context (matches MoE pattern)
This commit is contained in:
@@ -103,52 +103,52 @@ class DenseRouterDecodeKernel:
|
||||
def run(self, X, W_gate, e_bias, out_w, out_ids, M, E, K, scaling, top_k, stream=None):
|
||||
self.a_major_mode = OperandMajorMode.K
|
||||
self.b_major_mode = OperandMajorMode.K
|
||||
self._setup_attributes()
|
||||
|
||||
X_cu = cutlass_torch.to_cuTe_tensor(X, major_mode=self.a_major_mode)
|
||||
W_cu = cutlass_torch.to_cuTe_tensor(W_gate, major_mode=self.b_major_mode)
|
||||
e_bias_cu = cutlass_torch.to_cuTe_tensor(e_bias)
|
||||
out_w_cu = cutlass_torch.to_cuTe_tensor(out_w)
|
||||
out_ids_cu = cutlass_torch.to_cuTe_tensor(out_ids)
|
||||
|
||||
tiled_mma = self._tiled_mma
|
||||
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
|
||||
a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
|
||||
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
a_op, X_cu, a_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
|
||||
|
||||
b_smem = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
|
||||
b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id)
|
||||
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
b_op, W_cu, b_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
|
||||
|
||||
a_copy = cute.size_in_bytes(self.a_dtype, a_smem)
|
||||
b_copy = cute.size_in_bytes(self.b_dtype, b_smem)
|
||||
self.num_tma_load_bytes = (a_copy + b_copy) * atom_thr_size
|
||||
|
||||
num_M_tiles = cute.ceil_div(M, self.cta_tile_shape_mnk[0])
|
||||
num_N_tiles = cute.ceil_div(E, self.cta_tile_shape_mnk[1])
|
||||
L = 1
|
||||
grid = (num_M_tiles * num_N_tiles, 1, 1)
|
||||
|
||||
max_active_clusters = 0
|
||||
tile_sched_params = utils.PersistentTileSchedulerParams.from_shape(
|
||||
cutlass.Int32(num_M_tiles), cutlass.Int32(num_N_tiles),
|
||||
cutlass.Int32(L), max_active_clusters, self.cluster_shape_mn)
|
||||
|
||||
if stream is None:
|
||||
stream = cuda.CUstream(0)
|
||||
|
||||
self._kernel(
|
||||
tiled_mma, tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b,
|
||||
self.cluster_layout_vmnk, self.a_smem_layout_staged,
|
||||
self.b_smem_layout_staged, self.epi_tile,
|
||||
e_bias_cu, out_w_cu, out_ids_cu, tile_sched_params,
|
||||
M, E, K, top_k, scaling,
|
||||
).launch(grid=grid, block=[self.threads_per_cta, 1, 1],
|
||||
cluster=(*self.cluster_shape_mn, 1), stream=stream, min_blocks_per_mp=1)
|
||||
# All MLIR-dependent setup (tiled_mma, TMA atoms, CuTe tensor conversion)
|
||||
# must happen inside cute.compile context. This matches the MoE kernel pattern.
|
||||
def _compiled_fn(X, W_gate, e_bias, out_w, out_ids):
|
||||
self._setup_attributes()
|
||||
tiled_mma = self._tiled_mma
|
||||
|
||||
X_cu = cutlass_torch.to_cuTe_tensor(X, major_mode=self.a_major_mode)
|
||||
W_cu = cutlass_torch.to_cuTe_tensor(W_gate, major_mode=self.b_major_mode)
|
||||
e_bias_cu = cutlass_torch.to_cuTe_tensor(e_bias)
|
||||
out_w_cu = cutlass_torch.to_cuTe_tensor(out_w)
|
||||
out_ids_cu = cutlass_torch.to_cuTe_tensor(out_ids)
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
|
||||
a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
|
||||
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
a_op, X_cu, a_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
|
||||
|
||||
b_smem = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
|
||||
b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id)
|
||||
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
b_op, W_cu, b_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
|
||||
|
||||
num_M_tiles = cute.ceil_div(M, self.cta_tile_shape_mnk[0])
|
||||
num_N_tiles = cute.ceil_div(E, self.cta_tile_shape_mnk[1])
|
||||
L = 1
|
||||
grid = (num_M_tiles * num_N_tiles, 1, 1)
|
||||
|
||||
max_active_clusters = 0
|
||||
tile_sched_params = utils.PersistentTileSchedulerParams.from_shape(
|
||||
cutlass.Int32(num_M_tiles), cutlass.Int32(num_N_tiles),
|
||||
cutlass.Int32(L), max_active_clusters, self.cluster_shape_mn)
|
||||
|
||||
self._kernel(
|
||||
tiled_mma, tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b,
|
||||
self.cluster_layout_vmnk, self.a_smem_layout_staged,
|
||||
self.b_smem_layout_staged, self.epi_tile,
|
||||
e_bias_cu, out_w_cu, out_ids_cu, tile_sched_params,
|
||||
M, E, K, top_k, scaling,
|
||||
).launch(grid=grid, block=[self.threads_per_cta, 1, 1],
|
||||
cluster=(*self.cluster_shape_mn, 1), stream=stream, min_blocks_per_mp=1)
|
||||
|
||||
cute.compile(_compiled_fn, X, W_gate, e_bias, out_w, out_ids)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, tiled_mma, tma_atom_a, mA_mkl, tma_atom_b, mB_nkl,
|
||||
|
||||
Reference in New Issue
Block a user