Files
nvfp4-megamoe-kernel/cutedsl/kernel/moe/moe_sched_extension.py
biondizzle ca28f1335d refactor: copy CuTeDSL kernel into repo with local imports
Copied from CUTLASS examples (no more runtime dependency on
/root/cutlass/examples/). Fixed all imports to use cutedsl.kernel.*
instead of blackwell.kernel.*.

Structure:
  cutedsl/__init__.py
  cutedsl/kernel/__init__.py
  cutedsl/kernel/moe/  (the MoE scaled grouped GEMM)
  cutedsl/kernel/blockscaled_gemm/  (dense blockscaled GEMM)

test_cutedsl.py updated to import from our local copy.
2026-05-16 02:57:54 +00:00

444 lines
20 KiB
Python

# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
MoE Scheduler Extension.
Bridges the MoE tile scheduler (MoEStaticPersistentTileScheduler) with tensor-level
domain conversion and TMA descriptor selection. This is the "glue" layer between:
- Scheduler: produces MoEWorkTileInfo (expert_idx, tile_m, tile_n, k_tile_cnt)
- OnlineTensormapDescCreator: builds/retrieves TMA descriptors from workspace
- Kernel: orchestrates everything
Different kernel types (grouped_mm, scaled_grouped_mm, etc.) provide their own
MoESchedExtension subclass with kernel-specific domain conversion logic.
Key design principles:
- Unified interface: get_gmem_tensor() for all tensor types
- Free implementation: no role-based templates, each subclass writes its own logic
- Composable utilities: compute_expert_token_range, rewrite_tensor_shape, etc.
are available as tools but not mandatory
Architecture:
Scheduler ──(produces)──> MoEWorkTileInfo
expert_idx, tile_m, tile_n, k_cnt
v
Extension ──(uses)──> OnlineTensormapDescCreator
│ │
│ get_gmem_tensor() │ get_desc_ptr()
│ prefetch_for_expert() │ construct_and_write()
│ │
└── internal calls ───────┘
Kernel (caller): the only place that knows all three exist
"""
from abc import ABC, abstractmethod
from typing import Literal, Tuple, Union
import cutlass
import cutlass.cute as cute
from cutlass.cute.typing import Pointer
from cutlass.cutlass_dsl import Int32
from dataclasses import dataclass
from cutlass.utils.blockscaled_layout import tile_atom_to_shape_SF
from cutedsl.kernel.moe.moe_utils import (
OnlineTensormapDescCreator,
tensormap_ptr_for_copy,
compute_expert_token_range,
rewrite_tensor_shape,
prefetch_tma_descriptor,
)
from cutedsl.kernel.moe.moe_persistent_scheduler import MoEWorkTileInfo
@dataclass(frozen=True)
class MoESchedExtension(ABC):
"""
Abstract base class for MoE scheduler extensions.
Bridges MoEWorkTileInfo with tensor-level domain conversion and TMA
descriptor selection. Each kernel type (grouped_mm, scaled_grouped_mm, etc.)
provides its own subclass with kernel-specific logic.
The extension:
- Holds a reference to an OnlineTensormapDescCreator for expert-wise desc retrieval
- Implements get_gmem_tensor() to convert MoE-view tensors to per-expert tensors
- Implements prefetch_for_expert() to prefetch expert-wise TMA descriptors
Subclasses are free to add any additional attributes in __init__ (scenario,
codegen configs, etc.) and implement get_gmem_tensor with arbitrary logic
per tensor_name. No role-based templates or rigid patterns are imposed.
Usage in kernel (caller):
ext = ConcreteSchedExtension(tensormap_ctor, scenario=...)
while work_tile_info.is_valid_tile:
real_a, desc_a = ext.get_gmem_tensor("a", tma_tensor_a, offs, work_tile_info)
real_b, desc_b = ext.get_gmem_tensor("b", tma_tensor_b, offs, work_tile_info)
# Use real_a, desc_a in cute.copy ...
"""
def __init__(self, tensormap_ctor: OnlineTensormapDescCreator):
super().__init__()
self.tensormap_ctor = tensormap_ctor
@abstractmethod
def get_gmem_tensor(
self,
tensor_name: str,
gmem_tensor_in_moe_view: cute.Tensor,
offs: Union[cute.Tensor, Tuple[cute.Tensor, cute.Tensor]],
work_tile_info: MoEWorkTileInfo,
) -> Tuple[cute.Tensor, "Pointer | None"]:
"""
Convert an MoE-view tensor to the real per-expert tensor for the
current work tile, and return the appropriate TMA descriptor pointer.
The MoE-view tensor uses "fake" GEMM domain dimensions that span all
experts (e.g., fake_m = tokens_sum). This method slices/offsets it
to the current expert's actual region.
:param tensor_name: Identifies which tensor (e.g., "a", "b", "c", "sfa")
:param gmem_tensor_in_moe_view: Tensor in fake GEMM MNKL domain
:param offs: Either a single cumsum tensor (experts,), or a tuple of
(offs_token, offs_padded) where offs_padded provides
padded offsets for scale-factor domain conversion.
:param work_tile_info: Current work tile from the scheduler
:return: (real_tensor, tma_desc_ptr_or_none)
- real_tensor: domain-offset and shape-rewritten tensor for this expert
- tma_desc_ptr: expert-wise desc ptr (already converted for cute.copy),
or None if the caller should use the global TMA descriptor
"""
...
@abstractmethod
def prefetch_for_expert(self, expert_idx: Int32) -> None:
"""
Prefetch expert-wise TMA descriptors for the given expert.
Called when the scheduler advances to a new expert, allowing the TMA
descriptor cache to be warmed up before the descriptors are needed.
:param expert_idx: Index of the expert whose descriptors to prefetch
"""
...
# =============================================================================
# Grouped MM Extension
# =============================================================================
class GroupedMmSchedExtension(MoESchedExtension):
"""
MoE scheduler extension for grouped_mm: handles tensors a, b, c.
Domain conversion logic per scenario:
2Dx3D:
A: (fake_m, k, 1) -> offset fake_m by token_offset, global desc
B: (n, k, fake_l) -> offset fake_l by expert_idx, global desc
C: (fake_m, n, 1) -> rewrite shape only, expert-wise desc
2Dx2D:
A: (m, fake_k, 1) -> rewrite shape only, expert-wise desc
B: (n, fake_k, 1) -> rewrite shape only, expert-wise desc
C: (m, n, fake_l) -> offset fake_l by expert_idx, global desc
"""
def __init__(
self,
scenario: Literal["2Dx3D", "2Dx2D"],
tensormap_ctor: OnlineTensormapDescCreator,
):
super().__init__(tensormap_ctor)
self.scenario = scenario
@cute.jit
def get_gmem_tensor(
self,
tensor_name: str,
gmem_tensor_in_moe_view: cute.Tensor,
offs: cute.Tensor,
work_tile_info: MoEWorkTileInfo,
):
expert_idx = work_tile_info.expert_idx
token_offset, tokens_i = compute_expert_token_range(offs, expert_idx)
shape = gmem_tensor_in_moe_view.shape
c1 = cutlass.Int32(1)
if cutlass.const_expr(self.scenario == "2Dx3D"):
if cutlass.const_expr(tensor_name == "a"):
# A: (fake_m, k, 1) -> offset fake_m, global desc
real = cute.domain_offset((token_offset, 0, 0), gmem_tensor_in_moe_view)
real = rewrite_tensor_shape(real, (tokens_i, shape[1], c1)) # type: ignore[index]
return (real, None)
elif cutlass.const_expr(tensor_name == "b"):
# B: (n, k, fake_l) -> offset fake_l, global desc
real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view)
real = rewrite_tensor_shape(real, (shape[0], shape[1], c1)) # type: ignore[index]
return (real, None)
elif cutlass.const_expr(tensor_name == "c"):
# C: (fake_m, n, 1) -> expert-wise desc, no offset
real = rewrite_tensor_shape(
gmem_tensor_in_moe_view,
(tokens_i, shape[1], c1), # type: ignore[index]
)
desc = tensormap_ptr_for_copy(
self.tensormap_ctor.get_desc_ptr("c", expert_idx)
)
return (real, desc)
elif cutlass.const_expr(self.scenario == "2Dx2D"):
if cutlass.const_expr(tensor_name == "a"):
# A: (m, fake_k, 1) -> expert-wise desc, no offset
real = rewrite_tensor_shape(
gmem_tensor_in_moe_view,
(shape[0], tokens_i, c1), # type: ignore[index]
)
desc = tensormap_ptr_for_copy(
self.tensormap_ctor.get_desc_ptr("a", expert_idx)
)
return (real, desc)
elif cutlass.const_expr(tensor_name == "b"):
# B: (n, fake_k, 1) -> expert-wise desc, no offset
real = rewrite_tensor_shape(
gmem_tensor_in_moe_view,
(shape[0], tokens_i, c1), # type: ignore[index]
)
desc = tensormap_ptr_for_copy(
self.tensormap_ctor.get_desc_ptr("b", expert_idx)
)
return (real, desc)
elif cutlass.const_expr(tensor_name == "c"):
# C: (m, n, fake_l) -> offset fake_l, global desc
real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view)
real = rewrite_tensor_shape(real, (shape[0], shape[1], c1)) # type: ignore[index]
return (real, None)
raise ValueError("Invalid scenario or GEMM tensor name.")
@cute.jit
def prefetch_for_expert(self, expert_idx: Int32) -> None:
if cutlass.const_expr(self.scenario == "2Dx3D"):
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("c", expert_idx))
elif cutlass.const_expr(self.scenario == "2Dx2D"):
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("a", expert_idx))
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("b", expert_idx))
else:
raise ValueError("Invalid scenario.")
# =============================================================================
# Scaled Grouped MM Extension (block-scaled MoE)
# =============================================================================
class ScaledGroupedMmSchedExtension(MoESchedExtension):
"""
MoE scheduler extension for scaled_grouped_mm: handles a, b, c, sfa, sfb.
Extends GroupedMmSchedExtension with scale-factor tensor support.
SFA/SFB are passed as flat GEMM-domain tensors and atom-tiled per expert
via tile_atom_to_shape_SF.
The offs parameter is always a tuple (offs_token, offs_padded):
- offs_token: cumsum offsets in data (activation) domain
- offs_padded: cumsum offsets in scale-factor domain (padded to atom granularity)
sf_vec_size is obtained from self.tensormap_ctor.sf_vec_size.
Domain conversion logic per scenario:
2Dx3D:
A: (fake_m, k, 1) -> offset fake_m by token_offset, global desc
B: (n, k, fake_l) -> offset fake_l by expert_idx, global desc
C: (fake_m, n, 1) -> rewrite shape, expert-wise desc
SFA: (fake_m_pad, k_pad, 1) -> offset fake_m_pad by padded_offset,
atom-tile, global desc
SFB: (n_pad, k_pad, fake_l) -> offset fake_l by expert_idx,
atom-tile, global desc
2Dx2D:
A: (m, fake_k, 1) -> rewrite shape, expert-wise desc
B: (n, fake_k, 1) -> rewrite shape, expert-wise desc
C: (m, n, fake_l) -> offset fake_l by expert_idx, global desc
SFA: (m_pad, fake_k_pad, 1) -> offset fake_k_pad by padded_offset,
atom-tile, expert-wise desc
SFB: (n_pad, fake_k_pad, 1) -> offset fake_k_pad by padded_offset,
atom-tile, expert-wise desc
"""
def __init__(
self,
scenario: Literal["2Dx3D", "2Dx2D"],
tensormap_ctor: OnlineTensormapDescCreator,
):
super().__init__(tensormap_ctor)
self.scenario = scenario
@cute.jit
def get_gmem_tensor(
self,
tensor_name: str,
gmem_tensor_in_moe_view: cute.Tensor,
offs: Tuple[cute.Tensor, cute.Tensor],
work_tile_info: MoEWorkTileInfo,
):
# Unpack the offs tuple
offs_token, offs_padded = offs
expert_idx = work_tile_info.expert_idx
token_offset, tokens_i = compute_expert_token_range(offs_token, expert_idx)
padded_offset, padded_size_i = compute_expert_token_range(
offs_padded, expert_idx
)
shape = gmem_tensor_in_moe_view.shape
stride = gmem_tensor_in_moe_view.stride
c1 = cutlass.Int32(1)
sf_vec_size = self.tensormap_ctor.sf_vec_size
if cutlass.const_expr(self.scenario == "2Dx3D"):
if cutlass.const_expr(tensor_name == "a"):
# A: (fake_m, k, 1) -> offset fake_m, global desc
real = cute.domain_offset((token_offset, 0, 0), gmem_tensor_in_moe_view)
real = rewrite_tensor_shape(real, (tokens_i, shape[1], c1)) # type: ignore[index]
return (real, None)
elif cutlass.const_expr(tensor_name == "b"):
# B: (n, k, fake_l) -> offset fake_l, global desc
real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view)
real = rewrite_tensor_shape(real, (shape[0], shape[1], c1)) # type: ignore[index]
return (real, None)
elif cutlass.const_expr(tensor_name == "c"):
# C: (fake_m, n, 1) -> expert-wise desc
real = rewrite_tensor_shape(
gmem_tensor_in_moe_view,
(tokens_i, shape[1], c1), # type: ignore[index]
)
desc = tensormap_ptr_for_copy(
self.tensormap_ctor.get_desc_ptr("c", expert_idx)
)
return (real, desc)
elif cutlass.const_expr(tensor_name == "sfa"):
# SFA: (fake_m_pad, k_pad, 1) -> offset fake_m_pad, atom-tile, global desc
real = cute.domain_offset(
(padded_offset, 0, 0), gmem_tensor_in_moe_view
)
per_expert_shape = (padded_size_i, shape[1], c1) # type: ignore[index]
sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size)
real = cute.make_tensor(
real.iterator, cute.make_layout(sf_layout.shape, stride=stride)
)
return (real, None)
elif cutlass.const_expr(tensor_name == "sfb"):
# SFB: (n_pad, k_pad, fake_l) -> offset fake_l, atom-tile, global desc
real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view)
per_expert_shape = (shape[0], shape[1], c1) # type: ignore[index]
sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size)
real = cute.make_tensor(
real.iterator, cute.make_layout(sf_layout.shape, stride=stride)
)
return (real, None)
elif cutlass.const_expr(self.scenario == "2Dx2D"):
if cutlass.const_expr(tensor_name == "a"):
# A: (m, fake_k, 1) -> expert-wise desc
real = rewrite_tensor_shape(
gmem_tensor_in_moe_view,
(shape[0], tokens_i, c1), # type: ignore[index]
)
desc = tensormap_ptr_for_copy(
self.tensormap_ctor.get_desc_ptr("a", expert_idx)
)
return (real, desc)
elif cutlass.const_expr(tensor_name == "b"):
# B: (n, fake_k, 1) -> expert-wise desc
real = rewrite_tensor_shape(
gmem_tensor_in_moe_view,
(shape[0], tokens_i, c1), # type: ignore[index]
)
desc = tensormap_ptr_for_copy(
self.tensormap_ctor.get_desc_ptr("b", expert_idx)
)
return (real, desc)
elif cutlass.const_expr(tensor_name == "c"):
# C: (m, n, fake_l) -> offset fake_l, global desc
real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view)
real = rewrite_tensor_shape(real, (shape[0], shape[1], c1)) # type: ignore[index]
return (real, None)
elif cutlass.const_expr(tensor_name == "sfa"):
# SFA: (m_pad, fake_k_pad, 1) -> offset fake_k_pad, atom-tile, expert-wise desc
per_expert_shape = (shape[0], padded_size_i, c1) # type: ignore[index]
sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size)
real = rewrite_tensor_shape(gmem_tensor_in_moe_view, sf_layout.shape)
desc = tensormap_ptr_for_copy(
self.tensormap_ctor.get_desc_ptr("sfa", expert_idx)
)
return (real, desc)
elif cutlass.const_expr(tensor_name == "sfb"):
# SFB: (n_pad, fake_k_pad, 1) -> offset fake_k_pad, atom-tile, expert-wise desc
per_expert_shape = (shape[0], padded_size_i, c1) # type: ignore[index]
sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size)
real = rewrite_tensor_shape(gmem_tensor_in_moe_view, sf_layout.shape)
desc = tensormap_ptr_for_copy(
self.tensormap_ctor.get_desc_ptr("sfb", expert_idx)
)
return (real, desc)
raise ValueError("Invalid scenario or tensor name.")
@cute.jit
def prefetch_for_expert(self, expert_idx: Int32) -> None:
if cutlass.const_expr(self.scenario == "2Dx3D"):
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("c", expert_idx))
elif cutlass.const_expr(self.scenario == "2Dx2D"):
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("a", expert_idx))
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("b", expert_idx))
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("sfa", expert_idx))
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("sfb", expert_idx))
else:
raise ValueError("Invalid scenario.")