The C++ CUTLASS kernel is fundamentally broken (cosine 0.05 with real data). Switching to NVIDIA's CuTeDSL approach based on their official MoE scaled grouped GEMM example. Reference files copied: - moe_torch_scaled_grouped_mm.py (3900 lines — our new kernel) - moe_utils.py, moe_persistent_scheduler.py, moe_sched_extension.py - grouped_blockscaled_gemm.py, dense_blockscaled_gemm_persistent.py - blockscaled_layout.py
696 lines
27 KiB
Python
696 lines
27 KiB
Python
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
|
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
|
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
|
|
# 3. Neither the name of the copyright holder nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
"""
|
|
MoE Persistent Tile Scheduler
|
|
|
|
A specialized tile scheduler for MoE (Mixture of Experts) grouped GEMM operations.
|
|
This scheduler handles tile iteration across all experts, producing MoEWorkTileInfo
|
|
(expert_idx, tile_m_idx, tile_n_idx, k_tile_cnt) for each tile.
|
|
|
|
Scenarios:
|
|
- 2Dx3D (Forward): A(tokens_sum, hidden) x B(experts, intermediate, hidden) -> C(tokens_sum, intermediate)
|
|
- 2Dx2D (Backward): A(intermediate, tokens_sum) x B(hidden, tokens_sum) -> C(experts, intermediate, hidden)
|
|
|
|
Key design principle:
|
|
- Scheduler is ONLY responsible for tile iteration (tensor-agnostic, TMA-agnostic)
|
|
- Domain conversion (fake tensor -> real expert tensor) is handled by MoESchedExtension
|
|
- TMA descriptor management is handled by OnlineTensormapDescCreator
|
|
- The kernel orchestrates all three components
|
|
"""
|
|
|
|
from typing import List, Tuple, Literal
|
|
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
from cutlass.cutlass_dsl import (
|
|
Boolean,
|
|
Int32,
|
|
Integer,
|
|
extract_mlir_values,
|
|
new_from_mlir_values,
|
|
const_expr,
|
|
dsl_user_op,
|
|
)
|
|
from cutlass._mlir import ir
|
|
|
|
# =============================================================================
|
|
# Work Tile Info
|
|
# =============================================================================
|
|
|
|
class MoEWorkTileInfo:
|
|
"""
|
|
Work tile information for MoE scheduler.
|
|
|
|
Contains CTA-level tile information for executor warps:
|
|
- expert_idx: Which expert (-1 means invalid/done)
|
|
- tile_m_idx: CTA tile index along GEMM M dimension
|
|
- tile_n_idx: CTA tile index along GEMM N dimension
|
|
- k_tile_cnt: Number of CTA tiles along K dimension
|
|
|
|
Note: These are CTA-level indices, not cluster-level.
|
|
tile_l_idx is always 0 for MoE, executor can hardcode it.
|
|
|
|
For 2Dx3D (Forward):
|
|
M = tokens_i (dynamic), N = intermediate (fixed), K = hidden (fixed)
|
|
|
|
For 2Dx2D (Backward):
|
|
M = intermediate (fixed), N = hidden (fixed), K = tokens_i (dynamic)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
expert_idx: Int32, # -1 means invalid tile
|
|
tile_m_idx: Int32,
|
|
tile_n_idx: Int32,
|
|
k_tile_cnt: Int32,
|
|
):
|
|
self.expert_idx = expert_idx
|
|
self.tile_m_idx = tile_m_idx
|
|
self.tile_n_idx = tile_n_idx
|
|
self.k_tile_cnt = k_tile_cnt
|
|
|
|
@property
|
|
def is_valid_tile(self) -> Boolean:
|
|
"""Check if this is a valid work tile (expert_idx >= 0)."""
|
|
return self.expert_idx >= Int32(0)
|
|
|
|
def __extract_mlir_values__(self) -> List[ir.Value]:
|
|
values = extract_mlir_values(self.expert_idx)
|
|
values.extend(extract_mlir_values(self.tile_m_idx))
|
|
values.extend(extract_mlir_values(self.tile_n_idx))
|
|
values.extend(extract_mlir_values(self.k_tile_cnt))
|
|
return values
|
|
|
|
def __new_from_mlir_values__(self, values: List[ir.Value]) -> "MoEWorkTileInfo":
|
|
assert len(values) == 4
|
|
return MoEWorkTileInfo(
|
|
expert_idx=new_from_mlir_values(self.expert_idx, [values[0]]),
|
|
tile_m_idx=new_from_mlir_values(self.tile_m_idx, [values[1]]),
|
|
tile_n_idx=new_from_mlir_values(self.tile_n_idx, [values[2]]),
|
|
k_tile_cnt=new_from_mlir_values(self.k_tile_cnt, [values[3]]),
|
|
)
|
|
|
|
def to_rmem_tensor(self):
|
|
"""Pack work tile info fields into an rmem tensor of shape (4,) for vectorized smem copy."""
|
|
rmem = cute.make_rmem_tensor((4,), Int32)
|
|
rmem[0] = self.expert_idx
|
|
rmem[1] = self.tile_m_idx
|
|
rmem[2] = self.tile_n_idx
|
|
rmem[3] = self.k_tile_cnt
|
|
return rmem
|
|
|
|
@staticmethod
|
|
def from_rmem_tensor(rmem) -> "MoEWorkTileInfo":
|
|
"""Unpack work tile info from an rmem tensor of shape (4,)."""
|
|
return MoEWorkTileInfo(
|
|
expert_idx=rmem[0], # type: ignore[arg-type]
|
|
tile_m_idx=rmem[1], # type: ignore[arg-type]
|
|
tile_n_idx=rmem[2], # type: ignore[arg-type]
|
|
k_tile_cnt=rmem[3], # type: ignore[arg-type]
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Scheduler Parameters
|
|
# =============================================================================
|
|
|
|
class MoEStaticSchedulerParams:
|
|
"""
|
|
Parameters for MoE tile scheduler.
|
|
|
|
Uses unified semantics for both scenarios:
|
|
- expert_shape: (expert_cnt, intermediate, hidden)
|
|
|
|
For 2Dx3D: GEMM is (M=tokens_i, N=intermediate, K=hidden) per expert
|
|
For 2Dx2D: GEMM is (M=hidden, N=intermediate, K=tokens_i) per expert
|
|
|
|
Tile hierarchy:
|
|
- cta_tile_shape_mnk: Single CTA tile shape (tile_m, tile_n, tile_k)
|
|
- cluster_shape_mn: CTAs per cluster (cluster_m, cluster_n)
|
|
- cluster_tile_shape_mn: Cluster tile shape = cta_tile_shape * cluster_shape
|
|
|
|
This class is used both on host (for grid shape calculation) and on device
|
|
(stored in scheduler). Codegen-time constants (scenario, cta_tile_shape_mnk,
|
|
cluster_shape_mn) are NOT serialized to MLIR values.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
scenario: Literal["2Dx3D", "2Dx2D"],
|
|
expert_shape: Tuple[int | Int32, int | Int32, int | Int32], # (expert_cnt, intermediate, hidden)
|
|
cta_tile_shape_mnk: Tuple[int, int, int], # (tile_m, tile_n, tile_k)
|
|
cluster_shape_mn: Tuple[int, int], # (cluster_m, cluster_n)
|
|
):
|
|
self.scenario = scenario
|
|
e, i, h = expert_shape
|
|
self.expert_cnt = e if isinstance(e, Int32) else Int32(e)
|
|
self.intermediate = i if isinstance(i, Int32) else Int32(i)
|
|
self.hidden = h if isinstance(h, Int32) else Int32(h)
|
|
self.cta_tile_shape_mnk = cta_tile_shape_mnk
|
|
self.cluster_shape_mn = cluster_shape_mn
|
|
|
|
@property
|
|
def cluster_tile_m(self) -> int:
|
|
"""Cluster tile size along M = cta_tile_m * cluster_m."""
|
|
return self.cta_tile_shape_mnk[0] * self.cluster_shape_mn[0]
|
|
|
|
@property
|
|
def cluster_tile_n(self) -> int:
|
|
"""Cluster tile size along N = cta_tile_n * cluster_n."""
|
|
return self.cta_tile_shape_mnk[1] * self.cluster_shape_mn[1]
|
|
|
|
@property
|
|
def cta_tile_k(self) -> int:
|
|
"""CTA tile size along K (same as cluster since cluster_k = 1)."""
|
|
return self.cta_tile_shape_mnk[2]
|
|
|
|
def __extract_mlir_values__(self) -> List[ir.Value]:
|
|
"""Only serialize runtime values, not codegen-time constants."""
|
|
values = []
|
|
values.extend(extract_mlir_values(self.expert_cnt))
|
|
values.extend(extract_mlir_values(self.intermediate))
|
|
values.extend(extract_mlir_values(self.hidden))
|
|
return values
|
|
|
|
def __new_from_mlir_values__(self, values: List[ir.Value]) -> "MoEStaticSchedulerParams":
|
|
assert len(values) == 3
|
|
return MoEStaticSchedulerParams(
|
|
scenario=self.scenario,
|
|
expert_shape=(
|
|
new_from_mlir_values(self.expert_cnt, [values[0]]),
|
|
new_from_mlir_values(self.intermediate, [values[1]]),
|
|
new_from_mlir_values(self.hidden, [values[2]]),
|
|
),
|
|
cta_tile_shape_mnk=self.cta_tile_shape_mnk,
|
|
cluster_shape_mn=self.cluster_shape_mn,
|
|
)
|
|
|
|
@staticmethod
|
|
def get_grid_shape(
|
|
params: "MoEStaticSchedulerParams",
|
|
max_active_clusters: int,
|
|
) -> Tuple[int, int, int]:
|
|
"""
|
|
Compute grid shape for kernel launch.
|
|
|
|
Since host doesn't know token distribution across experts,
|
|
we launch max_active_clusters and let device-side scheduler
|
|
determine which tiles are valid.
|
|
"""
|
|
return (
|
|
params.cluster_shape_mn[0],
|
|
params.cluster_shape_mn[1],
|
|
max_active_clusters,
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Scheduler (Device-side)
|
|
# =============================================================================
|
|
|
|
class MoEStaticPersistentTileScheduler:
|
|
"""
|
|
Persistent tile scheduler specialized for MoE grouped GEMM.
|
|
|
|
This scheduler is ONLY responsible for tile iteration. It does NOT know
|
|
about tensor types, TMA descriptors, or domain conversion. Those concerns
|
|
are handled by MoESchedExtension and OnlineTensormapDescCreator respectively.
|
|
|
|
Architecture:
|
|
- Scheduler warp: Holds scheduler instance, iterates tiles, broadcasts work_tile_info
|
|
- Executor warps: Read work_tile_info from smem, use MoESchedExtension for
|
|
domain conversion and TMA desc selection
|
|
|
|
The scheduler handles:
|
|
- 2Dx3D: Dynamic M per expert (from offs), fixed N (intermediate) and K (hidden)
|
|
- 2Dx2D: Fixed M (intermediate) and N (hidden), dynamic K per expert (reduction axis)
|
|
|
|
Usage (Scheduler warp):
|
|
scheduler = MoEStaticPersistentTileScheduler.create(params, offs, block_idx, grid_dim)
|
|
work_tile_info = scheduler.initial_work_tile_info()
|
|
# Broadcast work_tile_info to smem...
|
|
|
|
while work_tile_info.is_valid_tile:
|
|
# ... do work ...
|
|
work_tile_info = scheduler.advance_to_next_work()
|
|
# Broadcast work_tile_info to smem...
|
|
|
|
Usage (Executor warps - via MoESchedExtension):
|
|
# Read work_tile_info from smem...
|
|
real_a, desc_a = ext.get_gmem_tensor("a", tma_tensor_a, offs, work_tile_info)
|
|
real_b, desc_b = ext.get_gmem_tensor("b", tma_tensor_b, offs, work_tile_info)
|
|
real_c, desc_c = ext.get_gmem_tensor("c", tma_tensor_c, offs, work_tile_info)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
# Params (contains scenario, expert_cnt, intermediate, hidden, tile/cluster shapes)
|
|
params: MoEStaticSchedulerParams,
|
|
# Runtime tensor for scheduling
|
|
offs: cute.Tensor, # (experts,) cumsum of token counts
|
|
# Scheduling state
|
|
num_persistent_clusters: Int32,
|
|
current_work_linear_idx: Int32,
|
|
cta_id_in_cluster: cute.Coord,
|
|
# Expert tracking state (for O(1) advance within same expert)
|
|
current_expert_idx: Int32,
|
|
expert_tile_start: Int32, # cumsum of tiles before current expert
|
|
expert_tile_end: Int32, # cumsum of tiles including current expert
|
|
):
|
|
self.params = params
|
|
self.offs = offs
|
|
self.num_persistent_clusters = num_persistent_clusters
|
|
self._current_work_linear_idx = current_work_linear_idx
|
|
self.cta_id_in_cluster = cta_id_in_cluster
|
|
# Expert tracking
|
|
self.current_expert_idx = current_expert_idx
|
|
self.expert_tile_start = expert_tile_start
|
|
self.expert_tile_end = expert_tile_end
|
|
|
|
# =========================================================================
|
|
# Convenience accessors for params
|
|
# =========================================================================
|
|
|
|
@property
|
|
def scenario(self) -> Literal["2Dx3D", "2Dx2D"]:
|
|
return self.params.scenario
|
|
|
|
@property
|
|
def expert_cnt(self) -> Int32:
|
|
return self.params.expert_cnt
|
|
|
|
@property
|
|
def intermediate(self) -> Int32:
|
|
return self.params.intermediate
|
|
|
|
@property
|
|
def hidden(self) -> Int32:
|
|
return self.params.hidden
|
|
|
|
@property
|
|
def cta_tile_shape_mnk(self) -> Tuple[int, int, int]:
|
|
return self.params.cta_tile_shape_mnk
|
|
|
|
@property
|
|
def cluster_shape_mn(self) -> Tuple[int, int]:
|
|
return self.params.cluster_shape_mn
|
|
|
|
@property
|
|
def cluster_tile_m(self) -> int:
|
|
return self.params.cluster_tile_m
|
|
|
|
@property
|
|
def cluster_tile_n(self) -> int:
|
|
return self.params.cluster_tile_n
|
|
|
|
@property
|
|
def cta_tile_k(self) -> int:
|
|
return self.params.cta_tile_k
|
|
|
|
# =========================================================================
|
|
# MLIR value serialization (for SSA value passing in device code)
|
|
# =========================================================================
|
|
|
|
def __extract_mlir_values__(self) -> List[ir.Value]:
|
|
values = []
|
|
# Params (only runtime values are extracted)
|
|
values.extend(extract_mlir_values(self.params))
|
|
# Runtime tensor for scheduling
|
|
values.extend(extract_mlir_values(self.offs))
|
|
# Scheduling state
|
|
values.extend(extract_mlir_values(self.num_persistent_clusters))
|
|
values.extend(extract_mlir_values(self._current_work_linear_idx))
|
|
values.extend(extract_mlir_values(self.cta_id_in_cluster))
|
|
# Expert tracking state
|
|
values.extend(extract_mlir_values(self.current_expert_idx))
|
|
values.extend(extract_mlir_values(self.expert_tile_start))
|
|
values.extend(extract_mlir_values(self.expert_tile_end))
|
|
return values
|
|
|
|
def __new_from_mlir_values__(
|
|
self, values: List[ir.Value]
|
|
) -> "MoEStaticPersistentTileScheduler":
|
|
idx = 0
|
|
|
|
# Params (3 values: expert_cnt, intermediate, hidden)
|
|
new_params = new_from_mlir_values(self.params, values[idx:idx + 3])
|
|
idx += 3
|
|
|
|
# Runtime tensor for scheduling (variable size)
|
|
offs_len = len(extract_mlir_values(self.offs))
|
|
new_offs = new_from_mlir_values(self.offs, values[idx:idx + offs_len])
|
|
idx += offs_len
|
|
|
|
# Scheduling state
|
|
new_num_persistent_clusters = new_from_mlir_values(
|
|
self.num_persistent_clusters, [values[idx]]
|
|
)
|
|
idx += 1
|
|
new_current_work_linear_idx = new_from_mlir_values(
|
|
self._current_work_linear_idx, [values[idx]]
|
|
)
|
|
idx += 1
|
|
|
|
# cta_id_in_cluster (3 values for Coord)
|
|
new_cta_id_in_cluster = new_from_mlir_values(
|
|
self.cta_id_in_cluster, values[idx:idx + 3]
|
|
)
|
|
idx += 3
|
|
|
|
# Expert tracking state
|
|
new_current_expert_idx = new_from_mlir_values(
|
|
self.current_expert_idx, [values[idx]]
|
|
)
|
|
idx += 1
|
|
new_expert_tile_start = new_from_mlir_values(
|
|
self.expert_tile_start, [values[idx]]
|
|
)
|
|
idx += 1
|
|
new_expert_tile_end = new_from_mlir_values(
|
|
self.expert_tile_end, [values[idx]]
|
|
)
|
|
idx += 1
|
|
|
|
return MoEStaticPersistentTileScheduler(
|
|
params=new_params,
|
|
offs=new_offs,
|
|
num_persistent_clusters=new_num_persistent_clusters,
|
|
current_work_linear_idx=new_current_work_linear_idx,
|
|
cta_id_in_cluster=new_cta_id_in_cluster,
|
|
current_expert_idx=new_current_expert_idx,
|
|
expert_tile_start=new_expert_tile_start,
|
|
expert_tile_end=new_expert_tile_end,
|
|
)
|
|
|
|
# =========================================================================
|
|
# Factory method
|
|
# =========================================================================
|
|
|
|
@staticmethod
|
|
@dsl_user_op
|
|
def create(
|
|
params: MoEStaticSchedulerParams,
|
|
offs: cute.Tensor,
|
|
block_idx: Tuple[Integer, Integer, Integer],
|
|
grid_dim: Tuple[Integer, Integer, Integer],
|
|
*,
|
|
loc=None,
|
|
ip=None,
|
|
) -> "MoEStaticPersistentTileScheduler":
|
|
"""
|
|
Create a MoE persistent tile scheduler.
|
|
|
|
:param params: Scheduler parameters (from host)
|
|
:param offs: Cumsum tensor of token counts per expert, shape (experts,)
|
|
:param block_idx: CUDA block index
|
|
:param grid_dim: CUDA grid dimensions
|
|
"""
|
|
num_persistent_clusters = cute.size(grid_dim, loc=loc, ip=ip) // cute.size(
|
|
params.cluster_shape_mn, loc=loc, ip=ip
|
|
)
|
|
|
|
bidx, bidy, bidz = block_idx
|
|
current_work_linear_idx = Int32(bidz)
|
|
|
|
cta_id_in_cluster = (
|
|
Int32(bidx % params.cluster_shape_mn[0]),
|
|
Int32(bidy % params.cluster_shape_mn[1]),
|
|
Int32(0),
|
|
)
|
|
|
|
# Initialize expert tracking to "before expert 0"
|
|
# The first call to _get_work_tile_for_linear_idx will advance to the correct expert
|
|
current_expert_idx = Int32(0)
|
|
expert_tile_start = Int32(0)
|
|
expert_tile_end = Int32(0) # Will be computed on first access
|
|
|
|
return MoEStaticPersistentTileScheduler(
|
|
params=params,
|
|
offs=offs,
|
|
num_persistent_clusters=num_persistent_clusters,
|
|
current_work_linear_idx=current_work_linear_idx,
|
|
cta_id_in_cluster=cta_id_in_cluster,
|
|
current_expert_idx=current_expert_idx,
|
|
expert_tile_start=expert_tile_start,
|
|
expert_tile_end=expert_tile_end,
|
|
)
|
|
|
|
# =========================================================================
|
|
# Tile iteration methods
|
|
# =========================================================================
|
|
|
|
@dsl_user_op
|
|
@cute.jit
|
|
def initial_work_tile_info(self, *, loc=None, ip=None) -> MoEWorkTileInfo:
|
|
"""Get the initial work tile info."""
|
|
return self._get_work_tile_for_linear_idx(
|
|
self._current_work_linear_idx, loc=loc, ip=ip
|
|
)
|
|
|
|
@dsl_user_op
|
|
@cute.jit
|
|
def advance_to_next_work(self, *, loc=None, ip=None) -> MoEWorkTileInfo:
|
|
"""Advance to the next work tile and return its info."""
|
|
self._current_work_linear_idx += self.num_persistent_clusters
|
|
return self._get_work_tile_for_linear_idx(
|
|
self._current_work_linear_idx, loc=loc, ip=ip
|
|
)
|
|
|
|
@dsl_user_op
|
|
@cute.jit
|
|
def _get_work_tile_for_linear_idx(
|
|
self,
|
|
cluster_linear_idx: Int32,
|
|
*,
|
|
loc=None,
|
|
ip=None
|
|
) -> MoEWorkTileInfo:
|
|
"""
|
|
Convert a linear cluster index to MoEWorkTileInfo.
|
|
|
|
Uses cached expert tracking state for O(1) fast path when staying
|
|
within the same expert. Advances expert state when needed.
|
|
|
|
Returns an invalid tile (expert_idx = -1) if cluster_linear_idx is out of range.
|
|
"""
|
|
# Ensure expert tracking is initialized and up-to-date
|
|
self._advance_expert_to_contain(cluster_linear_idx, loc=loc, ip=ip)
|
|
|
|
# Check if valid (still within expert range after advancing)
|
|
is_valid = self.current_expert_idx < self.expert_cnt
|
|
|
|
work_tile_info = MoEWorkTileInfo(
|
|
expert_idx=Int32(-1),
|
|
tile_m_idx=Int32(0),
|
|
tile_n_idx=Int32(0),
|
|
k_tile_cnt=Int32(0),
|
|
)
|
|
|
|
if is_valid:
|
|
# Compute local cluster tile indices within current expert
|
|
local_idx = cluster_linear_idx - self.expert_tile_start
|
|
cluster_tile_m_idx, cluster_tile_n_idx = self._decompose_local_idx(
|
|
local_idx, self.current_expert_idx, loc=loc, ip=ip
|
|
)
|
|
|
|
# Convert cluster tile indices to CTA tile indices
|
|
# cta_tile_idx = cluster_tile_idx * cluster_shape + cta_id_in_cluster
|
|
cta_tile_m_idx = (
|
|
cluster_tile_m_idx * self.cluster_shape_mn[0]
|
|
+ self.cta_id_in_cluster[0] # type: ignore[index]
|
|
)
|
|
cta_tile_n_idx = (
|
|
cluster_tile_n_idx * self.cluster_shape_mn[1]
|
|
+ self.cta_id_in_cluster[1] # type: ignore[index]
|
|
)
|
|
# Compute k_tile_cnt
|
|
k_tile_cnt = self._compute_k_tile_cnt(self.current_expert_idx, loc=loc, ip=ip)
|
|
|
|
work_tile_info = MoEWorkTileInfo(
|
|
expert_idx=self.current_expert_idx,
|
|
tile_m_idx=cta_tile_m_idx,
|
|
tile_n_idx=cta_tile_n_idx,
|
|
k_tile_cnt=k_tile_cnt,
|
|
)
|
|
return work_tile_info
|
|
|
|
@dsl_user_op
|
|
@cute.jit
|
|
def _advance_expert_to_contain(
|
|
self,
|
|
cluster_linear_idx: Int32,
|
|
*,
|
|
loc=None,
|
|
ip=None,
|
|
) -> None:
|
|
"""
|
|
Advance expert tracking state until current expert contains cluster_linear_idx,
|
|
or we run out of experts.
|
|
|
|
Fast path: If already in correct expert, no work needed.
|
|
"""
|
|
# Initialize expert_tile_end if this is the first call (expert_tile_end == 0)
|
|
if self.expert_tile_end == Int32(0):
|
|
tiles_for_expert_0 = self._compute_tiles_for_expert(Int32(0), loc=loc, ip=ip)
|
|
self.expert_tile_end = tiles_for_expert_0
|
|
|
|
# Advance until cluster_linear_idx < expert_tile_end or no more experts
|
|
while cluster_linear_idx >= self.expert_tile_end and self.current_expert_idx < self.expert_cnt:
|
|
self.current_expert_idx = self.current_expert_idx + 1
|
|
self.expert_tile_start = self.expert_tile_end
|
|
|
|
if self.current_expert_idx < self.expert_cnt:
|
|
tiles_for_expert = self._compute_tiles_for_expert(
|
|
self.current_expert_idx, loc=loc, ip=ip
|
|
)
|
|
self.expert_tile_end = self.expert_tile_end + tiles_for_expert
|
|
|
|
@dsl_user_op
|
|
@cute.jit
|
|
def _compute_tiles_for_expert(
|
|
self,
|
|
expert_idx: Int32,
|
|
*,
|
|
loc=None,
|
|
ip=None,
|
|
) -> Int32:
|
|
"""Compute total cluster tiles for a given expert."""
|
|
if const_expr(self.scenario == "2Dx2D"):
|
|
# Fixed M=hidden, N=intermediate
|
|
cluster_tile_m_cnt = (self.hidden + self.cluster_tile_m - 1) // self.cluster_tile_m
|
|
cluster_tile_n_cnt = (self.intermediate + self.cluster_tile_n - 1) // self.cluster_tile_n
|
|
return cluster_tile_m_cnt * cluster_tile_n_cnt
|
|
else: # 2Dx3D
|
|
# Variable M (tokens), fixed N
|
|
tokens_i = self.offs[expert_idx]
|
|
if expert_idx > 0:
|
|
tokens_i = tokens_i - self.offs[expert_idx - 1] # type: ignore[operator]
|
|
cluster_tile_m_cnt = (
|
|
tokens_i + self.cluster_tile_m - 1 # type: ignore[operator]
|
|
) // self.cluster_tile_m
|
|
cluster_tile_n_cnt = (
|
|
self.intermediate + self.cluster_tile_n - 1
|
|
) // self.cluster_tile_n
|
|
return cluster_tile_m_cnt * cluster_tile_n_cnt
|
|
|
|
@dsl_user_op
|
|
@cute.jit
|
|
def _decompose_local_idx(
|
|
self,
|
|
local_idx: Int32,
|
|
expert_idx: Int32,
|
|
*,
|
|
loc=None,
|
|
ip=None,
|
|
) -> Tuple[Int32, Int32]:
|
|
"""
|
|
Decompose local cluster tile index within expert to (cluster_tile_m_idx, cluster_tile_n_idx).
|
|
|
|
Uses "short side first" strategy: the shorter dimension changes faster.
|
|
This maximizes overlap between adjacent clusters for better L2 cache utilization.
|
|
|
|
For example, if m_cnt=2, n_cnt=8:
|
|
- N is longer, so M changes faster: local_idx = n_idx * m_cnt + m_idx
|
|
- Linearization order: (0,0), (1,0), (0,1), (1,1), (0,2), (1,2), ...
|
|
"""
|
|
# Get tile counts for M and N
|
|
cluster_tile_m_cnt, cluster_tile_n_cnt = self._get_cluster_tile_counts(
|
|
expert_idx, loc=loc, ip=ip
|
|
)
|
|
cluster_tile_m_idx = -1
|
|
cluster_tile_n_idx = -1
|
|
|
|
# Short side first: shorter dimension changes faster
|
|
# If m_cnt <= n_cnt: m is shorter, m changes faster
|
|
# local_idx = n_idx * m_cnt + m_idx
|
|
# If n_cnt < m_cnt: n is shorter, n changes faster
|
|
# local_idx = m_idx * n_cnt + n_idx
|
|
if cluster_tile_m_cnt <= cluster_tile_n_cnt:
|
|
# M is shorter or equal, M changes faster
|
|
cluster_tile_m_idx = local_idx % cluster_tile_m_cnt
|
|
cluster_tile_n_idx = local_idx // cluster_tile_m_cnt
|
|
else:
|
|
# N is shorter, N changes faster
|
|
cluster_tile_n_idx = local_idx % cluster_tile_n_cnt
|
|
cluster_tile_m_idx = local_idx // cluster_tile_n_cnt
|
|
|
|
return (cluster_tile_m_idx, cluster_tile_n_idx)
|
|
|
|
@dsl_user_op
|
|
@cute.jit
|
|
def _get_cluster_tile_counts(
|
|
self,
|
|
expert_idx: Int32,
|
|
*,
|
|
loc=None,
|
|
ip=None,
|
|
) -> Tuple[Int32, Int32]:
|
|
"""Get (cluster_tile_m_cnt, cluster_tile_n_cnt) for a given expert."""
|
|
if const_expr(self.scenario == "2Dx2D"):
|
|
# Fixed M=hidden, N=intermediate
|
|
cluster_tile_m_cnt = (self.hidden + self.cluster_tile_m - 1) // self.cluster_tile_m
|
|
cluster_tile_n_cnt = (self.intermediate + self.cluster_tile_n - 1) // self.cluster_tile_n
|
|
else: # 2Dx3D
|
|
# Variable M (tokens), fixed N
|
|
tokens_i = self.offs[expert_idx]
|
|
if expert_idx > 0:
|
|
tokens_i = tokens_i - self.offs[expert_idx - 1] # type: ignore[operator]
|
|
cluster_tile_m_cnt = (
|
|
tokens_i + self.cluster_tile_m - 1 # type: ignore[operator]
|
|
) // self.cluster_tile_m
|
|
cluster_tile_n_cnt = (
|
|
self.intermediate + self.cluster_tile_n - 1
|
|
) // self.cluster_tile_n
|
|
return (cluster_tile_m_cnt, cluster_tile_n_cnt)
|
|
|
|
@dsl_user_op
|
|
@cute.jit
|
|
def _compute_k_tile_cnt(
|
|
self,
|
|
expert_idx: Int32,
|
|
*,
|
|
loc=None,
|
|
ip=None,
|
|
) -> Int32:
|
|
"""
|
|
Compute the number of K tiles for this expert.
|
|
|
|
2Dx3D: K = hidden (fixed) -> k_tile_cnt = ceil(hidden / cta_tile_k)
|
|
2Dx2D: K = tokens_i (variable) -> k_tile_cnt = ceil(tokens_i / cta_tile_k)
|
|
"""
|
|
if const_expr(self.scenario == "2Dx3D"):
|
|
# K is hidden (fixed)
|
|
return (self.hidden + self.cta_tile_k - 1) // self.cta_tile_k
|
|
else: # 2Dx2D
|
|
# K is tokens_i (variable per expert)
|
|
tokens_i = self.offs[expert_idx]
|
|
if expert_idx > cutlass.Int32(0):
|
|
tokens_i = tokens_i - self.offs[expert_idx - 1] # type: ignore[operator]
|
|
return (tokens_i + self.cta_tile_k - 1) // self.cta_tile_k # type: ignore[return-value, operator]
|