docs: add CuTeDSL rewrite plan + reference files

The C++ CUTLASS kernel is fundamentally broken (cosine 0.05 with real
data). Switching to NVIDIA's CuTeDSL approach based on their official
MoE scaled grouped GEMM example.

Reference files copied:
- moe_torch_scaled_grouped_mm.py (3900 lines — our new kernel)
- moe_utils.py, moe_persistent_scheduler.py, moe_sched_extension.py
- grouped_blockscaled_gemm.py, dense_blockscaled_gemm_persistent.py
- blockscaled_layout.py
This commit is contained in:
2026-05-16 02:41:51 +00:00
parent c4a262bd54
commit a2ea836c74
9 changed files with 15148 additions and 0 deletions

93
REWRITE_PLAN.md Normal file
View File

@@ -0,0 +1,93 @@
# NVFP4 MegaMoE Kernel Rewrite — CuTeDSL
## Decision
The C++ CUTLASS kernel is fundamentally broken. The GEMM produces cosine 0.05
with real data (despite SF remap = 0 errors and uniform tests passing). The
root cause is likely in how CUTLASS's C++ API handles FP4 packing/tiling
internally — something we can't easily debug or fix.
We're replacing it with NVIDIA's CuTeDSL approach (Python-based CUTLASS
kernels compiled via MLIR → PTX). This is what the NVIDIA CUTLASS team
recommends for Blackwell, and they have a working reference:
`cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/moe/torch_scaled_grouped_mm.py`
This is a 3900-line MoE scaled grouped GEMM that supports NVFP4 out of the box.
## Reference Files (copied to reference/)
- `reference/moe_torch_scaled_grouped_mm.py` — the kernel (3900 lines)
- `reference/moe_moe_utils.py` — tensormap constructor
- `reference/moe_moe_persistent_scheduler.py` — MoE tile scheduler
- `reference/moe_moe_sched_extension.py` — scheduler extension for scaled GEMM
- `reference/grouped_blockscaled_gemm.py` — simpler grouped blockscaled reference
- `reference/dense_blockscaled_gemm_persistent.py` — dense (non-grouped) reference
- `reference/blockscaled_layout.py` — SF layout utilities
## Architecture
### The Reference Kernel: ScaledGroupedGemmKernel
7-warp persistent kernel:
- Warps 0-3: Epilogue (TMEM → RMEM → SMEM → GMEM)
- Warp 4: MMA (tcgen05.mma.block_scale with SFA/SFB in TMEM)
- Warp 5: TMA load (A, B, SFA, SFB from GMEM → SMEM)
- Warp 6: Scheduler (MoE persistent tile scheduler)
Supports:
- NVFP4 (Float4E2M1FN + Float8E4M3FN + sf_vec_size=16)
- 2Dx3D scenario (our case): A(tokens, K) x B(experts, K, N) → C(tokens, N)
- PyTorch interface via torch.nn.functional.scaled_grouped_mm
### Our Adaptation
We need to:
1. Use ScaledGroupedGemmKernel directly (or with minimal adaptation)
2. Replace our Python pipeline (nvfp4_mega_moe.py) to call the CuTeDSL kernel
3. Handle the MoE-specific stuff (routing, gate/up fusion, SiLU, scatter)
The CuTeDSL kernel handles A/B/SFA/SFB tiling and MMA internally.
We NO LONGER need:
- Our custom SF remap kernel (CuTeDSL handles SF layouts natively)
- Our C++ CUTLASS extension (cutlass_nvfp4_gemm.cu)
- transform_nvfp4_weights_for_mega_moe (CuTeDSL uses the natural layout)
We STILL need:
- stage_activation (BF16 → FP4 quantization)
- The MoE routing logic (top-k selection, slot mapping)
- Gate/up fusion with up_correction
- SiLU activation
- Scatter with routing weights
## Plan
### Phase 1: Get the reference kernel running standalone
- Set up CuTeDSL build environment on B200
- Run the reference example with NVFP4 config
- Verify it produces correct output
### Phase 2: Integrate into our pipeline
- Create a Python module that wraps ScaledGroupedGemmKernel
- Replace cutlass_nvfp4_gemm calls with CuTeDSL kernel calls
- Handle weight format (the CuTeDSL kernel expects raw FP4 + SF, no transpose)
### Phase 3: Full MoE pipeline
- Wire up stage_activation → L1 GEMM → SiLU → stage_activation → L2 GEMM → scatter
- Test with layertest.py (should get cosine ~0.995)
### Phase 4: vLLM integration
- Update the vLLM patch to use the CuTeDSL kernel
- Remove the C++ extension build from the Dockerfile
- Test full inference
## Key Differences from Current Approach
| Current (C++ CUTLASS) | New (CuTeDSL) |
|----------------------|---------------|
| C++ .cu file | Python .py file |
| Manual SF remap kernel | CuTe handles SF natively |
| Manual weight transpose | CuTe handles layout |
| pip install C++ extension | cute.compile() JIT |
| Broken FP4 handling | Reference-verified |
| No Blackwell pipeline | Full TMA/MMA/Epilogue overlap |

View File

@@ -0,0 +1,657 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# Use of this software is governed by the terms and conditions of the
# NVIDIA End User License Agreement (EULA), available at:
# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html
#
# Any use, reproduction, disclosure, or distribution of this software
# and related documentation outside the scope permitted by the EULA
# is strictly prohibited.
from dataclasses import dataclass, field
from typing import Optional
from cutlass.cutlass_dsl import dsl_user_op
import cutlass.cute as cute
from cutlass.cute.nvgpu import OperandMajorMode
from cutlass._mlir import ir
import cutlass._mlir.dialects.cute as _cute_ir
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
@dataclass(frozen=True)
class BlockScaledBasicChunk:
"""
The basic scale factor atom layout decided by tcgen05 BlockScaled MMA Ops.
This class represents the fixed layout pattern for scale factors used in
tcgen05 BlockScaled MMA Ops. The layout is determined by the
instruction specification and cannot be modified.
See `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x>`.
"""
sf_vec_size: int
major_mode: OperandMajorMode = OperandMajorMode.K
_layout: cute.Layout = field(init=False, repr=False)
def __post_init__(self) -> None:
if self.major_mode == OperandMajorMode.K:
# K-major layout: (AtomMN, AtomK)
atom_shape = ((32, 4), (self.sf_vec_size, 4))
atom_stride = ((16, 4), (0, 1))
else:
# MN-major layout: (AtomK, AtomMN)
atom_shape = ((self.sf_vec_size, 4), (32, 4))
atom_stride = ((0, 1), (16, 4))
object.__setattr__(
self, "_layout", cute.make_layout(atom_shape, stride=atom_stride)
)
@property
def layout(self) -> cute.Layout:
"""
Get the layout for this block scaled chunk.
:return: The layout representing the scale factor atom
:rtype: cute.Layout
"""
return self._layout
@dsl_user_op
def tile_atom_to_shape_SF(
Shape: cute.Shape,
sf_vec_size: int,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> cute.Layout:
"""
A helper function to get dynamic SFA/SFB layout by filling dynamic A/B shape to the scale factor atom layout.
:param Shape: The shape of the A/B tensor
:param sf_vec_size: Scale factor vector size
:return: The layout of the SFA/SFB tensor
:rtype: cute.Layout
"""
# ((Atom_MN, Rest_MN),(Atom_K, Rest_K),RestL)
sf_layout = cute.tile_to_shape(
BlockScaledBasicChunk(sf_vec_size).layout, Shape, (2, 1, 3)
)
return sf_layout
@dsl_user_op
def make_smem_layout_sf(
tile_shape: cute.Tile,
sf_vec_size: int,
num_stages: int,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> cute.Layout:
"""
A helper function to get dynamic SFA/SFB layout by filling dynamic A/B shape to the scale factor atom layout.
:param Shape: The shape of the A/B tensor
:param sf_vec_size: Scale factor vector size
:param num_stages: Number of stages
:return: The layout of the SFA/SFB tensor
:rtype: cute.Layout
"""
smem_layout = cute.tile_to_shape(
BlockScaledBasicChunk(sf_vec_size).layout,
tile_shape, # type: ignore[arg-type]
(2, 1),
loc=loc,
ip=ip,
)
smem_layout_staged = cute.append(
smem_layout,
cute.make_layout(
num_stages,
stride=cute.cosize(cute.filter_zeros(smem_layout)),
loc=loc,
ip=ip,
),
loc=loc,
ip=ip,
)
return smem_layout_staged
@dsl_user_op
def make_smem_layout_sfa(
tiled_mma: cute.TiledMma,
mma_tiler_mnk: cute.Tile,
sf_vec_size: int,
num_stages: int,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> cute.Layout:
"""
Make smem layout for SFA based on:
1. BlockScaledBasicChunk
2. MMA tiler shape
3. Scale factor vector size
4. Number of stages
:param tiled_mma: The tiled MMA
:type tiled_mma: cute.TiledMma
:param mma_tiler_mnk: The mma tiler shape
:type mma_tiler_mnk: cute.Tile
:param sf_vec_size: The scale factor vector size
:type sf_vec_size: int
:param num_stages: The number of stages
:type num_stages: int
:return: Smem layout for SFA
:rtype: cute.Layout
"""
# (CTA_Tile_Shape_M, MMA_Tile_Shape_K)
sfa_tile_shape = (
mma_tiler_mnk[0] // cute.size(tiled_mma.thr_id.shape), # type: ignore[index]
mma_tiler_mnk[2], # type: ignore[index]
)
# ((Atom_M, Rest_M),(Atom_K, Rest_K))
smem_layout = cute.tile_to_shape(
BlockScaledBasicChunk(sf_vec_size).layout,
sfa_tile_shape, # type: ignore[arg-type]
(2, 1),
)
# Number of MMA instructions to cover all k-tiles
mma_tile_inst_m = mma_tiler_mnk[0] // cute.size(tiled_mma.shape_mnk, mode=[0]) # type: ignore[index]
mma_tile_inst_k = mma_tiler_mnk[2] // cute.size(tiled_mma.shape_mnk, mode=[2]) # type: ignore[index]
# (CTA_Tile_Shape_M, MMA_Inst_Shape_K)
sfa_tile_shape = cute.shape_div(sfa_tile_shape, (mma_tile_inst_m, mma_tile_inst_k))
# ((Atom_Inst_M, Atom_Inst_K), MMA_M, MMA_K))
smem_layout = cute.tiled_divide(smem_layout, sfa_tile_shape)
atom_m = 128
tiler_inst = ((atom_m, sf_vec_size),)
# (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K)
smem_layout = cute.logical_divide(smem_layout, tiler_inst)
# (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K, STAGE)
sfa_smem_layout_staged = cute.append(
smem_layout,
cute.make_layout(
num_stages, stride=cute.cosize(cute.filter_zeros(smem_layout))
),
)
return sfa_smem_layout_staged
@dsl_user_op
def make_smem_layout_sfb(
tiled_mma: cute.TiledMma,
mma_tiler_mnk: cute.Tile,
sf_vec_size: int,
num_stages: int,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> cute.Layout:
"""
Make smem layout for SFB based on:
1. BlockScaledBasicChunk
2. MMA tiler shape
3. Scale factor vector size
4. Number of stages
:param tiled_mma: The tiled MMA
:type tiled_mma: cute.TiledMma
:param mma_tiler_mnk: The mma tiler shape
:type mma_tiler_mnk: cute.Tile
:param sf_vec_size: The scale factor vector size
:type sf_vec_size: int
:param num_stages: The number of stages
:type num_stages: int
:return: Smem layout for SFA
:rtype: cute.Layout
"""
# (Round_Up(CTA_Tile_Shape_N, 128), MMA_Tile_Shape_K)
sfb_tile_shape = (
cute.round_up(mma_tiler_mnk[1], 128), # type: ignore[index, arg-type]
mma_tiler_mnk[2], # type: ignore[index]
)
# ((Atom_N, Rest_N),(Atom_K, Rest_K))
smem_layout = cute.tile_to_shape(
BlockScaledBasicChunk(sf_vec_size).layout,
sfb_tile_shape, # type: ignore[arg-type]
(2, 1),
)
# Number of MMA instructions to cover all k-tiles
mma_tile_inst_n = mma_tiler_mnk[1] // cute.size(tiled_mma.shape_mnk, mode=[1]) # type: ignore[index]
mma_tile_inst_k = mma_tiler_mnk[2] // cute.size(tiled_mma.shape_mnk, mode=[2]) # type: ignore[index]
# (CTA_Tile_Shape_N, MMA_Inst_Shape_K)
sfb_tile_shape = cute.shape_div(sfb_tile_shape, (mma_tile_inst_n, mma_tile_inst_k))
# ((Atom_Inst_N, Atom_Inst_K), MMA_N, MMA_K)
smem_layout = cute.tiled_divide(smem_layout, sfb_tile_shape)
atom_n = 128
tiler_inst = ((atom_n, sf_vec_size),)
# (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K)
smem_layout = cute.logical_divide(smem_layout, tiler_inst)
# (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K, STAGE)
sfb_smem_layout_staged = cute.append(
smem_layout,
cute.make_layout(
num_stages, stride=cute.cosize(cute.filter_zeros(smem_layout))
),
)
return sfb_smem_layout_staged
@dsl_user_op
def sm120_make_smem_layout_sfa(
tiled_mma: cute.TiledMma,
tile_shape_mnk: cute.Tile,
sf_vec_size: int,
num_stages: int,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> cute.Layout:
"""
Make smem layout for SFA based on:
1. BlockScaledBasicChunk
2. MMA tiler shape
3. Scale factor vector size
4. Number of stages
:param tiled_mma: The tiled MMA
:type tiled_mma: cute.TiledMma
:param mma_tiler_mnk: The mma tiler shape
:type mma_tiler_mnk: cute.Tile
:param sf_vec_size: The scale factor vector size
:type sf_vec_size: int
:param num_stages: The number of stages
:type num_stages: int
:return: Smem layout for SFA
:rtype: cute.Layout
"""
assert sf_vec_size == 16 or sf_vec_size == 32, "sf_vec_size must be 16 or 32"
blk_mn = 128
blk_sf = 4
blk_elems = blk_mn * blk_sf
mma_nsf = tiled_mma.shape_mnk[2] // sf_vec_size
mn_basic_block_shape = (32, 4)
mn_basic_block_stride = (16, 4)
k_basic_block_shape = (sf_vec_size, mma_nsf)
k_basic_block_stride = (0, 1)
assert tile_shape_mnk[0] % blk_mn == 0, ( # type: ignore[index, operator]
"tile_shape_mnk[0] must be divisible by blk_mn"
)
sSFA_shapeM = (mn_basic_block_shape, tile_shape_mnk[0] // blk_mn) # type: ignore[index, operator]
sSF_strideM = (mn_basic_block_stride, blk_elems)
assert tile_shape_mnk[2] % (blk_sf * mma_nsf) == 0, ( # type: ignore[index]
"tile_shape_mnk[2] must be divisible by blk_sf * mma_nsf"
)
sSFA_shapeK = (
k_basic_block_shape,
blk_sf // mma_nsf,
tile_shape_mnk[2] // sf_vec_size // blk_sf, # type: ignore[index, operator]
)
sSF_strideK = (
k_basic_block_stride,
mma_nsf,
tile_shape_mnk[0] // blk_mn * blk_elems, # type: ignore[index, operator]
)
sSFA_shape = (sSFA_shapeM, sSFA_shapeK)
sSFA_stride = (sSF_strideM, sSF_strideK)
smem_layout = cute.make_layout(sSFA_shape, stride=sSFA_stride)
# (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K, STAGE)
sfa_smem_layout_staged = cute.append(
smem_layout,
cute.make_layout(
num_stages, stride=cute.cosize(cute.filter_zeros(smem_layout))
),
)
return sfa_smem_layout_staged
@dsl_user_op
def sm120_make_smem_layout_sfb(
tiled_mma: cute.TiledMma,
tile_shape_mnk: cute.Tile,
sf_vec_size: int,
num_stages: int,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> cute.Layout:
"""
Make smem layout for SFB based on:
1. BlockScaledBasicChunk
2. MMA tiler shape
3. Scale factor vector size
4. Number of stages
:param tiled_mma: The tiled MMA
:type tiled_mma: cute.TiledMma
:param mma_tiler_mnk: The mma tiler shape
:type mma_tiler_mnk: cute.Tile
:param sf_vec_size: The scale factor vector size
:type sf_vec_size: int
:param num_stages: The number of stages
:type num_stages: int
:return: Smem layout for SFA
:rtype: cute.Layout
"""
# A single indivisible block will hold 4 scale factors of 128 rows/columns (A/B matrix).
# 4 is chosen to make consecutive 32bits of data to have scale factors for only a single row(col).
blk_mn = 128
blk_sf = 4
blk_elems = blk_mn * blk_sf
assert sf_vec_size == 16 or sf_vec_size == 32, "sf_vec_size must be 16 or 32"
assert tile_shape_mnk[1] % blk_mn == 0, ( # type: ignore[index, operator]
"tile_shape_mnk[1] must be divisible by blk_mn"
)
assert tile_shape_mnk[2] % sf_vec_size == 0, ( # type: ignore[index, operator]
"tile_shape_mnk[2] must be divisible by sf_vec_size"
)
mma_nsf = tiled_mma.shape_mnk[2] // sf_vec_size
mn_basic_block_shape = (32, 4)
mn_basic_block_stride = (16, 4)
k_basic_block_shape = (sf_vec_size, mma_nsf)
k_basic_block_stride = (0, 1)
assert tile_shape_mnk[1] % blk_mn == 0, ( # type: ignore[index, operator]
"tile_shape_mnk[1] must be divisible by blk_mn"
)
sSFA_shapeN = (mn_basic_block_shape, tile_shape_mnk[1] // blk_mn) # type: ignore[index, operator]
sSF_strideN = (mn_basic_block_stride, blk_elems)
assert tile_shape_mnk[2] % (blk_sf * mma_nsf) == 0, ( # type: ignore[index]
"tile_shape_mnk[2] must be divisible by blk_sf * mma_nsf"
)
sSFA_shapeK = (
k_basic_block_shape,
blk_sf // mma_nsf,
tile_shape_mnk[2] // sf_vec_size // blk_sf, # type: ignore[index, operator]
)
sSF_strideK = (
k_basic_block_stride,
mma_nsf,
tile_shape_mnk[1] // blk_mn * blk_elems, # type: ignore[index, operator]
)
sSFA_shape = (sSFA_shapeN, sSFA_shapeK)
sSFA_stride = (sSF_strideN, sSF_strideK)
smem_layout = cute.make_layout(sSFA_shape, stride=sSFA_stride)
# (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K, STAGE)
sfb_smem_layout_staged = cute.append(
smem_layout,
cute.make_layout(
num_stages, stride=cute.cosize(cute.filter_zeros(smem_layout))
),
)
return sfb_smem_layout_staged
@dsl_user_op
def make_tmem_layout_sfa(
tiled_mma: cute.TiledMma,
mma_tiler_mnk: cute.Tile,
sf_vec_size: int,
smem_layout: cute.Layout,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> cute.Layout:
"""Make tmem layout for SFA based on:
1. SFA smem layout per stage
2. Cta tile shape m
3. tiled MMA atom thr size
4. Scale factor vector size
:param tiled_mma: The tiled MMA
:type tiled_mma: cute.TiledMma
:param mma_tiler_mnk: The mma tiler shape
:type mma_tiler_mnk: cute.Tile
:param sf_vec_size: The scale factor vector size
:type sf_vec_size: int
:param smem_layout: The smem layout of SFA per stage
:type smem_layout: cute.Layout
:return: TMEM layout for SFA
:rtype: cute.Layout
"""
atom_thr_size = cute.size(tiled_mma.thr_id.shape, loc=loc, ip=ip)
cta_tile_shape_m = mma_tiler_mnk[0] // atom_thr_size # type: ignore[index]
sfa_layout_ty = _cute_nvgpu_ir.make_tmem_layout_sfa(
smem_layout, cta_tile_shape_m, atom_thr_size, sf_vec_size
)
return _cute_ir.static(sfa_layout_ty, loc=loc, ip=ip)
@dsl_user_op
def make_tmem_layout_sfb(
tiled_mma: cute.TiledMma,
mma_tiler_mnk: cute.Tile,
sf_vec_size: int,
smem_layout: cute.Layout,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> cute.Layout:
"""Make tmem layout for SFB based on:
1. SFB smem layout per stage
2. Cta tile shape m
3. tiled MMA atom thr size
4. Scale factor vector size
:param tiled_mma: The tiled MMA
:type tiled_mma: cute.TiledMma
:param mma_tiler_mnk: The mma tiler shape
:type mma_tiler_mnk: cute.Tile
:param sf_vec_size: The scale factor vector size
:type sf_vec_size: int
:param smem_layout: The smem layout of SFB per stage
:type smem_layout: cute.Layout
:return: TMEM layout for SFB
:rtype: cute.Layout
"""
atom_thr_size = cute.size(tiled_mma.thr_id.shape, loc=loc, ip=ip)
cta_tile_shape_m = mma_tiler_mnk[0] // atom_thr_size # type: ignore[index]
sfb_layout_ty = _cute_nvgpu_ir.make_tmem_layout_sfb(
smem_layout, cta_tile_shape_m, atom_thr_size, sf_vec_size
)
return _cute_ir.static(sfb_layout_ty, loc=loc, ip=ip)
@dataclass(frozen=True)
class Sm103BlockScaledBasicChunk:
"""
Basic scale-factor atom layout decided by tcgen05 BlockScaled MMA Ops on SM103.
Represents the fixed layout pattern for scale factors used by tcgen05
BlockScaled MMA Ops on SM103. The layout is determined by the instruction
specification and is not configurable.
"""
sf_vec_size: int
major_mode: OperandMajorMode = OperandMajorMode.K
_layout: cute.Layout = field(init=False, repr=False)
def __post_init__(self) -> None:
atom_shape: cute.Shape
atom_stride: cute.Stride
if self.major_mode == OperandMajorMode.K:
atom_shape = ((8, 4, 4), (self.sf_vec_size, 4))
atom_stride = ((16, 128, 4), (0, 1))
else:
atom_shape = ((self.sf_vec_size, 4), (8, 4, 4))
atom_stride = ((0, 1), (16, 128, 4))
object.__setattr__(
self, "_layout", cute.make_layout(shape=atom_shape, stride=atom_stride)
)
@property
def layout(self) -> cute.Layout:
return self._layout
@dsl_user_op
def sm103_make_smem_layout_sfa(
tiled_mma: cute.TiledMma,
mma_tiler: cute.Tile,
sf_vec_size: int,
num_stages: int,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> cute.Layout:
"""
Make SMEM layout for SFA based on:
1) Sm103BlockScaledBasicChunk, 2) MMA tiler, 3) sf_vec_size, 4) stages.
:param tiled_mma: The tiled MMA
:type tiled_mma: cute.TiledMma
:param mma_tiler: The mma tiler shape
:type mma_tiler: cute.Tile
:param sf_vec_size: The scale factor vector size
:type sf_vec_size: int
:param num_stages: The number of stages
:type num_stages: int
:return: Smem layout for SFA
:rtype: cute.Layout
"""
mma_shape_mk = tiled_mma.partition_shape_A((mma_tiler[0], mma_tiler[2])) # type: ignore[index]
sf_atom = Sm103BlockScaledBasicChunk(sf_vec_size, tiled_mma.op.a_major_mode).layout # type: ignore[attr-defined]
k_divisor = 4 if sf_vec_size == 16 else 2
mma_sfa_tiler = (
mma_shape_mk[0][0] * mma_shape_mk[1],
mma_shape_mk[0][1] * mma_shape_mk[2] // k_divisor,
)
sfa_smem_atom_layout = cute.tiled_product(
sf_atom,
cute.make_layout(
cute.shape_div(mma_sfa_tiler, cute.product_each(sf_atom.shape))
),
)
sfa_smem_layout_staged = cute.make_layout(
shape=cute.append(sfa_smem_atom_layout.shape, num_stages),
stride=cute.append(
sfa_smem_atom_layout.stride,
cute.size(cute.filter_zeros(sfa_smem_atom_layout)),
),
)
return sfa_smem_layout_staged
@dsl_user_op
def sm103_make_smem_layout_sfb(
tiled_mma: cute.TiledMma,
mma_tiler: cute.Tile,
sf_vec_size: int,
num_stages: int,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> cute.Layout:
"""
Make SMEM layout for SFB based on the basic chunk, MMA tiler, sf_vec_size, stages.
:param tiled_mma: The tiled MMA
:type tiled_mma: cute.TiledMma
:param mma_tiler: The mma tiler shape
:type mma_tiler: cute.Tile
:param sf_vec_size: The scale factor vector size
:type sf_vec_size: int
:param num_stages: The number of stages
:type num_stages: int
:return: Smem layout for SFB
:rtype: cute.Layout
"""
sf_atom = Sm103BlockScaledBasicChunk(sf_vec_size, tiled_mma.op.b_major_mode).layout # type: ignore[attr-defined]
k_divisor = 4 if sf_vec_size == 16 else 2
mma_sfb_tiler = (mma_tiler[1], mma_tiler[2] // k_divisor) # type: ignore[index, operator]
if mma_sfb_tiler[0] == 128:
sfb_smem_atom_layout = cute.tiled_product(
sf_atom,
cute.make_layout(
cute.shape_div(mma_sfb_tiler, cute.product_each(sf_atom.shape))
),
)
else:
sf_k_major_atom256 = cute.make_layout(
shape=(
(32, 4, 2),
(sf_vec_size, 4),
),
stride=(
(16, 4, mma_sfb_tiler[1] // sf_vec_size // 4 * 512),
(0, 1),
),
)
sfb_smem_atom_layout = cute.tiled_product(
sf_k_major_atom256,
cute.make_layout(
cute.shape_div(
mma_sfb_tiler, cute.product_each(sf_k_major_atom256.shape)
)
),
)
sfb_smem_layout_staged = cute.make_layout(
shape=cute.append(sfb_smem_atom_layout.shape, num_stages),
stride=cute.append(
sfb_smem_atom_layout.stride,
cute.size(cute.filter_zeros(sfb_smem_atom_layout)),
),
)
return sfb_smem_layout_staged

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,695 @@
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
MoE Persistent Tile Scheduler
A specialized tile scheduler for MoE (Mixture of Experts) grouped GEMM operations.
This scheduler handles tile iteration across all experts, producing MoEWorkTileInfo
(expert_idx, tile_m_idx, tile_n_idx, k_tile_cnt) for each tile.
Scenarios:
- 2Dx3D (Forward): A(tokens_sum, hidden) x B(experts, intermediate, hidden) -> C(tokens_sum, intermediate)
- 2Dx2D (Backward): A(intermediate, tokens_sum) x B(hidden, tokens_sum) -> C(experts, intermediate, hidden)
Key design principle:
- Scheduler is ONLY responsible for tile iteration (tensor-agnostic, TMA-agnostic)
- Domain conversion (fake tensor -> real expert tensor) is handled by MoESchedExtension
- TMA descriptor management is handled by OnlineTensormapDescCreator
- The kernel orchestrates all three components
"""
from typing import List, Tuple, Literal
import cutlass
import cutlass.cute as cute
from cutlass.cutlass_dsl import (
Boolean,
Int32,
Integer,
extract_mlir_values,
new_from_mlir_values,
const_expr,
dsl_user_op,
)
from cutlass._mlir import ir
# =============================================================================
# Work Tile Info
# =============================================================================
class MoEWorkTileInfo:
"""
Work tile information for MoE scheduler.
Contains CTA-level tile information for executor warps:
- expert_idx: Which expert (-1 means invalid/done)
- tile_m_idx: CTA tile index along GEMM M dimension
- tile_n_idx: CTA tile index along GEMM N dimension
- k_tile_cnt: Number of CTA tiles along K dimension
Note: These are CTA-level indices, not cluster-level.
tile_l_idx is always 0 for MoE, executor can hardcode it.
For 2Dx3D (Forward):
M = tokens_i (dynamic), N = intermediate (fixed), K = hidden (fixed)
For 2Dx2D (Backward):
M = intermediate (fixed), N = hidden (fixed), K = tokens_i (dynamic)
"""
def __init__(
self,
expert_idx: Int32, # -1 means invalid tile
tile_m_idx: Int32,
tile_n_idx: Int32,
k_tile_cnt: Int32,
):
self.expert_idx = expert_idx
self.tile_m_idx = tile_m_idx
self.tile_n_idx = tile_n_idx
self.k_tile_cnt = k_tile_cnt
@property
def is_valid_tile(self) -> Boolean:
"""Check if this is a valid work tile (expert_idx >= 0)."""
return self.expert_idx >= Int32(0)
def __extract_mlir_values__(self) -> List[ir.Value]:
values = extract_mlir_values(self.expert_idx)
values.extend(extract_mlir_values(self.tile_m_idx))
values.extend(extract_mlir_values(self.tile_n_idx))
values.extend(extract_mlir_values(self.k_tile_cnt))
return values
def __new_from_mlir_values__(self, values: List[ir.Value]) -> "MoEWorkTileInfo":
assert len(values) == 4
return MoEWorkTileInfo(
expert_idx=new_from_mlir_values(self.expert_idx, [values[0]]),
tile_m_idx=new_from_mlir_values(self.tile_m_idx, [values[1]]),
tile_n_idx=new_from_mlir_values(self.tile_n_idx, [values[2]]),
k_tile_cnt=new_from_mlir_values(self.k_tile_cnt, [values[3]]),
)
def to_rmem_tensor(self):
"""Pack work tile info fields into an rmem tensor of shape (4,) for vectorized smem copy."""
rmem = cute.make_rmem_tensor((4,), Int32)
rmem[0] = self.expert_idx
rmem[1] = self.tile_m_idx
rmem[2] = self.tile_n_idx
rmem[3] = self.k_tile_cnt
return rmem
@staticmethod
def from_rmem_tensor(rmem) -> "MoEWorkTileInfo":
"""Unpack work tile info from an rmem tensor of shape (4,)."""
return MoEWorkTileInfo(
expert_idx=rmem[0], # type: ignore[arg-type]
tile_m_idx=rmem[1], # type: ignore[arg-type]
tile_n_idx=rmem[2], # type: ignore[arg-type]
k_tile_cnt=rmem[3], # type: ignore[arg-type]
)
# =============================================================================
# Scheduler Parameters
# =============================================================================
class MoEStaticSchedulerParams:
"""
Parameters for MoE tile scheduler.
Uses unified semantics for both scenarios:
- expert_shape: (expert_cnt, intermediate, hidden)
For 2Dx3D: GEMM is (M=tokens_i, N=intermediate, K=hidden) per expert
For 2Dx2D: GEMM is (M=hidden, N=intermediate, K=tokens_i) per expert
Tile hierarchy:
- cta_tile_shape_mnk: Single CTA tile shape (tile_m, tile_n, tile_k)
- cluster_shape_mn: CTAs per cluster (cluster_m, cluster_n)
- cluster_tile_shape_mn: Cluster tile shape = cta_tile_shape * cluster_shape
This class is used both on host (for grid shape calculation) and on device
(stored in scheduler). Codegen-time constants (scenario, cta_tile_shape_mnk,
cluster_shape_mn) are NOT serialized to MLIR values.
"""
def __init__(
self,
scenario: Literal["2Dx3D", "2Dx2D"],
expert_shape: Tuple[int | Int32, int | Int32, int | Int32], # (expert_cnt, intermediate, hidden)
cta_tile_shape_mnk: Tuple[int, int, int], # (tile_m, tile_n, tile_k)
cluster_shape_mn: Tuple[int, int], # (cluster_m, cluster_n)
):
self.scenario = scenario
e, i, h = expert_shape
self.expert_cnt = e if isinstance(e, Int32) else Int32(e)
self.intermediate = i if isinstance(i, Int32) else Int32(i)
self.hidden = h if isinstance(h, Int32) else Int32(h)
self.cta_tile_shape_mnk = cta_tile_shape_mnk
self.cluster_shape_mn = cluster_shape_mn
@property
def cluster_tile_m(self) -> int:
"""Cluster tile size along M = cta_tile_m * cluster_m."""
return self.cta_tile_shape_mnk[0] * self.cluster_shape_mn[0]
@property
def cluster_tile_n(self) -> int:
"""Cluster tile size along N = cta_tile_n * cluster_n."""
return self.cta_tile_shape_mnk[1] * self.cluster_shape_mn[1]
@property
def cta_tile_k(self) -> int:
"""CTA tile size along K (same as cluster since cluster_k = 1)."""
return self.cta_tile_shape_mnk[2]
def __extract_mlir_values__(self) -> List[ir.Value]:
"""Only serialize runtime values, not codegen-time constants."""
values = []
values.extend(extract_mlir_values(self.expert_cnt))
values.extend(extract_mlir_values(self.intermediate))
values.extend(extract_mlir_values(self.hidden))
return values
def __new_from_mlir_values__(self, values: List[ir.Value]) -> "MoEStaticSchedulerParams":
assert len(values) == 3
return MoEStaticSchedulerParams(
scenario=self.scenario,
expert_shape=(
new_from_mlir_values(self.expert_cnt, [values[0]]),
new_from_mlir_values(self.intermediate, [values[1]]),
new_from_mlir_values(self.hidden, [values[2]]),
),
cta_tile_shape_mnk=self.cta_tile_shape_mnk,
cluster_shape_mn=self.cluster_shape_mn,
)
@staticmethod
def get_grid_shape(
params: "MoEStaticSchedulerParams",
max_active_clusters: int,
) -> Tuple[int, int, int]:
"""
Compute grid shape for kernel launch.
Since host doesn't know token distribution across experts,
we launch max_active_clusters and let device-side scheduler
determine which tiles are valid.
"""
return (
params.cluster_shape_mn[0],
params.cluster_shape_mn[1],
max_active_clusters,
)
# =============================================================================
# Scheduler (Device-side)
# =============================================================================
class MoEStaticPersistentTileScheduler:
"""
Persistent tile scheduler specialized for MoE grouped GEMM.
This scheduler is ONLY responsible for tile iteration. It does NOT know
about tensor types, TMA descriptors, or domain conversion. Those concerns
are handled by MoESchedExtension and OnlineTensormapDescCreator respectively.
Architecture:
- Scheduler warp: Holds scheduler instance, iterates tiles, broadcasts work_tile_info
- Executor warps: Read work_tile_info from smem, use MoESchedExtension for
domain conversion and TMA desc selection
The scheduler handles:
- 2Dx3D: Dynamic M per expert (from offs), fixed N (intermediate) and K (hidden)
- 2Dx2D: Fixed M (intermediate) and N (hidden), dynamic K per expert (reduction axis)
Usage (Scheduler warp):
scheduler = MoEStaticPersistentTileScheduler.create(params, offs, block_idx, grid_dim)
work_tile_info = scheduler.initial_work_tile_info()
# Broadcast work_tile_info to smem...
while work_tile_info.is_valid_tile:
# ... do work ...
work_tile_info = scheduler.advance_to_next_work()
# Broadcast work_tile_info to smem...
Usage (Executor warps - via MoESchedExtension):
# Read work_tile_info from smem...
real_a, desc_a = ext.get_gmem_tensor("a", tma_tensor_a, offs, work_tile_info)
real_b, desc_b = ext.get_gmem_tensor("b", tma_tensor_b, offs, work_tile_info)
real_c, desc_c = ext.get_gmem_tensor("c", tma_tensor_c, offs, work_tile_info)
"""
def __init__(
self,
# Params (contains scenario, expert_cnt, intermediate, hidden, tile/cluster shapes)
params: MoEStaticSchedulerParams,
# Runtime tensor for scheduling
offs: cute.Tensor, # (experts,) cumsum of token counts
# Scheduling state
num_persistent_clusters: Int32,
current_work_linear_idx: Int32,
cta_id_in_cluster: cute.Coord,
# Expert tracking state (for O(1) advance within same expert)
current_expert_idx: Int32,
expert_tile_start: Int32, # cumsum of tiles before current expert
expert_tile_end: Int32, # cumsum of tiles including current expert
):
self.params = params
self.offs = offs
self.num_persistent_clusters = num_persistent_clusters
self._current_work_linear_idx = current_work_linear_idx
self.cta_id_in_cluster = cta_id_in_cluster
# Expert tracking
self.current_expert_idx = current_expert_idx
self.expert_tile_start = expert_tile_start
self.expert_tile_end = expert_tile_end
# =========================================================================
# Convenience accessors for params
# =========================================================================
@property
def scenario(self) -> Literal["2Dx3D", "2Dx2D"]:
return self.params.scenario
@property
def expert_cnt(self) -> Int32:
return self.params.expert_cnt
@property
def intermediate(self) -> Int32:
return self.params.intermediate
@property
def hidden(self) -> Int32:
return self.params.hidden
@property
def cta_tile_shape_mnk(self) -> Tuple[int, int, int]:
return self.params.cta_tile_shape_mnk
@property
def cluster_shape_mn(self) -> Tuple[int, int]:
return self.params.cluster_shape_mn
@property
def cluster_tile_m(self) -> int:
return self.params.cluster_tile_m
@property
def cluster_tile_n(self) -> int:
return self.params.cluster_tile_n
@property
def cta_tile_k(self) -> int:
return self.params.cta_tile_k
# =========================================================================
# MLIR value serialization (for SSA value passing in device code)
# =========================================================================
def __extract_mlir_values__(self) -> List[ir.Value]:
values = []
# Params (only runtime values are extracted)
values.extend(extract_mlir_values(self.params))
# Runtime tensor for scheduling
values.extend(extract_mlir_values(self.offs))
# Scheduling state
values.extend(extract_mlir_values(self.num_persistent_clusters))
values.extend(extract_mlir_values(self._current_work_linear_idx))
values.extend(extract_mlir_values(self.cta_id_in_cluster))
# Expert tracking state
values.extend(extract_mlir_values(self.current_expert_idx))
values.extend(extract_mlir_values(self.expert_tile_start))
values.extend(extract_mlir_values(self.expert_tile_end))
return values
def __new_from_mlir_values__(
self, values: List[ir.Value]
) -> "MoEStaticPersistentTileScheduler":
idx = 0
# Params (3 values: expert_cnt, intermediate, hidden)
new_params = new_from_mlir_values(self.params, values[idx:idx + 3])
idx += 3
# Runtime tensor for scheduling (variable size)
offs_len = len(extract_mlir_values(self.offs))
new_offs = new_from_mlir_values(self.offs, values[idx:idx + offs_len])
idx += offs_len
# Scheduling state
new_num_persistent_clusters = new_from_mlir_values(
self.num_persistent_clusters, [values[idx]]
)
idx += 1
new_current_work_linear_idx = new_from_mlir_values(
self._current_work_linear_idx, [values[idx]]
)
idx += 1
# cta_id_in_cluster (3 values for Coord)
new_cta_id_in_cluster = new_from_mlir_values(
self.cta_id_in_cluster, values[idx:idx + 3]
)
idx += 3
# Expert tracking state
new_current_expert_idx = new_from_mlir_values(
self.current_expert_idx, [values[idx]]
)
idx += 1
new_expert_tile_start = new_from_mlir_values(
self.expert_tile_start, [values[idx]]
)
idx += 1
new_expert_tile_end = new_from_mlir_values(
self.expert_tile_end, [values[idx]]
)
idx += 1
return MoEStaticPersistentTileScheduler(
params=new_params,
offs=new_offs,
num_persistent_clusters=new_num_persistent_clusters,
current_work_linear_idx=new_current_work_linear_idx,
cta_id_in_cluster=new_cta_id_in_cluster,
current_expert_idx=new_current_expert_idx,
expert_tile_start=new_expert_tile_start,
expert_tile_end=new_expert_tile_end,
)
# =========================================================================
# Factory method
# =========================================================================
@staticmethod
@dsl_user_op
def create(
params: MoEStaticSchedulerParams,
offs: cute.Tensor,
block_idx: Tuple[Integer, Integer, Integer],
grid_dim: Tuple[Integer, Integer, Integer],
*,
loc=None,
ip=None,
) -> "MoEStaticPersistentTileScheduler":
"""
Create a MoE persistent tile scheduler.
:param params: Scheduler parameters (from host)
:param offs: Cumsum tensor of token counts per expert, shape (experts,)
:param block_idx: CUDA block index
:param grid_dim: CUDA grid dimensions
"""
num_persistent_clusters = cute.size(grid_dim, loc=loc, ip=ip) // cute.size(
params.cluster_shape_mn, loc=loc, ip=ip
)
bidx, bidy, bidz = block_idx
current_work_linear_idx = Int32(bidz)
cta_id_in_cluster = (
Int32(bidx % params.cluster_shape_mn[0]),
Int32(bidy % params.cluster_shape_mn[1]),
Int32(0),
)
# Initialize expert tracking to "before expert 0"
# The first call to _get_work_tile_for_linear_idx will advance to the correct expert
current_expert_idx = Int32(0)
expert_tile_start = Int32(0)
expert_tile_end = Int32(0) # Will be computed on first access
return MoEStaticPersistentTileScheduler(
params=params,
offs=offs,
num_persistent_clusters=num_persistent_clusters,
current_work_linear_idx=current_work_linear_idx,
cta_id_in_cluster=cta_id_in_cluster,
current_expert_idx=current_expert_idx,
expert_tile_start=expert_tile_start,
expert_tile_end=expert_tile_end,
)
# =========================================================================
# Tile iteration methods
# =========================================================================
@dsl_user_op
@cute.jit
def initial_work_tile_info(self, *, loc=None, ip=None) -> MoEWorkTileInfo:
"""Get the initial work tile info."""
return self._get_work_tile_for_linear_idx(
self._current_work_linear_idx, loc=loc, ip=ip
)
@dsl_user_op
@cute.jit
def advance_to_next_work(self, *, loc=None, ip=None) -> MoEWorkTileInfo:
"""Advance to the next work tile and return its info."""
self._current_work_linear_idx += self.num_persistent_clusters
return self._get_work_tile_for_linear_idx(
self._current_work_linear_idx, loc=loc, ip=ip
)
@dsl_user_op
@cute.jit
def _get_work_tile_for_linear_idx(
self,
cluster_linear_idx: Int32,
*,
loc=None,
ip=None
) -> MoEWorkTileInfo:
"""
Convert a linear cluster index to MoEWorkTileInfo.
Uses cached expert tracking state for O(1) fast path when staying
within the same expert. Advances expert state when needed.
Returns an invalid tile (expert_idx = -1) if cluster_linear_idx is out of range.
"""
# Ensure expert tracking is initialized and up-to-date
self._advance_expert_to_contain(cluster_linear_idx, loc=loc, ip=ip)
# Check if valid (still within expert range after advancing)
is_valid = self.current_expert_idx < self.expert_cnt
work_tile_info = MoEWorkTileInfo(
expert_idx=Int32(-1),
tile_m_idx=Int32(0),
tile_n_idx=Int32(0),
k_tile_cnt=Int32(0),
)
if is_valid:
# Compute local cluster tile indices within current expert
local_idx = cluster_linear_idx - self.expert_tile_start
cluster_tile_m_idx, cluster_tile_n_idx = self._decompose_local_idx(
local_idx, self.current_expert_idx, loc=loc, ip=ip
)
# Convert cluster tile indices to CTA tile indices
# cta_tile_idx = cluster_tile_idx * cluster_shape + cta_id_in_cluster
cta_tile_m_idx = (
cluster_tile_m_idx * self.cluster_shape_mn[0]
+ self.cta_id_in_cluster[0] # type: ignore[index]
)
cta_tile_n_idx = (
cluster_tile_n_idx * self.cluster_shape_mn[1]
+ self.cta_id_in_cluster[1] # type: ignore[index]
)
# Compute k_tile_cnt
k_tile_cnt = self._compute_k_tile_cnt(self.current_expert_idx, loc=loc, ip=ip)
work_tile_info = MoEWorkTileInfo(
expert_idx=self.current_expert_idx,
tile_m_idx=cta_tile_m_idx,
tile_n_idx=cta_tile_n_idx,
k_tile_cnt=k_tile_cnt,
)
return work_tile_info
@dsl_user_op
@cute.jit
def _advance_expert_to_contain(
self,
cluster_linear_idx: Int32,
*,
loc=None,
ip=None,
) -> None:
"""
Advance expert tracking state until current expert contains cluster_linear_idx,
or we run out of experts.
Fast path: If already in correct expert, no work needed.
"""
# Initialize expert_tile_end if this is the first call (expert_tile_end == 0)
if self.expert_tile_end == Int32(0):
tiles_for_expert_0 = self._compute_tiles_for_expert(Int32(0), loc=loc, ip=ip)
self.expert_tile_end = tiles_for_expert_0
# Advance until cluster_linear_idx < expert_tile_end or no more experts
while cluster_linear_idx >= self.expert_tile_end and self.current_expert_idx < self.expert_cnt:
self.current_expert_idx = self.current_expert_idx + 1
self.expert_tile_start = self.expert_tile_end
if self.current_expert_idx < self.expert_cnt:
tiles_for_expert = self._compute_tiles_for_expert(
self.current_expert_idx, loc=loc, ip=ip
)
self.expert_tile_end = self.expert_tile_end + tiles_for_expert
@dsl_user_op
@cute.jit
def _compute_tiles_for_expert(
self,
expert_idx: Int32,
*,
loc=None,
ip=None,
) -> Int32:
"""Compute total cluster tiles for a given expert."""
if const_expr(self.scenario == "2Dx2D"):
# Fixed M=hidden, N=intermediate
cluster_tile_m_cnt = (self.hidden + self.cluster_tile_m - 1) // self.cluster_tile_m
cluster_tile_n_cnt = (self.intermediate + self.cluster_tile_n - 1) // self.cluster_tile_n
return cluster_tile_m_cnt * cluster_tile_n_cnt
else: # 2Dx3D
# Variable M (tokens), fixed N
tokens_i = self.offs[expert_idx]
if expert_idx > 0:
tokens_i = tokens_i - self.offs[expert_idx - 1] # type: ignore[operator]
cluster_tile_m_cnt = (
tokens_i + self.cluster_tile_m - 1 # type: ignore[operator]
) // self.cluster_tile_m
cluster_tile_n_cnt = (
self.intermediate + self.cluster_tile_n - 1
) // self.cluster_tile_n
return cluster_tile_m_cnt * cluster_tile_n_cnt
@dsl_user_op
@cute.jit
def _decompose_local_idx(
self,
local_idx: Int32,
expert_idx: Int32,
*,
loc=None,
ip=None,
) -> Tuple[Int32, Int32]:
"""
Decompose local cluster tile index within expert to (cluster_tile_m_idx, cluster_tile_n_idx).
Uses "short side first" strategy: the shorter dimension changes faster.
This maximizes overlap between adjacent clusters for better L2 cache utilization.
For example, if m_cnt=2, n_cnt=8:
- N is longer, so M changes faster: local_idx = n_idx * m_cnt + m_idx
- Linearization order: (0,0), (1,0), (0,1), (1,1), (0,2), (1,2), ...
"""
# Get tile counts for M and N
cluster_tile_m_cnt, cluster_tile_n_cnt = self._get_cluster_tile_counts(
expert_idx, loc=loc, ip=ip
)
cluster_tile_m_idx = -1
cluster_tile_n_idx = -1
# Short side first: shorter dimension changes faster
# If m_cnt <= n_cnt: m is shorter, m changes faster
# local_idx = n_idx * m_cnt + m_idx
# If n_cnt < m_cnt: n is shorter, n changes faster
# local_idx = m_idx * n_cnt + n_idx
if cluster_tile_m_cnt <= cluster_tile_n_cnt:
# M is shorter or equal, M changes faster
cluster_tile_m_idx = local_idx % cluster_tile_m_cnt
cluster_tile_n_idx = local_idx // cluster_tile_m_cnt
else:
# N is shorter, N changes faster
cluster_tile_n_idx = local_idx % cluster_tile_n_cnt
cluster_tile_m_idx = local_idx // cluster_tile_n_cnt
return (cluster_tile_m_idx, cluster_tile_n_idx)
@dsl_user_op
@cute.jit
def _get_cluster_tile_counts(
self,
expert_idx: Int32,
*,
loc=None,
ip=None,
) -> Tuple[Int32, Int32]:
"""Get (cluster_tile_m_cnt, cluster_tile_n_cnt) for a given expert."""
if const_expr(self.scenario == "2Dx2D"):
# Fixed M=hidden, N=intermediate
cluster_tile_m_cnt = (self.hidden + self.cluster_tile_m - 1) // self.cluster_tile_m
cluster_tile_n_cnt = (self.intermediate + self.cluster_tile_n - 1) // self.cluster_tile_n
else: # 2Dx3D
# Variable M (tokens), fixed N
tokens_i = self.offs[expert_idx]
if expert_idx > 0:
tokens_i = tokens_i - self.offs[expert_idx - 1] # type: ignore[operator]
cluster_tile_m_cnt = (
tokens_i + self.cluster_tile_m - 1 # type: ignore[operator]
) // self.cluster_tile_m
cluster_tile_n_cnt = (
self.intermediate + self.cluster_tile_n - 1
) // self.cluster_tile_n
return (cluster_tile_m_cnt, cluster_tile_n_cnt)
@dsl_user_op
@cute.jit
def _compute_k_tile_cnt(
self,
expert_idx: Int32,
*,
loc=None,
ip=None,
) -> Int32:
"""
Compute the number of K tiles for this expert.
2Dx3D: K = hidden (fixed) -> k_tile_cnt = ceil(hidden / cta_tile_k)
2Dx2D: K = tokens_i (variable) -> k_tile_cnt = ceil(tokens_i / cta_tile_k)
"""
if const_expr(self.scenario == "2Dx3D"):
# K is hidden (fixed)
return (self.hidden + self.cta_tile_k - 1) // self.cta_tile_k
else: # 2Dx2D
# K is tokens_i (variable per expert)
tokens_i = self.offs[expert_idx]
if expert_idx > cutlass.Int32(0):
tokens_i = tokens_i - self.offs[expert_idx - 1] # type: ignore[operator]
return (tokens_i + self.cta_tile_k - 1) // self.cta_tile_k # type: ignore[return-value, operator]

View File

@@ -0,0 +1,443 @@
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
MoE Scheduler Extension.
Bridges the MoE tile scheduler (MoEStaticPersistentTileScheduler) with tensor-level
domain conversion and TMA descriptor selection. This is the "glue" layer between:
- Scheduler: produces MoEWorkTileInfo (expert_idx, tile_m, tile_n, k_tile_cnt)
- OnlineTensormapDescCreator: builds/retrieves TMA descriptors from workspace
- Kernel: orchestrates everything
Different kernel types (grouped_mm, scaled_grouped_mm, etc.) provide their own
MoESchedExtension subclass with kernel-specific domain conversion logic.
Key design principles:
- Unified interface: get_gmem_tensor() for all tensor types
- Free implementation: no role-based templates, each subclass writes its own logic
- Composable utilities: compute_expert_token_range, rewrite_tensor_shape, etc.
are available as tools but not mandatory
Architecture:
Scheduler ──(produces)──> MoEWorkTileInfo
expert_idx, tile_m, tile_n, k_cnt
v
Extension ──(uses)──> OnlineTensormapDescCreator
│ │
│ get_gmem_tensor() │ get_desc_ptr()
│ prefetch_for_expert() │ construct_and_write()
│ │
└── internal calls ───────┘
Kernel (caller): the only place that knows all three exist
"""
from abc import ABC, abstractmethod
from typing import Literal, Tuple, Union
import cutlass
import cutlass.cute as cute
from cutlass.cute.typing import Pointer
from cutlass.cutlass_dsl import Int32
from dataclasses import dataclass
from cutlass.utils.blockscaled_layout import tile_atom_to_shape_SF
from blackwell.kernel.moe.moe_utils import (
OnlineTensormapDescCreator,
tensormap_ptr_for_copy,
compute_expert_token_range,
rewrite_tensor_shape,
prefetch_tma_descriptor,
)
from blackwell.kernel.moe.moe_persistent_scheduler import MoEWorkTileInfo
@dataclass(frozen=True)
class MoESchedExtension(ABC):
"""
Abstract base class for MoE scheduler extensions.
Bridges MoEWorkTileInfo with tensor-level domain conversion and TMA
descriptor selection. Each kernel type (grouped_mm, scaled_grouped_mm, etc.)
provides its own subclass with kernel-specific logic.
The extension:
- Holds a reference to an OnlineTensormapDescCreator for expert-wise desc retrieval
- Implements get_gmem_tensor() to convert MoE-view tensors to per-expert tensors
- Implements prefetch_for_expert() to prefetch expert-wise TMA descriptors
Subclasses are free to add any additional attributes in __init__ (scenario,
codegen configs, etc.) and implement get_gmem_tensor with arbitrary logic
per tensor_name. No role-based templates or rigid patterns are imposed.
Usage in kernel (caller):
ext = ConcreteSchedExtension(tensormap_ctor, scenario=...)
while work_tile_info.is_valid_tile:
real_a, desc_a = ext.get_gmem_tensor("a", tma_tensor_a, offs, work_tile_info)
real_b, desc_b = ext.get_gmem_tensor("b", tma_tensor_b, offs, work_tile_info)
# Use real_a, desc_a in cute.copy ...
"""
def __init__(self, tensormap_ctor: OnlineTensormapDescCreator):
super().__init__()
self.tensormap_ctor = tensormap_ctor
@abstractmethod
def get_gmem_tensor(
self,
tensor_name: str,
gmem_tensor_in_moe_view: cute.Tensor,
offs: Union[cute.Tensor, Tuple[cute.Tensor, cute.Tensor]],
work_tile_info: MoEWorkTileInfo,
) -> Tuple[cute.Tensor, "Pointer | None"]:
"""
Convert an MoE-view tensor to the real per-expert tensor for the
current work tile, and return the appropriate TMA descriptor pointer.
The MoE-view tensor uses "fake" GEMM domain dimensions that span all
experts (e.g., fake_m = tokens_sum). This method slices/offsets it
to the current expert's actual region.
:param tensor_name: Identifies which tensor (e.g., "a", "b", "c", "sfa")
:param gmem_tensor_in_moe_view: Tensor in fake GEMM MNKL domain
:param offs: Either a single cumsum tensor (experts,), or a tuple of
(offs_token, offs_padded) where offs_padded provides
padded offsets for scale-factor domain conversion.
:param work_tile_info: Current work tile from the scheduler
:return: (real_tensor, tma_desc_ptr_or_none)
- real_tensor: domain-offset and shape-rewritten tensor for this expert
- tma_desc_ptr: expert-wise desc ptr (already converted for cute.copy),
or None if the caller should use the global TMA descriptor
"""
...
@abstractmethod
def prefetch_for_expert(self, expert_idx: Int32) -> None:
"""
Prefetch expert-wise TMA descriptors for the given expert.
Called when the scheduler advances to a new expert, allowing the TMA
descriptor cache to be warmed up before the descriptors are needed.
:param expert_idx: Index of the expert whose descriptors to prefetch
"""
...
# =============================================================================
# Grouped MM Extension
# =============================================================================
class GroupedMmSchedExtension(MoESchedExtension):
"""
MoE scheduler extension for grouped_mm: handles tensors a, b, c.
Domain conversion logic per scenario:
2Dx3D:
A: (fake_m, k, 1) -> offset fake_m by token_offset, global desc
B: (n, k, fake_l) -> offset fake_l by expert_idx, global desc
C: (fake_m, n, 1) -> rewrite shape only, expert-wise desc
2Dx2D:
A: (m, fake_k, 1) -> rewrite shape only, expert-wise desc
B: (n, fake_k, 1) -> rewrite shape only, expert-wise desc
C: (m, n, fake_l) -> offset fake_l by expert_idx, global desc
"""
def __init__(
self,
scenario: Literal["2Dx3D", "2Dx2D"],
tensormap_ctor: OnlineTensormapDescCreator,
):
super().__init__(tensormap_ctor)
self.scenario = scenario
@cute.jit
def get_gmem_tensor(
self,
tensor_name: str,
gmem_tensor_in_moe_view: cute.Tensor,
offs: cute.Tensor,
work_tile_info: MoEWorkTileInfo,
):
expert_idx = work_tile_info.expert_idx
token_offset, tokens_i = compute_expert_token_range(offs, expert_idx)
shape = gmem_tensor_in_moe_view.shape
c1 = cutlass.Int32(1)
if cutlass.const_expr(self.scenario == "2Dx3D"):
if cutlass.const_expr(tensor_name == "a"):
# A: (fake_m, k, 1) -> offset fake_m, global desc
real = cute.domain_offset((token_offset, 0, 0), gmem_tensor_in_moe_view)
real = rewrite_tensor_shape(real, (tokens_i, shape[1], c1)) # type: ignore[index]
return (real, None)
elif cutlass.const_expr(tensor_name == "b"):
# B: (n, k, fake_l) -> offset fake_l, global desc
real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view)
real = rewrite_tensor_shape(real, (shape[0], shape[1], c1)) # type: ignore[index]
return (real, None)
elif cutlass.const_expr(tensor_name == "c"):
# C: (fake_m, n, 1) -> expert-wise desc, no offset
real = rewrite_tensor_shape(
gmem_tensor_in_moe_view,
(tokens_i, shape[1], c1), # type: ignore[index]
)
desc = tensormap_ptr_for_copy(
self.tensormap_ctor.get_desc_ptr("c", expert_idx)
)
return (real, desc)
elif cutlass.const_expr(self.scenario == "2Dx2D"):
if cutlass.const_expr(tensor_name == "a"):
# A: (m, fake_k, 1) -> expert-wise desc, no offset
real = rewrite_tensor_shape(
gmem_tensor_in_moe_view,
(shape[0], tokens_i, c1), # type: ignore[index]
)
desc = tensormap_ptr_for_copy(
self.tensormap_ctor.get_desc_ptr("a", expert_idx)
)
return (real, desc)
elif cutlass.const_expr(tensor_name == "b"):
# B: (n, fake_k, 1) -> expert-wise desc, no offset
real = rewrite_tensor_shape(
gmem_tensor_in_moe_view,
(shape[0], tokens_i, c1), # type: ignore[index]
)
desc = tensormap_ptr_for_copy(
self.tensormap_ctor.get_desc_ptr("b", expert_idx)
)
return (real, desc)
elif cutlass.const_expr(tensor_name == "c"):
# C: (m, n, fake_l) -> offset fake_l, global desc
real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view)
real = rewrite_tensor_shape(real, (shape[0], shape[1], c1)) # type: ignore[index]
return (real, None)
raise ValueError("Invalid scenario or GEMM tensor name.")
@cute.jit
def prefetch_for_expert(self, expert_idx: Int32) -> None:
if cutlass.const_expr(self.scenario == "2Dx3D"):
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("c", expert_idx))
elif cutlass.const_expr(self.scenario == "2Dx2D"):
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("a", expert_idx))
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("b", expert_idx))
else:
raise ValueError("Invalid scenario.")
# =============================================================================
# Scaled Grouped MM Extension (block-scaled MoE)
# =============================================================================
class ScaledGroupedMmSchedExtension(MoESchedExtension):
"""
MoE scheduler extension for scaled_grouped_mm: handles a, b, c, sfa, sfb.
Extends GroupedMmSchedExtension with scale-factor tensor support.
SFA/SFB are passed as flat GEMM-domain tensors and atom-tiled per expert
via tile_atom_to_shape_SF.
The offs parameter is always a tuple (offs_token, offs_padded):
- offs_token: cumsum offsets in data (activation) domain
- offs_padded: cumsum offsets in scale-factor domain (padded to atom granularity)
sf_vec_size is obtained from self.tensormap_ctor.sf_vec_size.
Domain conversion logic per scenario:
2Dx3D:
A: (fake_m, k, 1) -> offset fake_m by token_offset, global desc
B: (n, k, fake_l) -> offset fake_l by expert_idx, global desc
C: (fake_m, n, 1) -> rewrite shape, expert-wise desc
SFA: (fake_m_pad, k_pad, 1) -> offset fake_m_pad by padded_offset,
atom-tile, global desc
SFB: (n_pad, k_pad, fake_l) -> offset fake_l by expert_idx,
atom-tile, global desc
2Dx2D:
A: (m, fake_k, 1) -> rewrite shape, expert-wise desc
B: (n, fake_k, 1) -> rewrite shape, expert-wise desc
C: (m, n, fake_l) -> offset fake_l by expert_idx, global desc
SFA: (m_pad, fake_k_pad, 1) -> offset fake_k_pad by padded_offset,
atom-tile, expert-wise desc
SFB: (n_pad, fake_k_pad, 1) -> offset fake_k_pad by padded_offset,
atom-tile, expert-wise desc
"""
def __init__(
self,
scenario: Literal["2Dx3D", "2Dx2D"],
tensormap_ctor: OnlineTensormapDescCreator,
):
super().__init__(tensormap_ctor)
self.scenario = scenario
@cute.jit
def get_gmem_tensor(
self,
tensor_name: str,
gmem_tensor_in_moe_view: cute.Tensor,
offs: Tuple[cute.Tensor, cute.Tensor],
work_tile_info: MoEWorkTileInfo,
):
# Unpack the offs tuple
offs_token, offs_padded = offs
expert_idx = work_tile_info.expert_idx
token_offset, tokens_i = compute_expert_token_range(offs_token, expert_idx)
padded_offset, padded_size_i = compute_expert_token_range(
offs_padded, expert_idx
)
shape = gmem_tensor_in_moe_view.shape
stride = gmem_tensor_in_moe_view.stride
c1 = cutlass.Int32(1)
sf_vec_size = self.tensormap_ctor.sf_vec_size
if cutlass.const_expr(self.scenario == "2Dx3D"):
if cutlass.const_expr(tensor_name == "a"):
# A: (fake_m, k, 1) -> offset fake_m, global desc
real = cute.domain_offset((token_offset, 0, 0), gmem_tensor_in_moe_view)
real = rewrite_tensor_shape(real, (tokens_i, shape[1], c1)) # type: ignore[index]
return (real, None)
elif cutlass.const_expr(tensor_name == "b"):
# B: (n, k, fake_l) -> offset fake_l, global desc
real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view)
real = rewrite_tensor_shape(real, (shape[0], shape[1], c1)) # type: ignore[index]
return (real, None)
elif cutlass.const_expr(tensor_name == "c"):
# C: (fake_m, n, 1) -> expert-wise desc
real = rewrite_tensor_shape(
gmem_tensor_in_moe_view,
(tokens_i, shape[1], c1), # type: ignore[index]
)
desc = tensormap_ptr_for_copy(
self.tensormap_ctor.get_desc_ptr("c", expert_idx)
)
return (real, desc)
elif cutlass.const_expr(tensor_name == "sfa"):
# SFA: (fake_m_pad, k_pad, 1) -> offset fake_m_pad, atom-tile, global desc
real = cute.domain_offset(
(padded_offset, 0, 0), gmem_tensor_in_moe_view
)
per_expert_shape = (padded_size_i, shape[1], c1) # type: ignore[index]
sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size)
real = cute.make_tensor(
real.iterator, cute.make_layout(sf_layout.shape, stride=stride)
)
return (real, None)
elif cutlass.const_expr(tensor_name == "sfb"):
# SFB: (n_pad, k_pad, fake_l) -> offset fake_l, atom-tile, global desc
real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view)
per_expert_shape = (shape[0], shape[1], c1) # type: ignore[index]
sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size)
real = cute.make_tensor(
real.iterator, cute.make_layout(sf_layout.shape, stride=stride)
)
return (real, None)
elif cutlass.const_expr(self.scenario == "2Dx2D"):
if cutlass.const_expr(tensor_name == "a"):
# A: (m, fake_k, 1) -> expert-wise desc
real = rewrite_tensor_shape(
gmem_tensor_in_moe_view,
(shape[0], tokens_i, c1), # type: ignore[index]
)
desc = tensormap_ptr_for_copy(
self.tensormap_ctor.get_desc_ptr("a", expert_idx)
)
return (real, desc)
elif cutlass.const_expr(tensor_name == "b"):
# B: (n, fake_k, 1) -> expert-wise desc
real = rewrite_tensor_shape(
gmem_tensor_in_moe_view,
(shape[0], tokens_i, c1), # type: ignore[index]
)
desc = tensormap_ptr_for_copy(
self.tensormap_ctor.get_desc_ptr("b", expert_idx)
)
return (real, desc)
elif cutlass.const_expr(tensor_name == "c"):
# C: (m, n, fake_l) -> offset fake_l, global desc
real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view)
real = rewrite_tensor_shape(real, (shape[0], shape[1], c1)) # type: ignore[index]
return (real, None)
elif cutlass.const_expr(tensor_name == "sfa"):
# SFA: (m_pad, fake_k_pad, 1) -> offset fake_k_pad, atom-tile, expert-wise desc
per_expert_shape = (shape[0], padded_size_i, c1) # type: ignore[index]
sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size)
real = rewrite_tensor_shape(gmem_tensor_in_moe_view, sf_layout.shape)
desc = tensormap_ptr_for_copy(
self.tensormap_ctor.get_desc_ptr("sfa", expert_idx)
)
return (real, desc)
elif cutlass.const_expr(tensor_name == "sfb"):
# SFB: (n_pad, fake_k_pad, 1) -> offset fake_k_pad, atom-tile, expert-wise desc
per_expert_shape = (shape[0], padded_size_i, c1) # type: ignore[index]
sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size)
real = rewrite_tensor_shape(gmem_tensor_in_moe_view, sf_layout.shape)
desc = tensormap_ptr_for_copy(
self.tensormap_ctor.get_desc_ptr("sfb", expert_idx)
)
return (real, desc)
raise ValueError("Invalid scenario or tensor name.")
@cute.jit
def prefetch_for_expert(self, expert_idx: Int32) -> None:
if cutlass.const_expr(self.scenario == "2Dx3D"):
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("c", expert_idx))
elif cutlass.const_expr(self.scenario == "2Dx2D"):
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("a", expert_idx))
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("b", expert_idx))
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("sfa", expert_idx))
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("sfb", expert_idx))
else:
raise ValueError("Invalid scenario.")

910
reference/moe_moe_utils.py Normal file
View File

@@ -0,0 +1,910 @@
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
Online TMA Descriptor Construction Utilities.
Provides utilities for dynamically creating TMA descriptors at kernel runtime
based on runtime-provided information (problem sizes, pointers, etc.).
Key components:
- OnlineTensormapDescCreator: Simplified ABC for TMA descriptor builders (2 abstract methods)
- TensormapWorkspace: Helper for linear workspace layout of TMA descriptors
- MoEGroupedGemmTensormapConstructor: TMA descriptor constructor for MoE Grouped GEMM
- GeneralGroupedGemmTensormapConstructor: TMA descriptor constructor for general Grouped GEMM
- Pointer utility functions (ptr_offset_bytes, gmem_ptr_to_generic, etc.)
- tensormap_ptr_for_copy: Convert raw desc ptr to cute.copy-compatible type
- compute_expert_token_range: Compute per-expert token offset and count from offs
- rewrite_tensor_shape: Debug-friendly tensor shape rewrite utility
"""
from abc import ABC, abstractmethod
from typing import Optional, Literal, Tuple, Union
import cutlass
import cutlass.cute as cute
from cutlass.cute.typing import AddressSpace, Pointer
from cutlass.cute.nvgpu import cpasync
from cutlass.cutlass_dsl import dsl_user_op, Int32
from cutlass._mlir import ir
from cutlass._mlir.dialects import llvm
from cutlass._mlir.dialects import cute as _cute_ir
from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir
from dataclasses import dataclass
from cutlass.utils.blockscaled_layout import tile_atom_to_shape_SF
TensormapDescBytes = 128
# =============================================================================
# Pointer Utilities
# =============================================================================
@dsl_user_op
@cute.jit
def spin_wait(
ptr: Pointer, condition, fail_sleep_cycles: int = 100, *, loc=None, ip=None
) -> None:
"""
Generic spin-wait.
Example usage:
# Wait until counter >= total_blocks
spin_wait(counter_ptr, lambda x: x >= total_blocks, fail_sleep_cycles=100)
# Wait until flag == 1
spin_wait(flag_ptr, lambda x: x == 1)
"""
current = cute.arch.load(ptr, ptr.dtype, cop="cg", loc=loc, ip=ip)
while not condition(current):
# Load with L1 cache bypass (ld.global.cg)
if cutlass.const_expr(fail_sleep_cycles > 0):
cute.arch.nanosleep(sleep_time=fail_sleep_cycles, loc=loc, ip=ip)
current = cute.arch.load(ptr, ptr.dtype, cop="cg", loc=loc, ip=ip)
@dsl_user_op
def gmem_ptr_to_generic(
gmem_ptr: Pointer,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> Pointer:
if gmem_ptr.memspace != AddressSpace.gmem:
raise ValueError(
f"gmem_ptr_to_generic requires pointer in gmem address space, "
f"got {gmem_ptr.memspace}"
)
# Get LLVM pointer and cast to generic address space
llvm_ptr = gmem_ptr.to_llvm_ptr(loc=loc, ip=ip)
generic_llvm_ptr = llvm.addrspacecast(
llvm.PointerType.get(AddressSpace.generic), llvm_ptr, loc=loc, ip=ip
)
# Create a new cute.Pointer with generic address space, preserving alignment
return cute.make_ptr(
gmem_ptr.dtype,
generic_llvm_ptr,
AddressSpace.generic,
assumed_align=gmem_ptr.alignment,
loc=loc,
ip=ip,
)
@dsl_user_op
def generic_ptr_to_gmem(
generic_ptr: Pointer,
*,
loc: Optional[ir.Location] = None,
ip: Optional[ir.InsertionPoint] = None,
) -> Pointer:
if generic_ptr.memspace != AddressSpace.generic:
raise ValueError(
f"generic_ptr_to_gmem requires pointer in generic address space, "
f"got {generic_ptr.memspace}"
)
# Get LLVM pointer and cast to gmem address space
llvm_ptr = generic_ptr.to_llvm_ptr(loc=loc, ip=ip)
gmem_llvm_ptr = llvm.addrspacecast(
llvm.PointerType.get(AddressSpace.gmem), llvm_ptr, loc=loc, ip=ip
)
# Create a new cute.Pointer with gmem address space, preserving alignment
return cute.make_ptr(
generic_ptr.dtype,
gmem_llvm_ptr,
AddressSpace.gmem,
assumed_align=generic_ptr.alignment,
loc=loc,
ip=ip,
)
@dsl_user_op
def prefetch_tma_descriptor(tma_desc_ptr: Pointer, *, loc=None, ip=None) -> None:
"""
Prefetch a TMA descriptor from global memory.
This function prefetches the TMA descriptor pointed to by tma_desc_ptr
into the TMA descriptor cache. The pointer must be in generic or global
address space. If a gmem pointer is passed, it will be automatically
converted to generic address space.
:param tma_desc_ptr: Pointer to the TMA descriptor in global or generic memory
:type tma_desc_ptr: Pointer
:raises ValueError: If pointer is not in generic or global address space
"""
if tma_desc_ptr.memspace not in (AddressSpace.gmem, AddressSpace.generic):
raise ValueError(
f"prefetch_tma_descriptor requires pointer in gmem or generic address space, "
f"got {tma_desc_ptr.memspace}"
)
# Convert gmem pointer to generic if needed
if tma_desc_ptr.memspace == AddressSpace.gmem:
tma_desc_ptr = gmem_ptr_to_generic(tma_desc_ptr, loc=loc, ip=ip)
# Convert cute.Pointer to LLVM pointer for prefetch
llvm_ptr = tma_desc_ptr.to_llvm_ptr(loc=loc, ip=ip)
from cutlass.cute.arch.nvvm_wrappers import prefetch as nvvm_prefetch
nvvm_prefetch(llvm_ptr, tensormap=True, loc=loc, ip=ip)
def ptr_offset_bytes(ptr: Pointer, byte_offset: int) -> Pointer:
"""Offset a pointer by a given number of bytes."""
element_offset = byte_offset * 8 // ptr.dtype.width
return ptr + element_offset
@dsl_user_op
def tensormap_ptr_for_copy(raw_ptr: Pointer, *, loc=None, ip=None) -> Pointer:
"""
Convert a raw TMA descriptor gmem pointer to the type expected by cute.copy.
cute.copy requires the tma_desc_ptr to be in generic address space and
recast to TmaDescriptorTiledType. This utility performs both conversions.
:param raw_ptr: Raw pointer to TMA descriptor in gmem
:type raw_ptr: Pointer
:return: Pointer compatible with cute.copy's tma_desc_ptr parameter
:rtype: Pointer
"""
generic_ptr = gmem_ptr_to_generic(raw_ptr, loc=loc, ip=ip)
tma_desc_ptr_ty = _cute_ir.PtrType.get(
_cute_nvgpu_ir.TmaDescriptorTiledType.get(),
generic_ptr.memspace,
generic_ptr.alignment,
)
return _cute_ir.recast_iter(tma_desc_ptr_ty, generic_ptr.value)
# =============================================================================
# MoE Utilities
# =============================================================================
@dsl_user_op
@cute.jit
def compute_expert_token_range(
offs: cute.Tensor,
expert_idx: Int32,
*,
loc=None,
ip=None,
) -> Tuple[Int32, Int32]:
"""
Compute token offset and count for a given expert from the cumsum offs tensor.
:param offs: Cumulative sum tensor of token counts per expert, shape (experts,)
:param expert_idx: Index of the expert
:return: (token_offset, tokens_i) where token_offset is the start position
and tokens_i is the number of tokens for this expert
"""
token_offset = Int32(0)
if expert_idx > Int32(0):
token_offset = offs[expert_idx - 1] # type: ignore[assignment]
tokens_i = offs[expert_idx] - token_offset
return token_offset, tokens_i
@dsl_user_op
def rewrite_tensor_shape(
tensor: cute.Tensor,
new_shape: Tuple,
*,
loc=None,
ip=None,
) -> cute.Tensor:
"""
Rewrite tensor shape while keeping the same stride and iterator.
This is primarily for debug friendliness - shows the actual expert's shape
instead of the fake global shape. No runtime overhead as it becomes
dead code in non-debug builds.
:param tensor: Source tensor whose stride and iterator to preserve
:param new_shape: New shape to apply
:return: New tensor with the given shape but original stride and iterator
"""
new_layout = cute.make_layout(new_shape, stride=tensor.stride, loc=loc, ip=ip)
return cute.make_tensor(tensor.iterator, new_layout, loc=loc, ip=ip)
# =============================================================================
# TMA Descriptor Workspace Helper
# =============================================================================
class TensormapWorkspace:
"""
Helper for linear workspace layout of TMA descriptors.
Manages address calculation for a workspace buffer containing TMA descriptors
organized as: for each executor (e.g., expert or group), a fixed set of
named descriptor slots.
Layout: [slot_0_exec_0, slot_1_exec_0, ..., slot_0_exec_1, slot_1_exec_1, ...]
Example:
# 2Dx3D MoE: only C is expert-wise
workspace = TensormapWorkspace(workspace_ptr, ["c"])
# 2Dx2D MoE: A and B are expert-wise
workspace = TensormapWorkspace(workspace_ptr, ["a", "b"])
# General grouped GEMM: all three tensors
workspace = TensormapWorkspace(workspace_ptr, ["a", "b", "c"])
"""
def __init__(self, workspace_ptr: Pointer, slot_names: list):
"""
:param workspace_ptr: Pointer to the beginning of the workspace buffer
:param slot_names: Ordered list of tensor names, defining the slot layout
per executor. e.g., ["a", "b", "c"]
"""
self.workspace_ptr = workspace_ptr
self._name_to_slot = {name: i for i, name in enumerate(slot_names)}
self._slots_per_executor = len(slot_names)
@cute.jit
def get_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer:
"""
Get the workspace pointer for a specific TMA descriptor.
:param tensor_name: Name of the tensor (must be one of the slot_names)
:param executor_idx: Index of the executor (e.g., group_idx or expert_idx)
:return: Aligned pointer to the TMA descriptor in workspace
"""
if cutlass.const_expr(tensor_name not in self._name_to_slot):
raise ValueError(
f"Invalid tensor_name '{tensor_name}', "
f"expected one of {list(self._name_to_slot.keys())}"
)
slot = self._name_to_slot[tensor_name]
byte_offset = (
executor_idx * self._slots_per_executor + slot
) * TensormapDescBytes
return ptr_offset_bytes(self.workspace_ptr, byte_offset).align(
TensormapDescBytes
)
@staticmethod
def size_bytes(num_slots: int, num_executors: int) -> int:
"""
Calculate workspace size in bytes.
:param num_slots: Number of descriptor slots per executor
:param num_executors: Total number of executors (e.g., expert_cnt or group_cnt)
:return: Total workspace size in bytes
"""
return num_slots * num_executors * TensormapDescBytes
# =============================================================================
# Online TMA Descriptor Creator (Abstract Base Class)
# =============================================================================
@dataclass(frozen=True)
class OnlineTensormapDescCreator(ABC):
"""
Abstract base class for building TMA descriptors online (at kernel runtime).
Subclasses store all needed parameters (both codegen-time configs and runtime
values) as explicit instance attributes in __init__. No dict-based APIs.
Subclasses must implement exactly 2 abstract methods:
- construct_and_write: Build TMA descriptor(s) for one executor and write to workspace
- get_desc_ptr: Return raw gmem pointer to a specific descriptor in workspace
To convert the raw pointer for use with cute.copy, callers should use the
standalone tensormap_ptr_for_copy() utility.
"""
@abstractmethod
def construct_and_write(self, executor_idx: Int32, dependency=None) -> None:
"""
Build TMA descriptor(s) for one executor and write to workspace.
:param executor_idx: Index of the executor (e.g., group_idx or expert_idx).
Semantics may vary by subclass when ``dependency`` is provided.
:param dependency: Optional pipeline consumer for inter-warp-group
synchronization. When provided, the subclass decides when to wait
(via ``dependency.wait_and_advance()``) and release. The subclass
also decides how to interpret ``executor_idx`` in this mode.
"""
...
@abstractmethod
def get_desc_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer:
"""
Get the raw gmem pointer to a specific TMA descriptor in workspace.
:param tensor_name: Name identifying which tensor's descriptor
:param executor_idx: Index of the executor (e.g., group_idx or expert_idx)
:return: Raw pointer (gmem) to the TMA descriptor
"""
...
# =============================================================================
# MoE Grouped GEMM Tensormap Constructor
# =============================================================================
class MoEGroupedGemmTensormapConstructor(OnlineTensormapDescCreator):
"""
Tensormap descriptor constructor for MoE Grouped GEMM (expert-wise descriptors only).
Non-expert-wise descriptors are passed directly at kernel launch.
This class only handles:
- 2Dx3D: C descriptors (expert-wise, to avoid write conflicts)
- 2Dx2D: A and B descriptors (expert-wise, tokens is reduction axis)
All parameters are stored as explicit instance attributes (no dicts).
Workspace layout:
- 2Dx3D: [C_0, C_1, ..., C_{n-1}]
- 2Dx2D: [A_0, A_1, ..., A_{n-1}, B_0, B_1, ..., B_{n-1}]
"""
def __init__(
self,
scenario: Literal["2Dx3D", "2Dx2D"],
# Codegen-time configs
a_dtype,
b_dtype,
c_dtype,
a_smem_layout,
b_smem_layout,
epi_smem_layout,
a_tma_op,
b_tma_op,
c_tma_op,
tiled_mma,
mma_tiler,
cluster_layout_vmnk_shape,
epi_tile,
# Runtime params
a_tensor: cute.Tensor, # fake GEMM domain A
b_tensor: cute.Tensor, # fake GEMM domain B
c_tensor: cute.Tensor, # fake GEMM domain C
offs: cute.Tensor, # (experts,) cumsum
workspace_ptr: Pointer,
) -> None:
super().__init__()
self.scenario = scenario
# Codegen-time configs
self.a_dtype = a_dtype
self.b_dtype = b_dtype
self.c_dtype = c_dtype
self.a_smem_layout = a_smem_layout
self.b_smem_layout = b_smem_layout
self.epi_smem_layout = epi_smem_layout
self.a_tma_op = a_tma_op
self.b_tma_op = b_tma_op
self.c_tma_op = c_tma_op
self.tiled_mma = tiled_mma
self.mma_tiler = mma_tiler
self.cluster_layout_vmnk_shape = cluster_layout_vmnk_shape
self.epi_tile = epi_tile
# Runtime params
self.a_tensor = a_tensor
self.b_tensor = b_tensor
self.c_tensor = c_tensor
self.offs = offs
# Workspace with scenario-specific slot layout
if scenario == "2Dx3D":
self.workspace = TensormapWorkspace(workspace_ptr, ["c"])
else:
self.workspace = TensormapWorkspace(workspace_ptr, ["a", "b"])
@staticmethod
def get_workspace_size(scenario: Literal["2Dx3D", "2Dx2D"], expert_cnt: int) -> int:
"""Calculate workspace size in bytes for tensormap descriptors."""
if scenario == "2Dx3D":
return TensormapWorkspace.size_bytes(1, expert_cnt) # only C
else:
return TensormapWorkspace.size_bytes(2, expert_cnt) # A and B
@cute.jit
def get_desc_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer:
return self.workspace.get_ptr(tensor_name, executor_idx)
@cute.jit
def construct_and_write(self, executor_idx: Int32, dependency=None) -> None:
"""
Create expert-wise tensormap descriptors for the given expert.
- 2Dx3D: Creates C descriptor for this expert
- 2Dx2D: Creates A and B descriptors for this expert
"""
if cutlass.const_expr(self.scenario == "2Dx3D"):
self._construct_c_desc_2dx3d(executor_idx)
else: # 2Dx2D
self._construct_ab_descs_2dx2d(executor_idx)
@cute.jit
def _construct_c_desc_2dx3d(self, expert_idx: Int32) -> None:
"""
2Dx3D: Create expert-wise C descriptor.
C tensor: (fake_m, n, 1) = (tokens_sum, intermediate, 1)
Slice fake_m -> (tokens_i, intermediate, 1) per expert.
"""
token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx)
c_ptr = self.c_tensor.iterator
c_stride = self.c_tensor.stride
intermediate = self.c_tensor.shape[1] # type: ignore[index]
c1 = cutlass.Int32(1)
c0 = cutlass.Int32(0)
c_ptr_i = c_ptr + token_offset * c_stride[0] # type: ignore[index]
c_layout_i = cute.make_layout(
(tokens_i, intermediate, c1),
stride=(c_stride[0], c_stride[1], c0), # type: ignore[index]
)
c_tensor_i = cute.make_tensor(c_ptr_i, c_layout_i)
tma_atom_c, _ = cpasync.make_tiled_tma_atom(
self.c_tma_op,
c_tensor_i,
self.epi_smem_layout,
self.epi_tile,
)
cpasync.copy_tensormap(tma_atom_c, self.get_desc_ptr("c", expert_idx))
@cute.jit
def _construct_ab_descs_2dx2d(self, expert_idx: Int32) -> None:
"""
2Dx2D: Create expert-wise A and B descriptors.
A: (m, fake_k, 1) -> slice to (m, tokens_i, 1)
B: (n, fake_k, 1) -> slice to (n, tokens_i, 1)
"""
token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx)
c1 = cutlass.Int32(1)
c0 = cutlass.Int32(0)
# A tensor: (m, fake_k, 1) -> (m, tokens_i, 1)
a_ptr = self.a_tensor.iterator
a_stride = self.a_tensor.stride
a_m = self.a_tensor.shape[0] # type: ignore[index]
a_ptr_i = a_ptr + token_offset * a_stride[1] # type: ignore[index]
a_layout_i = cute.make_layout(
(a_m, tokens_i, c1),
stride=(a_stride[0], a_stride[1], c0), # type: ignore[index]
)
a_tensor_i = cute.make_tensor(a_ptr_i, a_layout_i)
tma_atom_a, _ = cute.nvgpu.make_tiled_tma_atom_A(
self.a_tma_op,
a_tensor_i,
self.a_smem_layout,
self.mma_tiler,
self.tiled_mma,
self.cluster_layout_vmnk_shape,
)
cpasync.copy_tensormap(tma_atom_a, self.get_desc_ptr("a", expert_idx))
# B tensor: (n, fake_k, 1) -> (n, tokens_i, 1)
b_ptr = self.b_tensor.iterator
b_stride = self.b_tensor.stride
b_n = self.b_tensor.shape[0] # type: ignore[index]
b_ptr_i = b_ptr + token_offset * b_stride[1] # type: ignore[index]
b_layout_i = cute.make_layout(
(b_n, tokens_i, c1),
stride=(b_stride[0], b_stride[1], c0), # type: ignore[index]
)
b_tensor_i = cute.make_tensor(b_ptr_i, b_layout_i)
tma_atom_b, _ = cute.nvgpu.make_tiled_tma_atom_B(
self.b_tma_op,
b_tensor_i,
self.b_smem_layout,
self.mma_tiler,
self.tiled_mma,
self.cluster_layout_vmnk_shape,
)
cpasync.copy_tensormap(tma_atom_b, self.get_desc_ptr("b", expert_idx))
# =============================================================================
# MoE Scaled Grouped GEMM Tensormap Constructor
# =============================================================================
class MoEScaledGroupedGemmTensormapConstructor(OnlineTensormapDescCreator):
"""
Tensormap descriptor constructor for MoE Scaled Grouped GEMM (block-scaled).
.. py:attribute:: ChunkSize
:value: 128
Number of experts processed per chunk in the desc_init_kernel.
Must match the warp-group width (4 warps × 32 threads).
Extends MoEGroupedGemmTensormapConstructor with SFA/SFB descriptor support.
Expert-wise descriptors only — non-expert-wise descriptors are passed
directly at kernel launch.
Workspace layout:
- 2Dx3D: [C_0, C_1, ..., C_{n-1}] (1 slot per expert)
- 2Dx2D: [A_0, B_0, SFA_0, SFB_0, A_1, B_1, SFA_1, SFB_1, ...] (4 slots per expert)
:param scenario: "2Dx3D" or "2Dx2D"
:param sf_vec_size: Scale factor vector size (32 for MXFP8/MXFP4, 16 for NVFP4)
:param a_dtype: Data type for tensor A
:param b_dtype: Data type for tensor B
:param c_dtype: Data type for tensor C
:param sf_dtype: Data type for scale factors (SFA/SFB)
:param a_smem_layout: SMEM layout for A TMA
:param b_smem_layout: SMEM layout for B TMA
:param epi_smem_layout: SMEM layout for epilogue (C) TMA
:param sfa_smem_layout: SMEM layout for SFA TMA
:param sfb_smem_layout: SMEM layout for SFB TMA
:param a_tma_op: TMA operation for A
:param b_tma_op: TMA operation for B
:param c_tma_op: TMA operation for C (S2G store or reduce)
:param sfa_tma_op: TMA operation for SFA
:param sfb_tma_op: TMA operation for SFB
:param tiled_mma: TiledMma for A/B/SFA/C TMA atom construction
:param tiled_mma_sfb: TiledMma for SFB (separate due to 2CTA replication)
:param mma_tiler: MMA tiler shape (M, N, K)
:param mma_tiler_sfb: MMA tiler shape for SFB
:param cluster_layout_vmnk_shape: Cluster layout shape for A/B/SFA multicast
:param cluster_layout_sfb_vmnk_shape: Cluster layout shape for SFB multicast
:param epi_tile: Epilogue tile shape
:param a_tensor: Fake GEMM domain A tensor
:param b_tensor: Fake GEMM domain B tensor
:param c_tensor: Fake GEMM domain C tensor
:param sfa_tensor: Fake GEMM domain SFA tensor (atom-tiled layout)
:param sfb_tensor: Fake GEMM domain SFB tensor (atom-tiled layout)
:param offs: (experts,) cumsum offsets in data domain
:param offs_padded: (experts,) cumsum offsets in padded scale domain
:param workspace_ptr: Pointer to workspace for TMA descriptors
:param expert_cnt: Total number of experts
"""
ChunkSize = 128
def __init__(
self,
scenario: Literal["2Dx3D", "2Dx2D"],
sf_vec_size: int,
# Codegen-time configs: dtypes
a_dtype,
b_dtype,
c_dtype,
sf_dtype,
# Codegen-time configs: SMEM layouts
a_smem_layout,
b_smem_layout,
epi_smem_layout,
sfa_smem_layout,
sfb_smem_layout,
# Codegen-time configs: TMA ops
a_tma_op,
b_tma_op,
c_tma_op,
sfa_tma_op,
sfb_tma_op,
# Codegen-time configs: MMA / cluster / tile
tiled_mma,
tiled_mma_sfb,
mma_tiler,
mma_tiler_sfb,
cluster_layout_vmnk_shape,
cluster_layout_sfb_vmnk_shape,
epi_tile,
# Runtime params
a_tensor: cute.Tensor,
b_tensor: cute.Tensor,
c_tensor: cute.Tensor,
sfa_tensor: cute.Tensor,
sfb_tensor: cute.Tensor,
offs: cute.Tensor,
offs_padded: cute.Tensor,
workspace_ptr: Pointer,
expert_cnt: Optional[Union[Int32, int]] = None,
) -> None:
super().__init__()
self.scenario = scenario
self.sf_vec_size = sf_vec_size
# Dtypes
self.a_dtype = a_dtype
self.b_dtype = b_dtype
self.c_dtype = c_dtype
self.sf_dtype = sf_dtype
# SMEM layouts
self.a_smem_layout = a_smem_layout
self.b_smem_layout = b_smem_layout
self.epi_smem_layout = epi_smem_layout
self.sfa_smem_layout = sfa_smem_layout
self.sfb_smem_layout = sfb_smem_layout
# TMA ops
self.a_tma_op = a_tma_op
self.b_tma_op = b_tma_op
self.c_tma_op = c_tma_op
self.sfa_tma_op = sfa_tma_op
self.sfb_tma_op = sfb_tma_op
# MMA / cluster / tile
self.tiled_mma = tiled_mma
self.tiled_mma_sfb = tiled_mma_sfb
self.mma_tiler = mma_tiler
self.mma_tiler_sfb = mma_tiler_sfb
self.cluster_layout_vmnk_shape = cluster_layout_vmnk_shape
self.cluster_layout_sfb_vmnk_shape = cluster_layout_sfb_vmnk_shape
self.epi_tile = epi_tile
# Runtime params
self.a_tensor = a_tensor
self.b_tensor = b_tensor
self.c_tensor = c_tensor
self.sfa_tensor = sfa_tensor
self.sfb_tensor = sfb_tensor
self.offs = offs
self.offs_padded = offs_padded
self.expert_cnt = expert_cnt
# Workspace with scenario-specific slot layout
if scenario == "2Dx3D":
self.workspace = TensormapWorkspace(workspace_ptr, ["c"])
else:
self.workspace = TensormapWorkspace(workspace_ptr, ["a", "b", "sfa", "sfb"])
@staticmethod
def get_workspace_size(scenario: Literal["2Dx3D", "2Dx2D"], expert_cnt: int) -> int:
"""Calculate workspace size in bytes for tensormap descriptors."""
if scenario == "2Dx3D":
return TensormapWorkspace.size_bytes(1, expert_cnt) # C only
else:
return TensormapWorkspace.size_bytes(4, expert_cnt) # A, B, SFA, SFB
@cute.jit
def get_desc_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer:
return self.workspace.get_ptr(tensor_name, executor_idx)
@cute.jit
def construct_and_write(self, lane_in_group: Int32, dependency=None) -> None:
"""
Create expert-wise tensormap descriptors for all experts.
``lane_in_group`` is the thread's position within its warp group
(0..ChunkSize-1). The method loops internally over all experts in
chunks of ``ChunkSize``, with two-phase pipeline synchronization
per chunk.
Per-chunk execution:
1. Phase 1: Build descriptors that do NOT depend on ``offs_padded``
(A/B for 2Dx2D, C for 2Dx3D). Overlaps with Group A's prefix sum.
2. Barrier: ``consumer.wait_and_advance()`` — all threads participate.
3. Phase 2: Build descriptors that depend on ``offs_padded``
(SFA/SFB for 2Dx2D). Reads padded offsets from SMEM buffer.
4. Release: ``handle.release()`` — all threads participate.
:param lane_in_group: Thread's position within the warp group (0..127).
:param dependency: ``(PipelineConsumer, smem_offs_padded)`` — the
consumer for mbarrier sync, and the SMEM tensor of shape
``(ChunkSize + 1,)`` with layout ``[carry, offs_padded[0..127]]``.
"""
consumer, smem_offs_padded = dependency
assert self.expert_cnt is not None
num_chunks = (self.expert_cnt + self.ChunkSize - 1) // self.ChunkSize
chunk_idx = cutlass.Int32(0)
while chunk_idx < num_chunks:
expert_idx = chunk_idx * self.ChunkSize + lane_in_group
in_bounds = expert_idx < self.expert_cnt
# Phase 1: non-dependent descriptors
if in_bounds:
if cutlass.const_expr(self.scenario == "2Dx2D"):
self._construct_ab_descs_2dx2d(expert_idx)
else:
self._construct_c_desc_2dx3d(expert_idx)
# All threads participate in barrier (fixed arrive count)
handle = consumer.wait_and_advance()
# Phase 2: dependent descriptors (read padded offsets from SMEM)
if in_bounds:
if cutlass.const_expr(self.scenario == "2Dx2D"):
# smem_offs_padded layout: [carry, chunk[0], ..., chunk[127]]
# padded_offset = smem[lane] (prev expert's cumulative)
# padded_end = smem[lane + 1] (this expert's cumulative)
padded_offset = smem_offs_padded[lane_in_group]
padded_size_i = smem_offs_padded[lane_in_group + 1] - padded_offset
self._construct_sf_descs_2dx2d_direct(
expert_idx, padded_offset, padded_size_i
)
# All threads release (fixed arrive count)
handle.release()
chunk_idx += 1
# -----------------------------------------------------------------
# 2Dx3D: C descriptor (same as MoEGroupedGemmTensormapConstructor)
# -----------------------------------------------------------------
@cute.jit
def _construct_c_desc_2dx3d(self, expert_idx: Int32) -> None:
"""
2Dx3D: Create expert-wise C descriptor.
C: (fake_m, n, 1) -> slice to (tokens_i, n, 1) per expert.
"""
token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx)
c1 = cutlass.Int32(1)
c_i = cute.domain_offset((token_offset, 0, 0), self.c_tensor)
c_i = rewrite_tensor_shape(c_i, (tokens_i, self.c_tensor.shape[1], c1)) # type: ignore[index]
tma_atom_c, _ = cpasync.make_tiled_tma_atom(
self.c_tma_op,
c_i,
self.epi_smem_layout,
self.epi_tile,
)
cpasync.copy_tensormap(tma_atom_c, self.get_desc_ptr("c", expert_idx))
# -----------------------------------------------------------------
# 2Dx2D: A, B descriptors (same as MoEGroupedGemmTensormapConstructor)
# -----------------------------------------------------------------
@cute.jit
def _construct_ab_descs_2dx2d(self, expert_idx: Int32) -> None:
"""
2Dx2D: Create expert-wise A and B descriptors.
A: (m, fake_k, 1) -> slice to (m, tokens_i, 1)
B: (n, fake_k, 1) -> slice to (n, tokens_i, 1)
"""
token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx)
c1 = cutlass.Int32(1)
# A: (m, fake_k, 1) -> domain_offset + rewrite shape
a_i = cute.domain_offset((0, token_offset, 0), self.a_tensor)
a_i = rewrite_tensor_shape(a_i, (self.a_tensor.shape[0], tokens_i, c1)) # type: ignore[index]
tma_atom_a, _ = cute.nvgpu.make_tiled_tma_atom_A(
self.a_tma_op,
a_i,
self.a_smem_layout,
self.mma_tiler,
self.tiled_mma,
self.cluster_layout_vmnk_shape,
)
cpasync.copy_tensormap(tma_atom_a, self.get_desc_ptr("a", expert_idx))
# B: (n, fake_k, 1) -> domain_offset + rewrite shape
b_i = cute.domain_offset((0, token_offset, 0), self.b_tensor)
b_i = rewrite_tensor_shape(b_i, (self.b_tensor.shape[0], tokens_i, c1)) # type: ignore[index]
tma_atom_b, _ = cute.nvgpu.make_tiled_tma_atom_B(
self.b_tma_op,
b_i,
self.b_smem_layout,
self.mma_tiler,
self.tiled_mma,
self.cluster_layout_vmnk_shape,
)
cpasync.copy_tensormap(tma_atom_b, self.get_desc_ptr("b", expert_idx))
# -----------------------------------------------------------------
# 2Dx2D: SFA, SFB descriptors (new for block-scaled)
# -----------------------------------------------------------------
@cute.jit
def _construct_sf_descs_2dx2d_direct(
self,
expert_idx: Int32,
padded_offset: Int32,
padded_size_i: Int32,
) -> None:
"""
2Dx2D: Create expert-wise SFA and SFB descriptors with pre-computed
padded offset and size.
This variant allows the caller to supply padded offsets from SMEM
(in desc_init_kernel) instead of reading from ``self.offs_padded`` in GMEM.
"""
c1 = cutlass.Int32(1)
a_chunks_to_move = (
padded_offset
// self.sf_vec_size
* cute.size(self.sfa_tensor, mode=[0])
// 128
)
a_elems_to_move = (
cute.size(self.sfa_tensor, mode=[0]) * padded_offset // self.sf_vec_size
)
b_chunks_to_move = (
padded_offset
// self.sf_vec_size
* cute.size(self.sfb_tensor, mode=[0])
// 128
)
b_elems_to_move = (
cute.size(self.sfb_tensor, mode=[0]) * padded_offset // self.sf_vec_size
)
per_expert_sfa_shape = (self.sfa_tensor.shape[0], padded_size_i, c1) # type: ignore[index]
sfa_layout_i = tile_atom_to_shape_SF(per_expert_sfa_shape, self.sf_vec_size)
sfa_i = cute.make_tensor(
self.sfa_tensor.iterator + a_elems_to_move, sfa_layout_i
)
tma_atom_sfa, _ = cute.nvgpu.make_tiled_tma_atom_A(
self.sfa_tma_op,
sfa_i,
self.sfa_smem_layout,
self.mma_tiler,
self.tiled_mma,
self.cluster_layout_vmnk_shape,
internal_type=cutlass.Uint64,
)
cpasync.copy_tensormap(tma_atom_sfa, self.get_desc_ptr("sfa", expert_idx))
per_expert_sfb_shape = (self.sfb_tensor.shape[0], padded_size_i, c1) # type: ignore[index]
sfb_layout_i = tile_atom_to_shape_SF(per_expert_sfb_shape, self.sf_vec_size)
sfb_i = cute.make_tensor(
self.sfb_tensor.iterator + b_elems_to_move, sfb_layout_i
)
tma_atom_sfb, _ = cute.nvgpu.make_tiled_tma_atom_B(
self.sfb_tma_op,
sfb_i,
self.sfb_smem_layout,
self.mma_tiler_sfb,
self.tiled_mma_sfb,
self.cluster_layout_sfb_vmnk_shape,
internal_type=cutlass.Uint64,
)
cpasync.copy_tensormap(tma_atom_sfb, self.get_desc_ptr("sfb", expert_idx))

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff