docs: add CuTeDSL rewrite plan + reference files
The C++ CUTLASS kernel is fundamentally broken (cosine 0.05 with real data). Switching to NVIDIA's CuTeDSL approach based on their official MoE scaled grouped GEMM example. Reference files copied: - moe_torch_scaled_grouped_mm.py (3900 lines — our new kernel) - moe_utils.py, moe_persistent_scheduler.py, moe_sched_extension.py - grouped_blockscaled_gemm.py, dense_blockscaled_gemm_persistent.py - blockscaled_layout.py
This commit is contained in:
93
REWRITE_PLAN.md
Normal file
93
REWRITE_PLAN.md
Normal file
@@ -0,0 +1,93 @@
|
||||
# NVFP4 MegaMoE Kernel Rewrite — CuTeDSL
|
||||
|
||||
## Decision
|
||||
|
||||
The C++ CUTLASS kernel is fundamentally broken. The GEMM produces cosine 0.05
|
||||
with real data (despite SF remap = 0 errors and uniform tests passing). The
|
||||
root cause is likely in how CUTLASS's C++ API handles FP4 packing/tiling
|
||||
internally — something we can't easily debug or fix.
|
||||
|
||||
We're replacing it with NVIDIA's CuTeDSL approach (Python-based CUTLASS
|
||||
kernels compiled via MLIR → PTX). This is what the NVIDIA CUTLASS team
|
||||
recommends for Blackwell, and they have a working reference:
|
||||
|
||||
`cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/moe/torch_scaled_grouped_mm.py`
|
||||
|
||||
This is a 3900-line MoE scaled grouped GEMM that supports NVFP4 out of the box.
|
||||
|
||||
## Reference Files (copied to reference/)
|
||||
|
||||
- `reference/moe_torch_scaled_grouped_mm.py` — the kernel (3900 lines)
|
||||
- `reference/moe_moe_utils.py` — tensormap constructor
|
||||
- `reference/moe_moe_persistent_scheduler.py` — MoE tile scheduler
|
||||
- `reference/moe_moe_sched_extension.py` — scheduler extension for scaled GEMM
|
||||
- `reference/grouped_blockscaled_gemm.py` — simpler grouped blockscaled reference
|
||||
- `reference/dense_blockscaled_gemm_persistent.py` — dense (non-grouped) reference
|
||||
- `reference/blockscaled_layout.py` — SF layout utilities
|
||||
|
||||
## Architecture
|
||||
|
||||
### The Reference Kernel: ScaledGroupedGemmKernel
|
||||
|
||||
7-warp persistent kernel:
|
||||
- Warps 0-3: Epilogue (TMEM → RMEM → SMEM → GMEM)
|
||||
- Warp 4: MMA (tcgen05.mma.block_scale with SFA/SFB in TMEM)
|
||||
- Warp 5: TMA load (A, B, SFA, SFB from GMEM → SMEM)
|
||||
- Warp 6: Scheduler (MoE persistent tile scheduler)
|
||||
|
||||
Supports:
|
||||
- NVFP4 (Float4E2M1FN + Float8E4M3FN + sf_vec_size=16)
|
||||
- 2Dx3D scenario (our case): A(tokens, K) x B(experts, K, N) → C(tokens, N)
|
||||
- PyTorch interface via torch.nn.functional.scaled_grouped_mm
|
||||
|
||||
### Our Adaptation
|
||||
|
||||
We need to:
|
||||
1. Use ScaledGroupedGemmKernel directly (or with minimal adaptation)
|
||||
2. Replace our Python pipeline (nvfp4_mega_moe.py) to call the CuTeDSL kernel
|
||||
3. Handle the MoE-specific stuff (routing, gate/up fusion, SiLU, scatter)
|
||||
|
||||
The CuTeDSL kernel handles A/B/SFA/SFB tiling and MMA internally.
|
||||
We NO LONGER need:
|
||||
- Our custom SF remap kernel (CuTeDSL handles SF layouts natively)
|
||||
- Our C++ CUTLASS extension (cutlass_nvfp4_gemm.cu)
|
||||
- transform_nvfp4_weights_for_mega_moe (CuTeDSL uses the natural layout)
|
||||
|
||||
We STILL need:
|
||||
- stage_activation (BF16 → FP4 quantization)
|
||||
- The MoE routing logic (top-k selection, slot mapping)
|
||||
- Gate/up fusion with up_correction
|
||||
- SiLU activation
|
||||
- Scatter with routing weights
|
||||
|
||||
## Plan
|
||||
|
||||
### Phase 1: Get the reference kernel running standalone
|
||||
- Set up CuTeDSL build environment on B200
|
||||
- Run the reference example with NVFP4 config
|
||||
- Verify it produces correct output
|
||||
|
||||
### Phase 2: Integrate into our pipeline
|
||||
- Create a Python module that wraps ScaledGroupedGemmKernel
|
||||
- Replace cutlass_nvfp4_gemm calls with CuTeDSL kernel calls
|
||||
- Handle weight format (the CuTeDSL kernel expects raw FP4 + SF, no transpose)
|
||||
|
||||
### Phase 3: Full MoE pipeline
|
||||
- Wire up stage_activation → L1 GEMM → SiLU → stage_activation → L2 GEMM → scatter
|
||||
- Test with layertest.py (should get cosine ~0.995)
|
||||
|
||||
### Phase 4: vLLM integration
|
||||
- Update the vLLM patch to use the CuTeDSL kernel
|
||||
- Remove the C++ extension build from the Dockerfile
|
||||
- Test full inference
|
||||
|
||||
## Key Differences from Current Approach
|
||||
|
||||
| Current (C++ CUTLASS) | New (CuTeDSL) |
|
||||
|----------------------|---------------|
|
||||
| C++ .cu file | Python .py file |
|
||||
| Manual SF remap kernel | CuTe handles SF natively |
|
||||
| Manual weight transpose | CuTe handles layout |
|
||||
| pip install C++ extension | cute.compile() JIT |
|
||||
| Broken FP4 handling | Reference-verified |
|
||||
| No Blackwell pipeline | Full TMA/MMA/Epilogue overlap |
|
||||
657
reference/blockscaled_layout.py
Normal file
657
reference/blockscaled_layout.py
Normal file
@@ -0,0 +1,657 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Use of this software is governed by the terms and conditions of the
|
||||
# NVIDIA End User License Agreement (EULA), available at:
|
||||
# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html
|
||||
#
|
||||
# Any use, reproduction, disclosure, or distribution of this software
|
||||
# and related documentation outside the scope permitted by the EULA
|
||||
# is strictly prohibited.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from cutlass.cutlass_dsl import dsl_user_op
|
||||
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.nvgpu import OperandMajorMode
|
||||
|
||||
from cutlass._mlir import ir
|
||||
import cutlass._mlir.dialects.cute as _cute_ir
|
||||
import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BlockScaledBasicChunk:
|
||||
"""
|
||||
The basic scale factor atom layout decided by tcgen05 BlockScaled MMA Ops.
|
||||
|
||||
This class represents the fixed layout pattern for scale factors used in
|
||||
tcgen05 BlockScaled MMA Ops. The layout is determined by the
|
||||
instruction specification and cannot be modified.
|
||||
See `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x>`.
|
||||
"""
|
||||
|
||||
sf_vec_size: int
|
||||
major_mode: OperandMajorMode = OperandMajorMode.K
|
||||
_layout: cute.Layout = field(init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.major_mode == OperandMajorMode.K:
|
||||
# K-major layout: (AtomMN, AtomK)
|
||||
atom_shape = ((32, 4), (self.sf_vec_size, 4))
|
||||
atom_stride = ((16, 4), (0, 1))
|
||||
else:
|
||||
# MN-major layout: (AtomK, AtomMN)
|
||||
atom_shape = ((self.sf_vec_size, 4), (32, 4))
|
||||
atom_stride = ((0, 1), (16, 4))
|
||||
|
||||
object.__setattr__(
|
||||
self, "_layout", cute.make_layout(atom_shape, stride=atom_stride)
|
||||
)
|
||||
|
||||
@property
|
||||
def layout(self) -> cute.Layout:
|
||||
"""
|
||||
Get the layout for this block scaled chunk.
|
||||
|
||||
:return: The layout representing the scale factor atom
|
||||
:rtype: cute.Layout
|
||||
"""
|
||||
return self._layout
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def tile_atom_to_shape_SF(
|
||||
Shape: cute.Shape,
|
||||
sf_vec_size: int,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
) -> cute.Layout:
|
||||
"""
|
||||
A helper function to get dynamic SFA/SFB layout by filling dynamic A/B shape to the scale factor atom layout.
|
||||
|
||||
:param Shape: The shape of the A/B tensor
|
||||
:param sf_vec_size: Scale factor vector size
|
||||
|
||||
:return: The layout of the SFA/SFB tensor
|
||||
:rtype: cute.Layout
|
||||
"""
|
||||
# ((Atom_MN, Rest_MN),(Atom_K, Rest_K),RestL)
|
||||
sf_layout = cute.tile_to_shape(
|
||||
BlockScaledBasicChunk(sf_vec_size).layout, Shape, (2, 1, 3)
|
||||
)
|
||||
return sf_layout
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_smem_layout_sf(
|
||||
tile_shape: cute.Tile,
|
||||
sf_vec_size: int,
|
||||
num_stages: int,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
) -> cute.Layout:
|
||||
"""
|
||||
A helper function to get dynamic SFA/SFB layout by filling dynamic A/B shape to the scale factor atom layout.
|
||||
|
||||
:param Shape: The shape of the A/B tensor
|
||||
:param sf_vec_size: Scale factor vector size
|
||||
:param num_stages: Number of stages
|
||||
|
||||
:return: The layout of the SFA/SFB tensor
|
||||
:rtype: cute.Layout
|
||||
"""
|
||||
|
||||
smem_layout = cute.tile_to_shape(
|
||||
BlockScaledBasicChunk(sf_vec_size).layout,
|
||||
tile_shape, # type: ignore[arg-type]
|
||||
(2, 1),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
smem_layout_staged = cute.append(
|
||||
smem_layout,
|
||||
cute.make_layout(
|
||||
num_stages,
|
||||
stride=cute.cosize(cute.filter_zeros(smem_layout)),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
),
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
return smem_layout_staged
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_smem_layout_sfa(
|
||||
tiled_mma: cute.TiledMma,
|
||||
mma_tiler_mnk: cute.Tile,
|
||||
sf_vec_size: int,
|
||||
num_stages: int,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
) -> cute.Layout:
|
||||
"""
|
||||
Make smem layout for SFA based on:
|
||||
|
||||
1. BlockScaledBasicChunk
|
||||
2. MMA tiler shape
|
||||
3. Scale factor vector size
|
||||
4. Number of stages
|
||||
|
||||
:param tiled_mma: The tiled MMA
|
||||
:type tiled_mma: cute.TiledMma
|
||||
:param mma_tiler_mnk: The mma tiler shape
|
||||
:type mma_tiler_mnk: cute.Tile
|
||||
:param sf_vec_size: The scale factor vector size
|
||||
:type sf_vec_size: int
|
||||
:param num_stages: The number of stages
|
||||
:type num_stages: int
|
||||
|
||||
:return: Smem layout for SFA
|
||||
:rtype: cute.Layout
|
||||
"""
|
||||
# (CTA_Tile_Shape_M, MMA_Tile_Shape_K)
|
||||
sfa_tile_shape = (
|
||||
mma_tiler_mnk[0] // cute.size(tiled_mma.thr_id.shape), # type: ignore[index]
|
||||
mma_tiler_mnk[2], # type: ignore[index]
|
||||
)
|
||||
|
||||
# ((Atom_M, Rest_M),(Atom_K, Rest_K))
|
||||
smem_layout = cute.tile_to_shape(
|
||||
BlockScaledBasicChunk(sf_vec_size).layout,
|
||||
sfa_tile_shape, # type: ignore[arg-type]
|
||||
(2, 1),
|
||||
)
|
||||
|
||||
# Number of MMA instructions to cover all k-tiles
|
||||
mma_tile_inst_m = mma_tiler_mnk[0] // cute.size(tiled_mma.shape_mnk, mode=[0]) # type: ignore[index]
|
||||
mma_tile_inst_k = mma_tiler_mnk[2] // cute.size(tiled_mma.shape_mnk, mode=[2]) # type: ignore[index]
|
||||
|
||||
# (CTA_Tile_Shape_M, MMA_Inst_Shape_K)
|
||||
sfa_tile_shape = cute.shape_div(sfa_tile_shape, (mma_tile_inst_m, mma_tile_inst_k))
|
||||
# ((Atom_Inst_M, Atom_Inst_K), MMA_M, MMA_K))
|
||||
smem_layout = cute.tiled_divide(smem_layout, sfa_tile_shape)
|
||||
|
||||
atom_m = 128
|
||||
tiler_inst = ((atom_m, sf_vec_size),)
|
||||
# (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K)
|
||||
smem_layout = cute.logical_divide(smem_layout, tiler_inst)
|
||||
|
||||
# (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K, STAGE)
|
||||
sfa_smem_layout_staged = cute.append(
|
||||
smem_layout,
|
||||
cute.make_layout(
|
||||
num_stages, stride=cute.cosize(cute.filter_zeros(smem_layout))
|
||||
),
|
||||
)
|
||||
|
||||
return sfa_smem_layout_staged
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_smem_layout_sfb(
|
||||
tiled_mma: cute.TiledMma,
|
||||
mma_tiler_mnk: cute.Tile,
|
||||
sf_vec_size: int,
|
||||
num_stages: int,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
) -> cute.Layout:
|
||||
"""
|
||||
Make smem layout for SFB based on:
|
||||
|
||||
1. BlockScaledBasicChunk
|
||||
2. MMA tiler shape
|
||||
3. Scale factor vector size
|
||||
4. Number of stages
|
||||
|
||||
:param tiled_mma: The tiled MMA
|
||||
:type tiled_mma: cute.TiledMma
|
||||
:param mma_tiler_mnk: The mma tiler shape
|
||||
:type mma_tiler_mnk: cute.Tile
|
||||
:param sf_vec_size: The scale factor vector size
|
||||
:type sf_vec_size: int
|
||||
:param num_stages: The number of stages
|
||||
:type num_stages: int
|
||||
|
||||
:return: Smem layout for SFA
|
||||
:rtype: cute.Layout
|
||||
"""
|
||||
# (Round_Up(CTA_Tile_Shape_N, 128), MMA_Tile_Shape_K)
|
||||
sfb_tile_shape = (
|
||||
cute.round_up(mma_tiler_mnk[1], 128), # type: ignore[index, arg-type]
|
||||
mma_tiler_mnk[2], # type: ignore[index]
|
||||
)
|
||||
|
||||
# ((Atom_N, Rest_N),(Atom_K, Rest_K))
|
||||
smem_layout = cute.tile_to_shape(
|
||||
BlockScaledBasicChunk(sf_vec_size).layout,
|
||||
sfb_tile_shape, # type: ignore[arg-type]
|
||||
(2, 1),
|
||||
)
|
||||
|
||||
# Number of MMA instructions to cover all k-tiles
|
||||
mma_tile_inst_n = mma_tiler_mnk[1] // cute.size(tiled_mma.shape_mnk, mode=[1]) # type: ignore[index]
|
||||
mma_tile_inst_k = mma_tiler_mnk[2] // cute.size(tiled_mma.shape_mnk, mode=[2]) # type: ignore[index]
|
||||
|
||||
# (CTA_Tile_Shape_N, MMA_Inst_Shape_K)
|
||||
sfb_tile_shape = cute.shape_div(sfb_tile_shape, (mma_tile_inst_n, mma_tile_inst_k))
|
||||
# ((Atom_Inst_N, Atom_Inst_K), MMA_N, MMA_K)
|
||||
smem_layout = cute.tiled_divide(smem_layout, sfb_tile_shape)
|
||||
|
||||
atom_n = 128
|
||||
tiler_inst = ((atom_n, sf_vec_size),)
|
||||
# (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K)
|
||||
smem_layout = cute.logical_divide(smem_layout, tiler_inst)
|
||||
|
||||
# (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K, STAGE)
|
||||
sfb_smem_layout_staged = cute.append(
|
||||
smem_layout,
|
||||
cute.make_layout(
|
||||
num_stages, stride=cute.cosize(cute.filter_zeros(smem_layout))
|
||||
),
|
||||
)
|
||||
|
||||
return sfb_smem_layout_staged
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def sm120_make_smem_layout_sfa(
|
||||
tiled_mma: cute.TiledMma,
|
||||
tile_shape_mnk: cute.Tile,
|
||||
sf_vec_size: int,
|
||||
num_stages: int,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
) -> cute.Layout:
|
||||
"""
|
||||
Make smem layout for SFA based on:
|
||||
1. BlockScaledBasicChunk
|
||||
2. MMA tiler shape
|
||||
3. Scale factor vector size
|
||||
4. Number of stages
|
||||
|
||||
:param tiled_mma: The tiled MMA
|
||||
:type tiled_mma: cute.TiledMma
|
||||
:param mma_tiler_mnk: The mma tiler shape
|
||||
:type mma_tiler_mnk: cute.Tile
|
||||
:param sf_vec_size: The scale factor vector size
|
||||
:type sf_vec_size: int
|
||||
:param num_stages: The number of stages
|
||||
:type num_stages: int
|
||||
|
||||
:return: Smem layout for SFA
|
||||
:rtype: cute.Layout
|
||||
"""
|
||||
|
||||
assert sf_vec_size == 16 or sf_vec_size == 32, "sf_vec_size must be 16 or 32"
|
||||
|
||||
blk_mn = 128
|
||||
blk_sf = 4
|
||||
blk_elems = blk_mn * blk_sf
|
||||
mma_nsf = tiled_mma.shape_mnk[2] // sf_vec_size
|
||||
|
||||
mn_basic_block_shape = (32, 4)
|
||||
mn_basic_block_stride = (16, 4)
|
||||
k_basic_block_shape = (sf_vec_size, mma_nsf)
|
||||
k_basic_block_stride = (0, 1)
|
||||
|
||||
assert tile_shape_mnk[0] % blk_mn == 0, ( # type: ignore[index, operator]
|
||||
"tile_shape_mnk[0] must be divisible by blk_mn"
|
||||
)
|
||||
|
||||
sSFA_shapeM = (mn_basic_block_shape, tile_shape_mnk[0] // blk_mn) # type: ignore[index, operator]
|
||||
sSF_strideM = (mn_basic_block_stride, blk_elems)
|
||||
|
||||
assert tile_shape_mnk[2] % (blk_sf * mma_nsf) == 0, ( # type: ignore[index]
|
||||
"tile_shape_mnk[2] must be divisible by blk_sf * mma_nsf"
|
||||
)
|
||||
|
||||
sSFA_shapeK = (
|
||||
k_basic_block_shape,
|
||||
blk_sf // mma_nsf,
|
||||
tile_shape_mnk[2] // sf_vec_size // blk_sf, # type: ignore[index, operator]
|
||||
)
|
||||
sSF_strideK = (
|
||||
k_basic_block_stride,
|
||||
mma_nsf,
|
||||
tile_shape_mnk[0] // blk_mn * blk_elems, # type: ignore[index, operator]
|
||||
)
|
||||
|
||||
sSFA_shape = (sSFA_shapeM, sSFA_shapeK)
|
||||
sSFA_stride = (sSF_strideM, sSF_strideK)
|
||||
|
||||
smem_layout = cute.make_layout(sSFA_shape, stride=sSFA_stride)
|
||||
|
||||
# (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K, STAGE)
|
||||
sfa_smem_layout_staged = cute.append(
|
||||
smem_layout,
|
||||
cute.make_layout(
|
||||
num_stages, stride=cute.cosize(cute.filter_zeros(smem_layout))
|
||||
),
|
||||
)
|
||||
|
||||
return sfa_smem_layout_staged
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def sm120_make_smem_layout_sfb(
|
||||
tiled_mma: cute.TiledMma,
|
||||
tile_shape_mnk: cute.Tile,
|
||||
sf_vec_size: int,
|
||||
num_stages: int,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
) -> cute.Layout:
|
||||
"""
|
||||
Make smem layout for SFB based on:
|
||||
1. BlockScaledBasicChunk
|
||||
2. MMA tiler shape
|
||||
3. Scale factor vector size
|
||||
4. Number of stages
|
||||
|
||||
:param tiled_mma: The tiled MMA
|
||||
:type tiled_mma: cute.TiledMma
|
||||
:param mma_tiler_mnk: The mma tiler shape
|
||||
:type mma_tiler_mnk: cute.Tile
|
||||
:param sf_vec_size: The scale factor vector size
|
||||
:type sf_vec_size: int
|
||||
:param num_stages: The number of stages
|
||||
:type num_stages: int
|
||||
|
||||
:return: Smem layout for SFA
|
||||
:rtype: cute.Layout
|
||||
"""
|
||||
|
||||
# A single indivisible block will hold 4 scale factors of 128 rows/columns (A/B matrix).
|
||||
# 4 is chosen to make consecutive 32bits of data to have scale factors for only a single row(col).
|
||||
blk_mn = 128
|
||||
blk_sf = 4
|
||||
blk_elems = blk_mn * blk_sf
|
||||
|
||||
assert sf_vec_size == 16 or sf_vec_size == 32, "sf_vec_size must be 16 or 32"
|
||||
|
||||
assert tile_shape_mnk[1] % blk_mn == 0, ( # type: ignore[index, operator]
|
||||
"tile_shape_mnk[1] must be divisible by blk_mn"
|
||||
)
|
||||
|
||||
assert tile_shape_mnk[2] % sf_vec_size == 0, ( # type: ignore[index, operator]
|
||||
"tile_shape_mnk[2] must be divisible by sf_vec_size"
|
||||
)
|
||||
|
||||
mma_nsf = tiled_mma.shape_mnk[2] // sf_vec_size
|
||||
|
||||
mn_basic_block_shape = (32, 4)
|
||||
mn_basic_block_stride = (16, 4)
|
||||
k_basic_block_shape = (sf_vec_size, mma_nsf)
|
||||
k_basic_block_stride = (0, 1)
|
||||
|
||||
assert tile_shape_mnk[1] % blk_mn == 0, ( # type: ignore[index, operator]
|
||||
"tile_shape_mnk[1] must be divisible by blk_mn"
|
||||
)
|
||||
|
||||
sSFA_shapeN = (mn_basic_block_shape, tile_shape_mnk[1] // blk_mn) # type: ignore[index, operator]
|
||||
sSF_strideN = (mn_basic_block_stride, blk_elems)
|
||||
|
||||
assert tile_shape_mnk[2] % (blk_sf * mma_nsf) == 0, ( # type: ignore[index]
|
||||
"tile_shape_mnk[2] must be divisible by blk_sf * mma_nsf"
|
||||
)
|
||||
|
||||
sSFA_shapeK = (
|
||||
k_basic_block_shape,
|
||||
blk_sf // mma_nsf,
|
||||
tile_shape_mnk[2] // sf_vec_size // blk_sf, # type: ignore[index, operator]
|
||||
)
|
||||
sSF_strideK = (
|
||||
k_basic_block_stride,
|
||||
mma_nsf,
|
||||
tile_shape_mnk[1] // blk_mn * blk_elems, # type: ignore[index, operator]
|
||||
)
|
||||
|
||||
sSFA_shape = (sSFA_shapeN, sSFA_shapeK)
|
||||
sSFA_stride = (sSF_strideN, sSF_strideK)
|
||||
|
||||
smem_layout = cute.make_layout(sSFA_shape, stride=sSFA_stride)
|
||||
|
||||
# (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K, STAGE)
|
||||
sfb_smem_layout_staged = cute.append(
|
||||
smem_layout,
|
||||
cute.make_layout(
|
||||
num_stages, stride=cute.cosize(cute.filter_zeros(smem_layout))
|
||||
),
|
||||
)
|
||||
|
||||
return sfb_smem_layout_staged
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_tmem_layout_sfa(
|
||||
tiled_mma: cute.TiledMma,
|
||||
mma_tiler_mnk: cute.Tile,
|
||||
sf_vec_size: int,
|
||||
smem_layout: cute.Layout,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
) -> cute.Layout:
|
||||
"""Make tmem layout for SFA based on:
|
||||
|
||||
1. SFA smem layout per stage
|
||||
2. Cta tile shape m
|
||||
3. tiled MMA atom thr size
|
||||
4. Scale factor vector size
|
||||
|
||||
:param tiled_mma: The tiled MMA
|
||||
:type tiled_mma: cute.TiledMma
|
||||
:param mma_tiler_mnk: The mma tiler shape
|
||||
:type mma_tiler_mnk: cute.Tile
|
||||
:param sf_vec_size: The scale factor vector size
|
||||
:type sf_vec_size: int
|
||||
:param smem_layout: The smem layout of SFA per stage
|
||||
:type smem_layout: cute.Layout
|
||||
|
||||
:return: TMEM layout for SFA
|
||||
:rtype: cute.Layout
|
||||
"""
|
||||
atom_thr_size = cute.size(tiled_mma.thr_id.shape, loc=loc, ip=ip)
|
||||
cta_tile_shape_m = mma_tiler_mnk[0] // atom_thr_size # type: ignore[index]
|
||||
|
||||
sfa_layout_ty = _cute_nvgpu_ir.make_tmem_layout_sfa(
|
||||
smem_layout, cta_tile_shape_m, atom_thr_size, sf_vec_size
|
||||
)
|
||||
return _cute_ir.static(sfa_layout_ty, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def make_tmem_layout_sfb(
|
||||
tiled_mma: cute.TiledMma,
|
||||
mma_tiler_mnk: cute.Tile,
|
||||
sf_vec_size: int,
|
||||
smem_layout: cute.Layout,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
) -> cute.Layout:
|
||||
"""Make tmem layout for SFB based on:
|
||||
|
||||
1. SFB smem layout per stage
|
||||
2. Cta tile shape m
|
||||
3. tiled MMA atom thr size
|
||||
4. Scale factor vector size
|
||||
|
||||
:param tiled_mma: The tiled MMA
|
||||
:type tiled_mma: cute.TiledMma
|
||||
:param mma_tiler_mnk: The mma tiler shape
|
||||
:type mma_tiler_mnk: cute.Tile
|
||||
:param sf_vec_size: The scale factor vector size
|
||||
:type sf_vec_size: int
|
||||
:param smem_layout: The smem layout of SFB per stage
|
||||
:type smem_layout: cute.Layout
|
||||
|
||||
:return: TMEM layout for SFB
|
||||
:rtype: cute.Layout
|
||||
"""
|
||||
atom_thr_size = cute.size(tiled_mma.thr_id.shape, loc=loc, ip=ip)
|
||||
cta_tile_shape_m = mma_tiler_mnk[0] // atom_thr_size # type: ignore[index]
|
||||
|
||||
sfb_layout_ty = _cute_nvgpu_ir.make_tmem_layout_sfb(
|
||||
smem_layout, cta_tile_shape_m, atom_thr_size, sf_vec_size
|
||||
)
|
||||
return _cute_ir.static(sfb_layout_ty, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Sm103BlockScaledBasicChunk:
|
||||
"""
|
||||
Basic scale-factor atom layout decided by tcgen05 BlockScaled MMA Ops on SM103.
|
||||
|
||||
Represents the fixed layout pattern for scale factors used by tcgen05
|
||||
BlockScaled MMA Ops on SM103. The layout is determined by the instruction
|
||||
specification and is not configurable.
|
||||
"""
|
||||
|
||||
sf_vec_size: int
|
||||
major_mode: OperandMajorMode = OperandMajorMode.K
|
||||
_layout: cute.Layout = field(init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
atom_shape: cute.Shape
|
||||
atom_stride: cute.Stride
|
||||
if self.major_mode == OperandMajorMode.K:
|
||||
atom_shape = ((8, 4, 4), (self.sf_vec_size, 4))
|
||||
atom_stride = ((16, 128, 4), (0, 1))
|
||||
else:
|
||||
atom_shape = ((self.sf_vec_size, 4), (8, 4, 4))
|
||||
atom_stride = ((0, 1), (16, 128, 4))
|
||||
|
||||
object.__setattr__(
|
||||
self, "_layout", cute.make_layout(shape=atom_shape, stride=atom_stride)
|
||||
)
|
||||
|
||||
@property
|
||||
def layout(self) -> cute.Layout:
|
||||
return self._layout
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def sm103_make_smem_layout_sfa(
|
||||
tiled_mma: cute.TiledMma,
|
||||
mma_tiler: cute.Tile,
|
||||
sf_vec_size: int,
|
||||
num_stages: int,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
) -> cute.Layout:
|
||||
"""
|
||||
Make SMEM layout for SFA based on:
|
||||
1) Sm103BlockScaledBasicChunk, 2) MMA tiler, 3) sf_vec_size, 4) stages.
|
||||
|
||||
:param tiled_mma: The tiled MMA
|
||||
:type tiled_mma: cute.TiledMma
|
||||
:param mma_tiler: The mma tiler shape
|
||||
:type mma_tiler: cute.Tile
|
||||
:param sf_vec_size: The scale factor vector size
|
||||
:type sf_vec_size: int
|
||||
:param num_stages: The number of stages
|
||||
:type num_stages: int
|
||||
|
||||
:return: Smem layout for SFA
|
||||
:rtype: cute.Layout
|
||||
"""
|
||||
mma_shape_mk = tiled_mma.partition_shape_A((mma_tiler[0], mma_tiler[2])) # type: ignore[index]
|
||||
sf_atom = Sm103BlockScaledBasicChunk(sf_vec_size, tiled_mma.op.a_major_mode).layout # type: ignore[attr-defined]
|
||||
k_divisor = 4 if sf_vec_size == 16 else 2
|
||||
mma_sfa_tiler = (
|
||||
mma_shape_mk[0][0] * mma_shape_mk[1],
|
||||
mma_shape_mk[0][1] * mma_shape_mk[2] // k_divisor,
|
||||
)
|
||||
sfa_smem_atom_layout = cute.tiled_product(
|
||||
sf_atom,
|
||||
cute.make_layout(
|
||||
cute.shape_div(mma_sfa_tiler, cute.product_each(sf_atom.shape))
|
||||
),
|
||||
)
|
||||
sfa_smem_layout_staged = cute.make_layout(
|
||||
shape=cute.append(sfa_smem_atom_layout.shape, num_stages),
|
||||
stride=cute.append(
|
||||
sfa_smem_atom_layout.stride,
|
||||
cute.size(cute.filter_zeros(sfa_smem_atom_layout)),
|
||||
),
|
||||
)
|
||||
return sfa_smem_layout_staged
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def sm103_make_smem_layout_sfb(
|
||||
tiled_mma: cute.TiledMma,
|
||||
mma_tiler: cute.Tile,
|
||||
sf_vec_size: int,
|
||||
num_stages: int,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
) -> cute.Layout:
|
||||
"""
|
||||
Make SMEM layout for SFB based on the basic chunk, MMA tiler, sf_vec_size, stages.
|
||||
|
||||
:param tiled_mma: The tiled MMA
|
||||
:type tiled_mma: cute.TiledMma
|
||||
:param mma_tiler: The mma tiler shape
|
||||
:type mma_tiler: cute.Tile
|
||||
:param sf_vec_size: The scale factor vector size
|
||||
:type sf_vec_size: int
|
||||
:param num_stages: The number of stages
|
||||
:type num_stages: int
|
||||
|
||||
:return: Smem layout for SFB
|
||||
:rtype: cute.Layout
|
||||
"""
|
||||
sf_atom = Sm103BlockScaledBasicChunk(sf_vec_size, tiled_mma.op.b_major_mode).layout # type: ignore[attr-defined]
|
||||
k_divisor = 4 if sf_vec_size == 16 else 2
|
||||
mma_sfb_tiler = (mma_tiler[1], mma_tiler[2] // k_divisor) # type: ignore[index, operator]
|
||||
if mma_sfb_tiler[0] == 128:
|
||||
sfb_smem_atom_layout = cute.tiled_product(
|
||||
sf_atom,
|
||||
cute.make_layout(
|
||||
cute.shape_div(mma_sfb_tiler, cute.product_each(sf_atom.shape))
|
||||
),
|
||||
)
|
||||
else:
|
||||
sf_k_major_atom256 = cute.make_layout(
|
||||
shape=(
|
||||
(32, 4, 2),
|
||||
(sf_vec_size, 4),
|
||||
),
|
||||
stride=(
|
||||
(16, 4, mma_sfb_tiler[1] // sf_vec_size // 4 * 512),
|
||||
(0, 1),
|
||||
),
|
||||
)
|
||||
sfb_smem_atom_layout = cute.tiled_product(
|
||||
sf_k_major_atom256,
|
||||
cute.make_layout(
|
||||
cute.shape_div(
|
||||
mma_sfb_tiler, cute.product_each(sf_k_major_atom256.shape)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
sfb_smem_layout_staged = cute.make_layout(
|
||||
shape=cute.append(sfb_smem_atom_layout.shape, num_stages),
|
||||
stride=cute.append(
|
||||
sfb_smem_atom_layout.stride,
|
||||
cute.size(cute.filter_zeros(sfb_smem_atom_layout)),
|
||||
),
|
||||
)
|
||||
return sfb_smem_layout_staged
|
||||
3152
reference/dense_blockscaled_gemm_persistent.py
Normal file
3152
reference/dense_blockscaled_gemm_persistent.py
Normal file
File diff suppressed because it is too large
Load Diff
3278
reference/grouped_blockscaled_gemm.py
Normal file
3278
reference/grouped_blockscaled_gemm.py
Normal file
File diff suppressed because it is too large
Load Diff
695
reference/moe_moe_persistent_scheduler.py
Normal file
695
reference/moe_moe_persistent_scheduler.py
Normal file
@@ -0,0 +1,695 @@
|
||||
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
"""
|
||||
MoE Persistent Tile Scheduler
|
||||
|
||||
A specialized tile scheduler for MoE (Mixture of Experts) grouped GEMM operations.
|
||||
This scheduler handles tile iteration across all experts, producing MoEWorkTileInfo
|
||||
(expert_idx, tile_m_idx, tile_n_idx, k_tile_cnt) for each tile.
|
||||
|
||||
Scenarios:
|
||||
- 2Dx3D (Forward): A(tokens_sum, hidden) x B(experts, intermediate, hidden) -> C(tokens_sum, intermediate)
|
||||
- 2Dx2D (Backward): A(intermediate, tokens_sum) x B(hidden, tokens_sum) -> C(experts, intermediate, hidden)
|
||||
|
||||
Key design principle:
|
||||
- Scheduler is ONLY responsible for tile iteration (tensor-agnostic, TMA-agnostic)
|
||||
- Domain conversion (fake tensor -> real expert tensor) is handled by MoESchedExtension
|
||||
- TMA descriptor management is handled by OnlineTensormapDescCreator
|
||||
- The kernel orchestrates all three components
|
||||
"""
|
||||
|
||||
from typing import List, Tuple, Literal
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cutlass_dsl import (
|
||||
Boolean,
|
||||
Int32,
|
||||
Integer,
|
||||
extract_mlir_values,
|
||||
new_from_mlir_values,
|
||||
const_expr,
|
||||
dsl_user_op,
|
||||
)
|
||||
from cutlass._mlir import ir
|
||||
|
||||
# =============================================================================
|
||||
# Work Tile Info
|
||||
# =============================================================================
|
||||
|
||||
class MoEWorkTileInfo:
|
||||
"""
|
||||
Work tile information for MoE scheduler.
|
||||
|
||||
Contains CTA-level tile information for executor warps:
|
||||
- expert_idx: Which expert (-1 means invalid/done)
|
||||
- tile_m_idx: CTA tile index along GEMM M dimension
|
||||
- tile_n_idx: CTA tile index along GEMM N dimension
|
||||
- k_tile_cnt: Number of CTA tiles along K dimension
|
||||
|
||||
Note: These are CTA-level indices, not cluster-level.
|
||||
tile_l_idx is always 0 for MoE, executor can hardcode it.
|
||||
|
||||
For 2Dx3D (Forward):
|
||||
M = tokens_i (dynamic), N = intermediate (fixed), K = hidden (fixed)
|
||||
|
||||
For 2Dx2D (Backward):
|
||||
M = intermediate (fixed), N = hidden (fixed), K = tokens_i (dynamic)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
expert_idx: Int32, # -1 means invalid tile
|
||||
tile_m_idx: Int32,
|
||||
tile_n_idx: Int32,
|
||||
k_tile_cnt: Int32,
|
||||
):
|
||||
self.expert_idx = expert_idx
|
||||
self.tile_m_idx = tile_m_idx
|
||||
self.tile_n_idx = tile_n_idx
|
||||
self.k_tile_cnt = k_tile_cnt
|
||||
|
||||
@property
|
||||
def is_valid_tile(self) -> Boolean:
|
||||
"""Check if this is a valid work tile (expert_idx >= 0)."""
|
||||
return self.expert_idx >= Int32(0)
|
||||
|
||||
def __extract_mlir_values__(self) -> List[ir.Value]:
|
||||
values = extract_mlir_values(self.expert_idx)
|
||||
values.extend(extract_mlir_values(self.tile_m_idx))
|
||||
values.extend(extract_mlir_values(self.tile_n_idx))
|
||||
values.extend(extract_mlir_values(self.k_tile_cnt))
|
||||
return values
|
||||
|
||||
def __new_from_mlir_values__(self, values: List[ir.Value]) -> "MoEWorkTileInfo":
|
||||
assert len(values) == 4
|
||||
return MoEWorkTileInfo(
|
||||
expert_idx=new_from_mlir_values(self.expert_idx, [values[0]]),
|
||||
tile_m_idx=new_from_mlir_values(self.tile_m_idx, [values[1]]),
|
||||
tile_n_idx=new_from_mlir_values(self.tile_n_idx, [values[2]]),
|
||||
k_tile_cnt=new_from_mlir_values(self.k_tile_cnt, [values[3]]),
|
||||
)
|
||||
|
||||
def to_rmem_tensor(self):
|
||||
"""Pack work tile info fields into an rmem tensor of shape (4,) for vectorized smem copy."""
|
||||
rmem = cute.make_rmem_tensor((4,), Int32)
|
||||
rmem[0] = self.expert_idx
|
||||
rmem[1] = self.tile_m_idx
|
||||
rmem[2] = self.tile_n_idx
|
||||
rmem[3] = self.k_tile_cnt
|
||||
return rmem
|
||||
|
||||
@staticmethod
|
||||
def from_rmem_tensor(rmem) -> "MoEWorkTileInfo":
|
||||
"""Unpack work tile info from an rmem tensor of shape (4,)."""
|
||||
return MoEWorkTileInfo(
|
||||
expert_idx=rmem[0], # type: ignore[arg-type]
|
||||
tile_m_idx=rmem[1], # type: ignore[arg-type]
|
||||
tile_n_idx=rmem[2], # type: ignore[arg-type]
|
||||
k_tile_cnt=rmem[3], # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Scheduler Parameters
|
||||
# =============================================================================
|
||||
|
||||
class MoEStaticSchedulerParams:
|
||||
"""
|
||||
Parameters for MoE tile scheduler.
|
||||
|
||||
Uses unified semantics for both scenarios:
|
||||
- expert_shape: (expert_cnt, intermediate, hidden)
|
||||
|
||||
For 2Dx3D: GEMM is (M=tokens_i, N=intermediate, K=hidden) per expert
|
||||
For 2Dx2D: GEMM is (M=hidden, N=intermediate, K=tokens_i) per expert
|
||||
|
||||
Tile hierarchy:
|
||||
- cta_tile_shape_mnk: Single CTA tile shape (tile_m, tile_n, tile_k)
|
||||
- cluster_shape_mn: CTAs per cluster (cluster_m, cluster_n)
|
||||
- cluster_tile_shape_mn: Cluster tile shape = cta_tile_shape * cluster_shape
|
||||
|
||||
This class is used both on host (for grid shape calculation) and on device
|
||||
(stored in scheduler). Codegen-time constants (scenario, cta_tile_shape_mnk,
|
||||
cluster_shape_mn) are NOT serialized to MLIR values.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scenario: Literal["2Dx3D", "2Dx2D"],
|
||||
expert_shape: Tuple[int | Int32, int | Int32, int | Int32], # (expert_cnt, intermediate, hidden)
|
||||
cta_tile_shape_mnk: Tuple[int, int, int], # (tile_m, tile_n, tile_k)
|
||||
cluster_shape_mn: Tuple[int, int], # (cluster_m, cluster_n)
|
||||
):
|
||||
self.scenario = scenario
|
||||
e, i, h = expert_shape
|
||||
self.expert_cnt = e if isinstance(e, Int32) else Int32(e)
|
||||
self.intermediate = i if isinstance(i, Int32) else Int32(i)
|
||||
self.hidden = h if isinstance(h, Int32) else Int32(h)
|
||||
self.cta_tile_shape_mnk = cta_tile_shape_mnk
|
||||
self.cluster_shape_mn = cluster_shape_mn
|
||||
|
||||
@property
|
||||
def cluster_tile_m(self) -> int:
|
||||
"""Cluster tile size along M = cta_tile_m * cluster_m."""
|
||||
return self.cta_tile_shape_mnk[0] * self.cluster_shape_mn[0]
|
||||
|
||||
@property
|
||||
def cluster_tile_n(self) -> int:
|
||||
"""Cluster tile size along N = cta_tile_n * cluster_n."""
|
||||
return self.cta_tile_shape_mnk[1] * self.cluster_shape_mn[1]
|
||||
|
||||
@property
|
||||
def cta_tile_k(self) -> int:
|
||||
"""CTA tile size along K (same as cluster since cluster_k = 1)."""
|
||||
return self.cta_tile_shape_mnk[2]
|
||||
|
||||
def __extract_mlir_values__(self) -> List[ir.Value]:
|
||||
"""Only serialize runtime values, not codegen-time constants."""
|
||||
values = []
|
||||
values.extend(extract_mlir_values(self.expert_cnt))
|
||||
values.extend(extract_mlir_values(self.intermediate))
|
||||
values.extend(extract_mlir_values(self.hidden))
|
||||
return values
|
||||
|
||||
def __new_from_mlir_values__(self, values: List[ir.Value]) -> "MoEStaticSchedulerParams":
|
||||
assert len(values) == 3
|
||||
return MoEStaticSchedulerParams(
|
||||
scenario=self.scenario,
|
||||
expert_shape=(
|
||||
new_from_mlir_values(self.expert_cnt, [values[0]]),
|
||||
new_from_mlir_values(self.intermediate, [values[1]]),
|
||||
new_from_mlir_values(self.hidden, [values[2]]),
|
||||
),
|
||||
cta_tile_shape_mnk=self.cta_tile_shape_mnk,
|
||||
cluster_shape_mn=self.cluster_shape_mn,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_grid_shape(
|
||||
params: "MoEStaticSchedulerParams",
|
||||
max_active_clusters: int,
|
||||
) -> Tuple[int, int, int]:
|
||||
"""
|
||||
Compute grid shape for kernel launch.
|
||||
|
||||
Since host doesn't know token distribution across experts,
|
||||
we launch max_active_clusters and let device-side scheduler
|
||||
determine which tiles are valid.
|
||||
"""
|
||||
return (
|
||||
params.cluster_shape_mn[0],
|
||||
params.cluster_shape_mn[1],
|
||||
max_active_clusters,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Scheduler (Device-side)
|
||||
# =============================================================================
|
||||
|
||||
class MoEStaticPersistentTileScheduler:
|
||||
"""
|
||||
Persistent tile scheduler specialized for MoE grouped GEMM.
|
||||
|
||||
This scheduler is ONLY responsible for tile iteration. It does NOT know
|
||||
about tensor types, TMA descriptors, or domain conversion. Those concerns
|
||||
are handled by MoESchedExtension and OnlineTensormapDescCreator respectively.
|
||||
|
||||
Architecture:
|
||||
- Scheduler warp: Holds scheduler instance, iterates tiles, broadcasts work_tile_info
|
||||
- Executor warps: Read work_tile_info from smem, use MoESchedExtension for
|
||||
domain conversion and TMA desc selection
|
||||
|
||||
The scheduler handles:
|
||||
- 2Dx3D: Dynamic M per expert (from offs), fixed N (intermediate) and K (hidden)
|
||||
- 2Dx2D: Fixed M (intermediate) and N (hidden), dynamic K per expert (reduction axis)
|
||||
|
||||
Usage (Scheduler warp):
|
||||
scheduler = MoEStaticPersistentTileScheduler.create(params, offs, block_idx, grid_dim)
|
||||
work_tile_info = scheduler.initial_work_tile_info()
|
||||
# Broadcast work_tile_info to smem...
|
||||
|
||||
while work_tile_info.is_valid_tile:
|
||||
# ... do work ...
|
||||
work_tile_info = scheduler.advance_to_next_work()
|
||||
# Broadcast work_tile_info to smem...
|
||||
|
||||
Usage (Executor warps - via MoESchedExtension):
|
||||
# Read work_tile_info from smem...
|
||||
real_a, desc_a = ext.get_gmem_tensor("a", tma_tensor_a, offs, work_tile_info)
|
||||
real_b, desc_b = ext.get_gmem_tensor("b", tma_tensor_b, offs, work_tile_info)
|
||||
real_c, desc_c = ext.get_gmem_tensor("c", tma_tensor_c, offs, work_tile_info)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# Params (contains scenario, expert_cnt, intermediate, hidden, tile/cluster shapes)
|
||||
params: MoEStaticSchedulerParams,
|
||||
# Runtime tensor for scheduling
|
||||
offs: cute.Tensor, # (experts,) cumsum of token counts
|
||||
# Scheduling state
|
||||
num_persistent_clusters: Int32,
|
||||
current_work_linear_idx: Int32,
|
||||
cta_id_in_cluster: cute.Coord,
|
||||
# Expert tracking state (for O(1) advance within same expert)
|
||||
current_expert_idx: Int32,
|
||||
expert_tile_start: Int32, # cumsum of tiles before current expert
|
||||
expert_tile_end: Int32, # cumsum of tiles including current expert
|
||||
):
|
||||
self.params = params
|
||||
self.offs = offs
|
||||
self.num_persistent_clusters = num_persistent_clusters
|
||||
self._current_work_linear_idx = current_work_linear_idx
|
||||
self.cta_id_in_cluster = cta_id_in_cluster
|
||||
# Expert tracking
|
||||
self.current_expert_idx = current_expert_idx
|
||||
self.expert_tile_start = expert_tile_start
|
||||
self.expert_tile_end = expert_tile_end
|
||||
|
||||
# =========================================================================
|
||||
# Convenience accessors for params
|
||||
# =========================================================================
|
||||
|
||||
@property
|
||||
def scenario(self) -> Literal["2Dx3D", "2Dx2D"]:
|
||||
return self.params.scenario
|
||||
|
||||
@property
|
||||
def expert_cnt(self) -> Int32:
|
||||
return self.params.expert_cnt
|
||||
|
||||
@property
|
||||
def intermediate(self) -> Int32:
|
||||
return self.params.intermediate
|
||||
|
||||
@property
|
||||
def hidden(self) -> Int32:
|
||||
return self.params.hidden
|
||||
|
||||
@property
|
||||
def cta_tile_shape_mnk(self) -> Tuple[int, int, int]:
|
||||
return self.params.cta_tile_shape_mnk
|
||||
|
||||
@property
|
||||
def cluster_shape_mn(self) -> Tuple[int, int]:
|
||||
return self.params.cluster_shape_mn
|
||||
|
||||
@property
|
||||
def cluster_tile_m(self) -> int:
|
||||
return self.params.cluster_tile_m
|
||||
|
||||
@property
|
||||
def cluster_tile_n(self) -> int:
|
||||
return self.params.cluster_tile_n
|
||||
|
||||
@property
|
||||
def cta_tile_k(self) -> int:
|
||||
return self.params.cta_tile_k
|
||||
|
||||
# =========================================================================
|
||||
# MLIR value serialization (for SSA value passing in device code)
|
||||
# =========================================================================
|
||||
|
||||
def __extract_mlir_values__(self) -> List[ir.Value]:
|
||||
values = []
|
||||
# Params (only runtime values are extracted)
|
||||
values.extend(extract_mlir_values(self.params))
|
||||
# Runtime tensor for scheduling
|
||||
values.extend(extract_mlir_values(self.offs))
|
||||
# Scheduling state
|
||||
values.extend(extract_mlir_values(self.num_persistent_clusters))
|
||||
values.extend(extract_mlir_values(self._current_work_linear_idx))
|
||||
values.extend(extract_mlir_values(self.cta_id_in_cluster))
|
||||
# Expert tracking state
|
||||
values.extend(extract_mlir_values(self.current_expert_idx))
|
||||
values.extend(extract_mlir_values(self.expert_tile_start))
|
||||
values.extend(extract_mlir_values(self.expert_tile_end))
|
||||
return values
|
||||
|
||||
def __new_from_mlir_values__(
|
||||
self, values: List[ir.Value]
|
||||
) -> "MoEStaticPersistentTileScheduler":
|
||||
idx = 0
|
||||
|
||||
# Params (3 values: expert_cnt, intermediate, hidden)
|
||||
new_params = new_from_mlir_values(self.params, values[idx:idx + 3])
|
||||
idx += 3
|
||||
|
||||
# Runtime tensor for scheduling (variable size)
|
||||
offs_len = len(extract_mlir_values(self.offs))
|
||||
new_offs = new_from_mlir_values(self.offs, values[idx:idx + offs_len])
|
||||
idx += offs_len
|
||||
|
||||
# Scheduling state
|
||||
new_num_persistent_clusters = new_from_mlir_values(
|
||||
self.num_persistent_clusters, [values[idx]]
|
||||
)
|
||||
idx += 1
|
||||
new_current_work_linear_idx = new_from_mlir_values(
|
||||
self._current_work_linear_idx, [values[idx]]
|
||||
)
|
||||
idx += 1
|
||||
|
||||
# cta_id_in_cluster (3 values for Coord)
|
||||
new_cta_id_in_cluster = new_from_mlir_values(
|
||||
self.cta_id_in_cluster, values[idx:idx + 3]
|
||||
)
|
||||
idx += 3
|
||||
|
||||
# Expert tracking state
|
||||
new_current_expert_idx = new_from_mlir_values(
|
||||
self.current_expert_idx, [values[idx]]
|
||||
)
|
||||
idx += 1
|
||||
new_expert_tile_start = new_from_mlir_values(
|
||||
self.expert_tile_start, [values[idx]]
|
||||
)
|
||||
idx += 1
|
||||
new_expert_tile_end = new_from_mlir_values(
|
||||
self.expert_tile_end, [values[idx]]
|
||||
)
|
||||
idx += 1
|
||||
|
||||
return MoEStaticPersistentTileScheduler(
|
||||
params=new_params,
|
||||
offs=new_offs,
|
||||
num_persistent_clusters=new_num_persistent_clusters,
|
||||
current_work_linear_idx=new_current_work_linear_idx,
|
||||
cta_id_in_cluster=new_cta_id_in_cluster,
|
||||
current_expert_idx=new_current_expert_idx,
|
||||
expert_tile_start=new_expert_tile_start,
|
||||
expert_tile_end=new_expert_tile_end,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Factory method
|
||||
# =========================================================================
|
||||
|
||||
@staticmethod
|
||||
@dsl_user_op
|
||||
def create(
|
||||
params: MoEStaticSchedulerParams,
|
||||
offs: cute.Tensor,
|
||||
block_idx: Tuple[Integer, Integer, Integer],
|
||||
grid_dim: Tuple[Integer, Integer, Integer],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> "MoEStaticPersistentTileScheduler":
|
||||
"""
|
||||
Create a MoE persistent tile scheduler.
|
||||
|
||||
:param params: Scheduler parameters (from host)
|
||||
:param offs: Cumsum tensor of token counts per expert, shape (experts,)
|
||||
:param block_idx: CUDA block index
|
||||
:param grid_dim: CUDA grid dimensions
|
||||
"""
|
||||
num_persistent_clusters = cute.size(grid_dim, loc=loc, ip=ip) // cute.size(
|
||||
params.cluster_shape_mn, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
bidx, bidy, bidz = block_idx
|
||||
current_work_linear_idx = Int32(bidz)
|
||||
|
||||
cta_id_in_cluster = (
|
||||
Int32(bidx % params.cluster_shape_mn[0]),
|
||||
Int32(bidy % params.cluster_shape_mn[1]),
|
||||
Int32(0),
|
||||
)
|
||||
|
||||
# Initialize expert tracking to "before expert 0"
|
||||
# The first call to _get_work_tile_for_linear_idx will advance to the correct expert
|
||||
current_expert_idx = Int32(0)
|
||||
expert_tile_start = Int32(0)
|
||||
expert_tile_end = Int32(0) # Will be computed on first access
|
||||
|
||||
return MoEStaticPersistentTileScheduler(
|
||||
params=params,
|
||||
offs=offs,
|
||||
num_persistent_clusters=num_persistent_clusters,
|
||||
current_work_linear_idx=current_work_linear_idx,
|
||||
cta_id_in_cluster=cta_id_in_cluster,
|
||||
current_expert_idx=current_expert_idx,
|
||||
expert_tile_start=expert_tile_start,
|
||||
expert_tile_end=expert_tile_end,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Tile iteration methods
|
||||
# =========================================================================
|
||||
|
||||
@dsl_user_op
|
||||
@cute.jit
|
||||
def initial_work_tile_info(self, *, loc=None, ip=None) -> MoEWorkTileInfo:
|
||||
"""Get the initial work tile info."""
|
||||
return self._get_work_tile_for_linear_idx(
|
||||
self._current_work_linear_idx, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
@dsl_user_op
|
||||
@cute.jit
|
||||
def advance_to_next_work(self, *, loc=None, ip=None) -> MoEWorkTileInfo:
|
||||
"""Advance to the next work tile and return its info."""
|
||||
self._current_work_linear_idx += self.num_persistent_clusters
|
||||
return self._get_work_tile_for_linear_idx(
|
||||
self._current_work_linear_idx, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
@dsl_user_op
|
||||
@cute.jit
|
||||
def _get_work_tile_for_linear_idx(
|
||||
self,
|
||||
cluster_linear_idx: Int32,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None
|
||||
) -> MoEWorkTileInfo:
|
||||
"""
|
||||
Convert a linear cluster index to MoEWorkTileInfo.
|
||||
|
||||
Uses cached expert tracking state for O(1) fast path when staying
|
||||
within the same expert. Advances expert state when needed.
|
||||
|
||||
Returns an invalid tile (expert_idx = -1) if cluster_linear_idx is out of range.
|
||||
"""
|
||||
# Ensure expert tracking is initialized and up-to-date
|
||||
self._advance_expert_to_contain(cluster_linear_idx, loc=loc, ip=ip)
|
||||
|
||||
# Check if valid (still within expert range after advancing)
|
||||
is_valid = self.current_expert_idx < self.expert_cnt
|
||||
|
||||
work_tile_info = MoEWorkTileInfo(
|
||||
expert_idx=Int32(-1),
|
||||
tile_m_idx=Int32(0),
|
||||
tile_n_idx=Int32(0),
|
||||
k_tile_cnt=Int32(0),
|
||||
)
|
||||
|
||||
if is_valid:
|
||||
# Compute local cluster tile indices within current expert
|
||||
local_idx = cluster_linear_idx - self.expert_tile_start
|
||||
cluster_tile_m_idx, cluster_tile_n_idx = self._decompose_local_idx(
|
||||
local_idx, self.current_expert_idx, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
# Convert cluster tile indices to CTA tile indices
|
||||
# cta_tile_idx = cluster_tile_idx * cluster_shape + cta_id_in_cluster
|
||||
cta_tile_m_idx = (
|
||||
cluster_tile_m_idx * self.cluster_shape_mn[0]
|
||||
+ self.cta_id_in_cluster[0] # type: ignore[index]
|
||||
)
|
||||
cta_tile_n_idx = (
|
||||
cluster_tile_n_idx * self.cluster_shape_mn[1]
|
||||
+ self.cta_id_in_cluster[1] # type: ignore[index]
|
||||
)
|
||||
# Compute k_tile_cnt
|
||||
k_tile_cnt = self._compute_k_tile_cnt(self.current_expert_idx, loc=loc, ip=ip)
|
||||
|
||||
work_tile_info = MoEWorkTileInfo(
|
||||
expert_idx=self.current_expert_idx,
|
||||
tile_m_idx=cta_tile_m_idx,
|
||||
tile_n_idx=cta_tile_n_idx,
|
||||
k_tile_cnt=k_tile_cnt,
|
||||
)
|
||||
return work_tile_info
|
||||
|
||||
@dsl_user_op
|
||||
@cute.jit
|
||||
def _advance_expert_to_contain(
|
||||
self,
|
||||
cluster_linear_idx: Int32,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> None:
|
||||
"""
|
||||
Advance expert tracking state until current expert contains cluster_linear_idx,
|
||||
or we run out of experts.
|
||||
|
||||
Fast path: If already in correct expert, no work needed.
|
||||
"""
|
||||
# Initialize expert_tile_end if this is the first call (expert_tile_end == 0)
|
||||
if self.expert_tile_end == Int32(0):
|
||||
tiles_for_expert_0 = self._compute_tiles_for_expert(Int32(0), loc=loc, ip=ip)
|
||||
self.expert_tile_end = tiles_for_expert_0
|
||||
|
||||
# Advance until cluster_linear_idx < expert_tile_end or no more experts
|
||||
while cluster_linear_idx >= self.expert_tile_end and self.current_expert_idx < self.expert_cnt:
|
||||
self.current_expert_idx = self.current_expert_idx + 1
|
||||
self.expert_tile_start = self.expert_tile_end
|
||||
|
||||
if self.current_expert_idx < self.expert_cnt:
|
||||
tiles_for_expert = self._compute_tiles_for_expert(
|
||||
self.current_expert_idx, loc=loc, ip=ip
|
||||
)
|
||||
self.expert_tile_end = self.expert_tile_end + tiles_for_expert
|
||||
|
||||
@dsl_user_op
|
||||
@cute.jit
|
||||
def _compute_tiles_for_expert(
|
||||
self,
|
||||
expert_idx: Int32,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Int32:
|
||||
"""Compute total cluster tiles for a given expert."""
|
||||
if const_expr(self.scenario == "2Dx2D"):
|
||||
# Fixed M=hidden, N=intermediate
|
||||
cluster_tile_m_cnt = (self.hidden + self.cluster_tile_m - 1) // self.cluster_tile_m
|
||||
cluster_tile_n_cnt = (self.intermediate + self.cluster_tile_n - 1) // self.cluster_tile_n
|
||||
return cluster_tile_m_cnt * cluster_tile_n_cnt
|
||||
else: # 2Dx3D
|
||||
# Variable M (tokens), fixed N
|
||||
tokens_i = self.offs[expert_idx]
|
||||
if expert_idx > 0:
|
||||
tokens_i = tokens_i - self.offs[expert_idx - 1] # type: ignore[operator]
|
||||
cluster_tile_m_cnt = (
|
||||
tokens_i + self.cluster_tile_m - 1 # type: ignore[operator]
|
||||
) // self.cluster_tile_m
|
||||
cluster_tile_n_cnt = (
|
||||
self.intermediate + self.cluster_tile_n - 1
|
||||
) // self.cluster_tile_n
|
||||
return cluster_tile_m_cnt * cluster_tile_n_cnt
|
||||
|
||||
@dsl_user_op
|
||||
@cute.jit
|
||||
def _decompose_local_idx(
|
||||
self,
|
||||
local_idx: Int32,
|
||||
expert_idx: Int32,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Tuple[Int32, Int32]:
|
||||
"""
|
||||
Decompose local cluster tile index within expert to (cluster_tile_m_idx, cluster_tile_n_idx).
|
||||
|
||||
Uses "short side first" strategy: the shorter dimension changes faster.
|
||||
This maximizes overlap between adjacent clusters for better L2 cache utilization.
|
||||
|
||||
For example, if m_cnt=2, n_cnt=8:
|
||||
- N is longer, so M changes faster: local_idx = n_idx * m_cnt + m_idx
|
||||
- Linearization order: (0,0), (1,0), (0,1), (1,1), (0,2), (1,2), ...
|
||||
"""
|
||||
# Get tile counts for M and N
|
||||
cluster_tile_m_cnt, cluster_tile_n_cnt = self._get_cluster_tile_counts(
|
||||
expert_idx, loc=loc, ip=ip
|
||||
)
|
||||
cluster_tile_m_idx = -1
|
||||
cluster_tile_n_idx = -1
|
||||
|
||||
# Short side first: shorter dimension changes faster
|
||||
# If m_cnt <= n_cnt: m is shorter, m changes faster
|
||||
# local_idx = n_idx * m_cnt + m_idx
|
||||
# If n_cnt < m_cnt: n is shorter, n changes faster
|
||||
# local_idx = m_idx * n_cnt + n_idx
|
||||
if cluster_tile_m_cnt <= cluster_tile_n_cnt:
|
||||
# M is shorter or equal, M changes faster
|
||||
cluster_tile_m_idx = local_idx % cluster_tile_m_cnt
|
||||
cluster_tile_n_idx = local_idx // cluster_tile_m_cnt
|
||||
else:
|
||||
# N is shorter, N changes faster
|
||||
cluster_tile_n_idx = local_idx % cluster_tile_n_cnt
|
||||
cluster_tile_m_idx = local_idx // cluster_tile_n_cnt
|
||||
|
||||
return (cluster_tile_m_idx, cluster_tile_n_idx)
|
||||
|
||||
@dsl_user_op
|
||||
@cute.jit
|
||||
def _get_cluster_tile_counts(
|
||||
self,
|
||||
expert_idx: Int32,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Tuple[Int32, Int32]:
|
||||
"""Get (cluster_tile_m_cnt, cluster_tile_n_cnt) for a given expert."""
|
||||
if const_expr(self.scenario == "2Dx2D"):
|
||||
# Fixed M=hidden, N=intermediate
|
||||
cluster_tile_m_cnt = (self.hidden + self.cluster_tile_m - 1) // self.cluster_tile_m
|
||||
cluster_tile_n_cnt = (self.intermediate + self.cluster_tile_n - 1) // self.cluster_tile_n
|
||||
else: # 2Dx3D
|
||||
# Variable M (tokens), fixed N
|
||||
tokens_i = self.offs[expert_idx]
|
||||
if expert_idx > 0:
|
||||
tokens_i = tokens_i - self.offs[expert_idx - 1] # type: ignore[operator]
|
||||
cluster_tile_m_cnt = (
|
||||
tokens_i + self.cluster_tile_m - 1 # type: ignore[operator]
|
||||
) // self.cluster_tile_m
|
||||
cluster_tile_n_cnt = (
|
||||
self.intermediate + self.cluster_tile_n - 1
|
||||
) // self.cluster_tile_n
|
||||
return (cluster_tile_m_cnt, cluster_tile_n_cnt)
|
||||
|
||||
@dsl_user_op
|
||||
@cute.jit
|
||||
def _compute_k_tile_cnt(
|
||||
self,
|
||||
expert_idx: Int32,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Int32:
|
||||
"""
|
||||
Compute the number of K tiles for this expert.
|
||||
|
||||
2Dx3D: K = hidden (fixed) -> k_tile_cnt = ceil(hidden / cta_tile_k)
|
||||
2Dx2D: K = tokens_i (variable) -> k_tile_cnt = ceil(tokens_i / cta_tile_k)
|
||||
"""
|
||||
if const_expr(self.scenario == "2Dx3D"):
|
||||
# K is hidden (fixed)
|
||||
return (self.hidden + self.cta_tile_k - 1) // self.cta_tile_k
|
||||
else: # 2Dx2D
|
||||
# K is tokens_i (variable per expert)
|
||||
tokens_i = self.offs[expert_idx]
|
||||
if expert_idx > cutlass.Int32(0):
|
||||
tokens_i = tokens_i - self.offs[expert_idx - 1] # type: ignore[operator]
|
||||
return (tokens_i + self.cta_tile_k - 1) // self.cta_tile_k # type: ignore[return-value, operator]
|
||||
443
reference/moe_moe_sched_extension.py
Normal file
443
reference/moe_moe_sched_extension.py
Normal file
@@ -0,0 +1,443 @@
|
||||
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
"""
|
||||
MoE Scheduler Extension.
|
||||
|
||||
Bridges the MoE tile scheduler (MoEStaticPersistentTileScheduler) with tensor-level
|
||||
domain conversion and TMA descriptor selection. This is the "glue" layer between:
|
||||
|
||||
- Scheduler: produces MoEWorkTileInfo (expert_idx, tile_m, tile_n, k_tile_cnt)
|
||||
- OnlineTensormapDescCreator: builds/retrieves TMA descriptors from workspace
|
||||
- Kernel: orchestrates everything
|
||||
|
||||
Different kernel types (grouped_mm, scaled_grouped_mm, etc.) provide their own
|
||||
MoESchedExtension subclass with kernel-specific domain conversion logic.
|
||||
|
||||
Key design principles:
|
||||
- Unified interface: get_gmem_tensor() for all tensor types
|
||||
- Free implementation: no role-based templates, each subclass writes its own logic
|
||||
- Composable utilities: compute_expert_token_range, rewrite_tensor_shape, etc.
|
||||
are available as tools but not mandatory
|
||||
|
||||
Architecture:
|
||||
|
||||
Scheduler ──(produces)──> MoEWorkTileInfo
|
||||
│
|
||||
expert_idx, tile_m, tile_n, k_cnt
|
||||
│
|
||||
v
|
||||
Extension ──(uses)──> OnlineTensormapDescCreator
|
||||
│ │
|
||||
│ get_gmem_tensor() │ get_desc_ptr()
|
||||
│ prefetch_for_expert() │ construct_and_write()
|
||||
│ │
|
||||
└── internal calls ───────┘
|
||||
|
||||
Kernel (caller): the only place that knows all three exist
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Literal, Tuple, Union
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.typing import Pointer
|
||||
from cutlass.cutlass_dsl import Int32
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from cutlass.utils.blockscaled_layout import tile_atom_to_shape_SF
|
||||
from blackwell.kernel.moe.moe_utils import (
|
||||
OnlineTensormapDescCreator,
|
||||
tensormap_ptr_for_copy,
|
||||
compute_expert_token_range,
|
||||
rewrite_tensor_shape,
|
||||
prefetch_tma_descriptor,
|
||||
)
|
||||
from blackwell.kernel.moe.moe_persistent_scheduler import MoEWorkTileInfo
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MoESchedExtension(ABC):
|
||||
"""
|
||||
Abstract base class for MoE scheduler extensions.
|
||||
|
||||
Bridges MoEWorkTileInfo with tensor-level domain conversion and TMA
|
||||
descriptor selection. Each kernel type (grouped_mm, scaled_grouped_mm, etc.)
|
||||
provides its own subclass with kernel-specific logic.
|
||||
|
||||
The extension:
|
||||
- Holds a reference to an OnlineTensormapDescCreator for expert-wise desc retrieval
|
||||
- Implements get_gmem_tensor() to convert MoE-view tensors to per-expert tensors
|
||||
- Implements prefetch_for_expert() to prefetch expert-wise TMA descriptors
|
||||
|
||||
Subclasses are free to add any additional attributes in __init__ (scenario,
|
||||
codegen configs, etc.) and implement get_gmem_tensor with arbitrary logic
|
||||
per tensor_name. No role-based templates or rigid patterns are imposed.
|
||||
|
||||
Usage in kernel (caller):
|
||||
ext = ConcreteSchedExtension(tensormap_ctor, scenario=...)
|
||||
|
||||
while work_tile_info.is_valid_tile:
|
||||
real_a, desc_a = ext.get_gmem_tensor("a", tma_tensor_a, offs, work_tile_info)
|
||||
real_b, desc_b = ext.get_gmem_tensor("b", tma_tensor_b, offs, work_tile_info)
|
||||
# Use real_a, desc_a in cute.copy ...
|
||||
"""
|
||||
|
||||
def __init__(self, tensormap_ctor: OnlineTensormapDescCreator):
|
||||
super().__init__()
|
||||
self.tensormap_ctor = tensormap_ctor
|
||||
|
||||
@abstractmethod
|
||||
def get_gmem_tensor(
|
||||
self,
|
||||
tensor_name: str,
|
||||
gmem_tensor_in_moe_view: cute.Tensor,
|
||||
offs: Union[cute.Tensor, Tuple[cute.Tensor, cute.Tensor]],
|
||||
work_tile_info: MoEWorkTileInfo,
|
||||
) -> Tuple[cute.Tensor, "Pointer | None"]:
|
||||
"""
|
||||
Convert an MoE-view tensor to the real per-expert tensor for the
|
||||
current work tile, and return the appropriate TMA descriptor pointer.
|
||||
|
||||
The MoE-view tensor uses "fake" GEMM domain dimensions that span all
|
||||
experts (e.g., fake_m = tokens_sum). This method slices/offsets it
|
||||
to the current expert's actual region.
|
||||
|
||||
:param tensor_name: Identifies which tensor (e.g., "a", "b", "c", "sfa")
|
||||
:param gmem_tensor_in_moe_view: Tensor in fake GEMM MNKL domain
|
||||
:param offs: Either a single cumsum tensor (experts,), or a tuple of
|
||||
(offs_token, offs_padded) where offs_padded provides
|
||||
padded offsets for scale-factor domain conversion.
|
||||
:param work_tile_info: Current work tile from the scheduler
|
||||
:return: (real_tensor, tma_desc_ptr_or_none)
|
||||
- real_tensor: domain-offset and shape-rewritten tensor for this expert
|
||||
- tma_desc_ptr: expert-wise desc ptr (already converted for cute.copy),
|
||||
or None if the caller should use the global TMA descriptor
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def prefetch_for_expert(self, expert_idx: Int32) -> None:
|
||||
"""
|
||||
Prefetch expert-wise TMA descriptors for the given expert.
|
||||
|
||||
Called when the scheduler advances to a new expert, allowing the TMA
|
||||
descriptor cache to be warmed up before the descriptors are needed.
|
||||
|
||||
:param expert_idx: Index of the expert whose descriptors to prefetch
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Grouped MM Extension
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class GroupedMmSchedExtension(MoESchedExtension):
|
||||
"""
|
||||
MoE scheduler extension for grouped_mm: handles tensors a, b, c.
|
||||
|
||||
Domain conversion logic per scenario:
|
||||
|
||||
2Dx3D:
|
||||
A: (fake_m, k, 1) -> offset fake_m by token_offset, global desc
|
||||
B: (n, k, fake_l) -> offset fake_l by expert_idx, global desc
|
||||
C: (fake_m, n, 1) -> rewrite shape only, expert-wise desc
|
||||
|
||||
2Dx2D:
|
||||
A: (m, fake_k, 1) -> rewrite shape only, expert-wise desc
|
||||
B: (n, fake_k, 1) -> rewrite shape only, expert-wise desc
|
||||
C: (m, n, fake_l) -> offset fake_l by expert_idx, global desc
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scenario: Literal["2Dx3D", "2Dx2D"],
|
||||
tensormap_ctor: OnlineTensormapDescCreator,
|
||||
):
|
||||
super().__init__(tensormap_ctor)
|
||||
self.scenario = scenario
|
||||
|
||||
@cute.jit
|
||||
def get_gmem_tensor(
|
||||
self,
|
||||
tensor_name: str,
|
||||
gmem_tensor_in_moe_view: cute.Tensor,
|
||||
offs: cute.Tensor,
|
||||
work_tile_info: MoEWorkTileInfo,
|
||||
):
|
||||
expert_idx = work_tile_info.expert_idx
|
||||
token_offset, tokens_i = compute_expert_token_range(offs, expert_idx)
|
||||
|
||||
shape = gmem_tensor_in_moe_view.shape
|
||||
c1 = cutlass.Int32(1)
|
||||
|
||||
if cutlass.const_expr(self.scenario == "2Dx3D"):
|
||||
if cutlass.const_expr(tensor_name == "a"):
|
||||
# A: (fake_m, k, 1) -> offset fake_m, global desc
|
||||
real = cute.domain_offset((token_offset, 0, 0), gmem_tensor_in_moe_view)
|
||||
real = rewrite_tensor_shape(real, (tokens_i, shape[1], c1)) # type: ignore[index]
|
||||
return (real, None)
|
||||
elif cutlass.const_expr(tensor_name == "b"):
|
||||
# B: (n, k, fake_l) -> offset fake_l, global desc
|
||||
real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view)
|
||||
real = rewrite_tensor_shape(real, (shape[0], shape[1], c1)) # type: ignore[index]
|
||||
return (real, None)
|
||||
elif cutlass.const_expr(tensor_name == "c"):
|
||||
# C: (fake_m, n, 1) -> expert-wise desc, no offset
|
||||
real = rewrite_tensor_shape(
|
||||
gmem_tensor_in_moe_view,
|
||||
(tokens_i, shape[1], c1), # type: ignore[index]
|
||||
)
|
||||
desc = tensormap_ptr_for_copy(
|
||||
self.tensormap_ctor.get_desc_ptr("c", expert_idx)
|
||||
)
|
||||
return (real, desc)
|
||||
|
||||
elif cutlass.const_expr(self.scenario == "2Dx2D"):
|
||||
if cutlass.const_expr(tensor_name == "a"):
|
||||
# A: (m, fake_k, 1) -> expert-wise desc, no offset
|
||||
real = rewrite_tensor_shape(
|
||||
gmem_tensor_in_moe_view,
|
||||
(shape[0], tokens_i, c1), # type: ignore[index]
|
||||
)
|
||||
desc = tensormap_ptr_for_copy(
|
||||
self.tensormap_ctor.get_desc_ptr("a", expert_idx)
|
||||
)
|
||||
return (real, desc)
|
||||
elif cutlass.const_expr(tensor_name == "b"):
|
||||
# B: (n, fake_k, 1) -> expert-wise desc, no offset
|
||||
real = rewrite_tensor_shape(
|
||||
gmem_tensor_in_moe_view,
|
||||
(shape[0], tokens_i, c1), # type: ignore[index]
|
||||
)
|
||||
desc = tensormap_ptr_for_copy(
|
||||
self.tensormap_ctor.get_desc_ptr("b", expert_idx)
|
||||
)
|
||||
return (real, desc)
|
||||
elif cutlass.const_expr(tensor_name == "c"):
|
||||
# C: (m, n, fake_l) -> offset fake_l, global desc
|
||||
real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view)
|
||||
real = rewrite_tensor_shape(real, (shape[0], shape[1], c1)) # type: ignore[index]
|
||||
return (real, None)
|
||||
|
||||
raise ValueError("Invalid scenario or GEMM tensor name.")
|
||||
|
||||
@cute.jit
|
||||
def prefetch_for_expert(self, expert_idx: Int32) -> None:
|
||||
if cutlass.const_expr(self.scenario == "2Dx3D"):
|
||||
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("c", expert_idx))
|
||||
elif cutlass.const_expr(self.scenario == "2Dx2D"):
|
||||
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("a", expert_idx))
|
||||
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("b", expert_idx))
|
||||
else:
|
||||
raise ValueError("Invalid scenario.")
|
||||
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Scaled Grouped MM Extension (block-scaled MoE)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ScaledGroupedMmSchedExtension(MoESchedExtension):
|
||||
"""
|
||||
MoE scheduler extension for scaled_grouped_mm: handles a, b, c, sfa, sfb.
|
||||
|
||||
Extends GroupedMmSchedExtension with scale-factor tensor support.
|
||||
SFA/SFB are passed as flat GEMM-domain tensors and atom-tiled per expert
|
||||
via tile_atom_to_shape_SF.
|
||||
|
||||
The offs parameter is always a tuple (offs_token, offs_padded):
|
||||
- offs_token: cumsum offsets in data (activation) domain
|
||||
- offs_padded: cumsum offsets in scale-factor domain (padded to atom granularity)
|
||||
|
||||
sf_vec_size is obtained from self.tensormap_ctor.sf_vec_size.
|
||||
|
||||
Domain conversion logic per scenario:
|
||||
|
||||
2Dx3D:
|
||||
A: (fake_m, k, 1) -> offset fake_m by token_offset, global desc
|
||||
B: (n, k, fake_l) -> offset fake_l by expert_idx, global desc
|
||||
C: (fake_m, n, 1) -> rewrite shape, expert-wise desc
|
||||
SFA: (fake_m_pad, k_pad, 1) -> offset fake_m_pad by padded_offset,
|
||||
atom-tile, global desc
|
||||
SFB: (n_pad, k_pad, fake_l) -> offset fake_l by expert_idx,
|
||||
atom-tile, global desc
|
||||
|
||||
2Dx2D:
|
||||
A: (m, fake_k, 1) -> rewrite shape, expert-wise desc
|
||||
B: (n, fake_k, 1) -> rewrite shape, expert-wise desc
|
||||
C: (m, n, fake_l) -> offset fake_l by expert_idx, global desc
|
||||
SFA: (m_pad, fake_k_pad, 1) -> offset fake_k_pad by padded_offset,
|
||||
atom-tile, expert-wise desc
|
||||
SFB: (n_pad, fake_k_pad, 1) -> offset fake_k_pad by padded_offset,
|
||||
atom-tile, expert-wise desc
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scenario: Literal["2Dx3D", "2Dx2D"],
|
||||
tensormap_ctor: OnlineTensormapDescCreator,
|
||||
):
|
||||
super().__init__(tensormap_ctor)
|
||||
self.scenario = scenario
|
||||
|
||||
@cute.jit
|
||||
def get_gmem_tensor(
|
||||
self,
|
||||
tensor_name: str,
|
||||
gmem_tensor_in_moe_view: cute.Tensor,
|
||||
offs: Tuple[cute.Tensor, cute.Tensor],
|
||||
work_tile_info: MoEWorkTileInfo,
|
||||
):
|
||||
# Unpack the offs tuple
|
||||
offs_token, offs_padded = offs
|
||||
|
||||
expert_idx = work_tile_info.expert_idx
|
||||
token_offset, tokens_i = compute_expert_token_range(offs_token, expert_idx)
|
||||
padded_offset, padded_size_i = compute_expert_token_range(
|
||||
offs_padded, expert_idx
|
||||
)
|
||||
|
||||
shape = gmem_tensor_in_moe_view.shape
|
||||
stride = gmem_tensor_in_moe_view.stride
|
||||
c1 = cutlass.Int32(1)
|
||||
sf_vec_size = self.tensormap_ctor.sf_vec_size
|
||||
|
||||
if cutlass.const_expr(self.scenario == "2Dx3D"):
|
||||
if cutlass.const_expr(tensor_name == "a"):
|
||||
# A: (fake_m, k, 1) -> offset fake_m, global desc
|
||||
real = cute.domain_offset((token_offset, 0, 0), gmem_tensor_in_moe_view)
|
||||
real = rewrite_tensor_shape(real, (tokens_i, shape[1], c1)) # type: ignore[index]
|
||||
return (real, None)
|
||||
|
||||
elif cutlass.const_expr(tensor_name == "b"):
|
||||
# B: (n, k, fake_l) -> offset fake_l, global desc
|
||||
real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view)
|
||||
real = rewrite_tensor_shape(real, (shape[0], shape[1], c1)) # type: ignore[index]
|
||||
return (real, None)
|
||||
|
||||
elif cutlass.const_expr(tensor_name == "c"):
|
||||
# C: (fake_m, n, 1) -> expert-wise desc
|
||||
real = rewrite_tensor_shape(
|
||||
gmem_tensor_in_moe_view,
|
||||
(tokens_i, shape[1], c1), # type: ignore[index]
|
||||
)
|
||||
desc = tensormap_ptr_for_copy(
|
||||
self.tensormap_ctor.get_desc_ptr("c", expert_idx)
|
||||
)
|
||||
return (real, desc)
|
||||
|
||||
elif cutlass.const_expr(tensor_name == "sfa"):
|
||||
# SFA: (fake_m_pad, k_pad, 1) -> offset fake_m_pad, atom-tile, global desc
|
||||
real = cute.domain_offset(
|
||||
(padded_offset, 0, 0), gmem_tensor_in_moe_view
|
||||
)
|
||||
per_expert_shape = (padded_size_i, shape[1], c1) # type: ignore[index]
|
||||
sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size)
|
||||
real = cute.make_tensor(
|
||||
real.iterator, cute.make_layout(sf_layout.shape, stride=stride)
|
||||
)
|
||||
return (real, None)
|
||||
|
||||
elif cutlass.const_expr(tensor_name == "sfb"):
|
||||
# SFB: (n_pad, k_pad, fake_l) -> offset fake_l, atom-tile, global desc
|
||||
real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view)
|
||||
per_expert_shape = (shape[0], shape[1], c1) # type: ignore[index]
|
||||
sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size)
|
||||
real = cute.make_tensor(
|
||||
real.iterator, cute.make_layout(sf_layout.shape, stride=stride)
|
||||
)
|
||||
return (real, None)
|
||||
|
||||
elif cutlass.const_expr(self.scenario == "2Dx2D"):
|
||||
if cutlass.const_expr(tensor_name == "a"):
|
||||
# A: (m, fake_k, 1) -> expert-wise desc
|
||||
real = rewrite_tensor_shape(
|
||||
gmem_tensor_in_moe_view,
|
||||
(shape[0], tokens_i, c1), # type: ignore[index]
|
||||
)
|
||||
desc = tensormap_ptr_for_copy(
|
||||
self.tensormap_ctor.get_desc_ptr("a", expert_idx)
|
||||
)
|
||||
return (real, desc)
|
||||
|
||||
elif cutlass.const_expr(tensor_name == "b"):
|
||||
# B: (n, fake_k, 1) -> expert-wise desc
|
||||
real = rewrite_tensor_shape(
|
||||
gmem_tensor_in_moe_view,
|
||||
(shape[0], tokens_i, c1), # type: ignore[index]
|
||||
)
|
||||
desc = tensormap_ptr_for_copy(
|
||||
self.tensormap_ctor.get_desc_ptr("b", expert_idx)
|
||||
)
|
||||
return (real, desc)
|
||||
|
||||
elif cutlass.const_expr(tensor_name == "c"):
|
||||
# C: (m, n, fake_l) -> offset fake_l, global desc
|
||||
real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view)
|
||||
real = rewrite_tensor_shape(real, (shape[0], shape[1], c1)) # type: ignore[index]
|
||||
return (real, None)
|
||||
|
||||
elif cutlass.const_expr(tensor_name == "sfa"):
|
||||
# SFA: (m_pad, fake_k_pad, 1) -> offset fake_k_pad, atom-tile, expert-wise desc
|
||||
per_expert_shape = (shape[0], padded_size_i, c1) # type: ignore[index]
|
||||
sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size)
|
||||
real = rewrite_tensor_shape(gmem_tensor_in_moe_view, sf_layout.shape)
|
||||
desc = tensormap_ptr_for_copy(
|
||||
self.tensormap_ctor.get_desc_ptr("sfa", expert_idx)
|
||||
)
|
||||
return (real, desc)
|
||||
|
||||
elif cutlass.const_expr(tensor_name == "sfb"):
|
||||
# SFB: (n_pad, fake_k_pad, 1) -> offset fake_k_pad, atom-tile, expert-wise desc
|
||||
per_expert_shape = (shape[0], padded_size_i, c1) # type: ignore[index]
|
||||
sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size)
|
||||
real = rewrite_tensor_shape(gmem_tensor_in_moe_view, sf_layout.shape)
|
||||
desc = tensormap_ptr_for_copy(
|
||||
self.tensormap_ctor.get_desc_ptr("sfb", expert_idx)
|
||||
)
|
||||
return (real, desc)
|
||||
|
||||
raise ValueError("Invalid scenario or tensor name.")
|
||||
|
||||
@cute.jit
|
||||
def prefetch_for_expert(self, expert_idx: Int32) -> None:
|
||||
if cutlass.const_expr(self.scenario == "2Dx3D"):
|
||||
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("c", expert_idx))
|
||||
elif cutlass.const_expr(self.scenario == "2Dx2D"):
|
||||
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("a", expert_idx))
|
||||
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("b", expert_idx))
|
||||
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("sfa", expert_idx))
|
||||
prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("sfb", expert_idx))
|
||||
else:
|
||||
raise ValueError("Invalid scenario.")
|
||||
910
reference/moe_moe_utils.py
Normal file
910
reference/moe_moe_utils.py
Normal file
@@ -0,0 +1,910 @@
|
||||
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
# Redistribution and use in source and binary forms, with or without
|
||||
# modification, are permitted provided that the following conditions are met:
|
||||
|
||||
# 1. Redistributions of source code must retain the above copyright notice, this
|
||||
# list of conditions and the following disclaimer.
|
||||
|
||||
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
# this list of conditions and the following disclaimer in the documentation
|
||||
# and/or other materials provided with the distribution.
|
||||
|
||||
# 3. Neither the name of the copyright holder nor the names of its
|
||||
# contributors may be used to endorse or promote products derived from
|
||||
# this software without specific prior written permission.
|
||||
|
||||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
"""
|
||||
Online TMA Descriptor Construction Utilities.
|
||||
|
||||
Provides utilities for dynamically creating TMA descriptors at kernel runtime
|
||||
based on runtime-provided information (problem sizes, pointers, etc.).
|
||||
|
||||
Key components:
|
||||
- OnlineTensormapDescCreator: Simplified ABC for TMA descriptor builders (2 abstract methods)
|
||||
- TensormapWorkspace: Helper for linear workspace layout of TMA descriptors
|
||||
- MoEGroupedGemmTensormapConstructor: TMA descriptor constructor for MoE Grouped GEMM
|
||||
- GeneralGroupedGemmTensormapConstructor: TMA descriptor constructor for general Grouped GEMM
|
||||
- Pointer utility functions (ptr_offset_bytes, gmem_ptr_to_generic, etc.)
|
||||
- tensormap_ptr_for_copy: Convert raw desc ptr to cute.copy-compatible type
|
||||
- compute_expert_token_range: Compute per-expert token offset and count from offs
|
||||
- rewrite_tensor_shape: Debug-friendly tensor shape rewrite utility
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Literal, Tuple, Union
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.typing import AddressSpace, Pointer
|
||||
from cutlass.cute.nvgpu import cpasync
|
||||
from cutlass.cutlass_dsl import dsl_user_op, Int32
|
||||
from cutlass._mlir import ir
|
||||
from cutlass._mlir.dialects import llvm
|
||||
from cutlass._mlir.dialects import cute as _cute_ir
|
||||
from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir
|
||||
from dataclasses import dataclass
|
||||
|
||||
from cutlass.utils.blockscaled_layout import tile_atom_to_shape_SF
|
||||
|
||||
TensormapDescBytes = 128
|
||||
|
||||
# =============================================================================
|
||||
# Pointer Utilities
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
@cute.jit
|
||||
def spin_wait(
|
||||
ptr: Pointer, condition, fail_sleep_cycles: int = 100, *, loc=None, ip=None
|
||||
) -> None:
|
||||
"""
|
||||
Generic spin-wait.
|
||||
Example usage:
|
||||
# Wait until counter >= total_blocks
|
||||
spin_wait(counter_ptr, lambda x: x >= total_blocks, fail_sleep_cycles=100)
|
||||
|
||||
# Wait until flag == 1
|
||||
spin_wait(flag_ptr, lambda x: x == 1)
|
||||
"""
|
||||
current = cute.arch.load(ptr, ptr.dtype, cop="cg", loc=loc, ip=ip)
|
||||
while not condition(current):
|
||||
# Load with L1 cache bypass (ld.global.cg)
|
||||
if cutlass.const_expr(fail_sleep_cycles > 0):
|
||||
cute.arch.nanosleep(sleep_time=fail_sleep_cycles, loc=loc, ip=ip)
|
||||
current = cute.arch.load(ptr, ptr.dtype, cop="cg", loc=loc, ip=ip)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def gmem_ptr_to_generic(
|
||||
gmem_ptr: Pointer,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
) -> Pointer:
|
||||
if gmem_ptr.memspace != AddressSpace.gmem:
|
||||
raise ValueError(
|
||||
f"gmem_ptr_to_generic requires pointer in gmem address space, "
|
||||
f"got {gmem_ptr.memspace}"
|
||||
)
|
||||
# Get LLVM pointer and cast to generic address space
|
||||
llvm_ptr = gmem_ptr.to_llvm_ptr(loc=loc, ip=ip)
|
||||
generic_llvm_ptr = llvm.addrspacecast(
|
||||
llvm.PointerType.get(AddressSpace.generic), llvm_ptr, loc=loc, ip=ip
|
||||
)
|
||||
# Create a new cute.Pointer with generic address space, preserving alignment
|
||||
return cute.make_ptr(
|
||||
gmem_ptr.dtype,
|
||||
generic_llvm_ptr,
|
||||
AddressSpace.generic,
|
||||
assumed_align=gmem_ptr.alignment,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def generic_ptr_to_gmem(
|
||||
generic_ptr: Pointer,
|
||||
*,
|
||||
loc: Optional[ir.Location] = None,
|
||||
ip: Optional[ir.InsertionPoint] = None,
|
||||
) -> Pointer:
|
||||
if generic_ptr.memspace != AddressSpace.generic:
|
||||
raise ValueError(
|
||||
f"generic_ptr_to_gmem requires pointer in generic address space, "
|
||||
f"got {generic_ptr.memspace}"
|
||||
)
|
||||
# Get LLVM pointer and cast to gmem address space
|
||||
llvm_ptr = generic_ptr.to_llvm_ptr(loc=loc, ip=ip)
|
||||
gmem_llvm_ptr = llvm.addrspacecast(
|
||||
llvm.PointerType.get(AddressSpace.gmem), llvm_ptr, loc=loc, ip=ip
|
||||
)
|
||||
# Create a new cute.Pointer with gmem address space, preserving alignment
|
||||
return cute.make_ptr(
|
||||
generic_ptr.dtype,
|
||||
gmem_llvm_ptr,
|
||||
AddressSpace.gmem,
|
||||
assumed_align=generic_ptr.alignment,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def prefetch_tma_descriptor(tma_desc_ptr: Pointer, *, loc=None, ip=None) -> None:
|
||||
"""
|
||||
Prefetch a TMA descriptor from global memory.
|
||||
|
||||
This function prefetches the TMA descriptor pointed to by tma_desc_ptr
|
||||
into the TMA descriptor cache. The pointer must be in generic or global
|
||||
address space. If a gmem pointer is passed, it will be automatically
|
||||
converted to generic address space.
|
||||
|
||||
:param tma_desc_ptr: Pointer to the TMA descriptor in global or generic memory
|
||||
:type tma_desc_ptr: Pointer
|
||||
:raises ValueError: If pointer is not in generic or global address space
|
||||
"""
|
||||
if tma_desc_ptr.memspace not in (AddressSpace.gmem, AddressSpace.generic):
|
||||
raise ValueError(
|
||||
f"prefetch_tma_descriptor requires pointer in gmem or generic address space, "
|
||||
f"got {tma_desc_ptr.memspace}"
|
||||
)
|
||||
# Convert gmem pointer to generic if needed
|
||||
if tma_desc_ptr.memspace == AddressSpace.gmem:
|
||||
tma_desc_ptr = gmem_ptr_to_generic(tma_desc_ptr, loc=loc, ip=ip)
|
||||
# Convert cute.Pointer to LLVM pointer for prefetch
|
||||
llvm_ptr = tma_desc_ptr.to_llvm_ptr(loc=loc, ip=ip)
|
||||
from cutlass.cute.arch.nvvm_wrappers import prefetch as nvvm_prefetch
|
||||
|
||||
nvvm_prefetch(llvm_ptr, tensormap=True, loc=loc, ip=ip)
|
||||
|
||||
|
||||
def ptr_offset_bytes(ptr: Pointer, byte_offset: int) -> Pointer:
|
||||
"""Offset a pointer by a given number of bytes."""
|
||||
element_offset = byte_offset * 8 // ptr.dtype.width
|
||||
return ptr + element_offset
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def tensormap_ptr_for_copy(raw_ptr: Pointer, *, loc=None, ip=None) -> Pointer:
|
||||
"""
|
||||
Convert a raw TMA descriptor gmem pointer to the type expected by cute.copy.
|
||||
|
||||
cute.copy requires the tma_desc_ptr to be in generic address space and
|
||||
recast to TmaDescriptorTiledType. This utility performs both conversions.
|
||||
|
||||
:param raw_ptr: Raw pointer to TMA descriptor in gmem
|
||||
:type raw_ptr: Pointer
|
||||
:return: Pointer compatible with cute.copy's tma_desc_ptr parameter
|
||||
:rtype: Pointer
|
||||
"""
|
||||
generic_ptr = gmem_ptr_to_generic(raw_ptr, loc=loc, ip=ip)
|
||||
tma_desc_ptr_ty = _cute_ir.PtrType.get(
|
||||
_cute_nvgpu_ir.TmaDescriptorTiledType.get(),
|
||||
generic_ptr.memspace,
|
||||
generic_ptr.alignment,
|
||||
)
|
||||
return _cute_ir.recast_iter(tma_desc_ptr_ty, generic_ptr.value)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MoE Utilities
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
@cute.jit
|
||||
def compute_expert_token_range(
|
||||
offs: cute.Tensor,
|
||||
expert_idx: Int32,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> Tuple[Int32, Int32]:
|
||||
"""
|
||||
Compute token offset and count for a given expert from the cumsum offs tensor.
|
||||
|
||||
:param offs: Cumulative sum tensor of token counts per expert, shape (experts,)
|
||||
:param expert_idx: Index of the expert
|
||||
:return: (token_offset, tokens_i) where token_offset is the start position
|
||||
and tokens_i is the number of tokens for this expert
|
||||
"""
|
||||
token_offset = Int32(0)
|
||||
if expert_idx > Int32(0):
|
||||
token_offset = offs[expert_idx - 1] # type: ignore[assignment]
|
||||
tokens_i = offs[expert_idx] - token_offset
|
||||
return token_offset, tokens_i
|
||||
|
||||
|
||||
@dsl_user_op
|
||||
def rewrite_tensor_shape(
|
||||
tensor: cute.Tensor,
|
||||
new_shape: Tuple,
|
||||
*,
|
||||
loc=None,
|
||||
ip=None,
|
||||
) -> cute.Tensor:
|
||||
"""
|
||||
Rewrite tensor shape while keeping the same stride and iterator.
|
||||
|
||||
This is primarily for debug friendliness - shows the actual expert's shape
|
||||
instead of the fake global shape. No runtime overhead as it becomes
|
||||
dead code in non-debug builds.
|
||||
|
||||
:param tensor: Source tensor whose stride and iterator to preserve
|
||||
:param new_shape: New shape to apply
|
||||
:return: New tensor with the given shape but original stride and iterator
|
||||
"""
|
||||
new_layout = cute.make_layout(new_shape, stride=tensor.stride, loc=loc, ip=ip)
|
||||
return cute.make_tensor(tensor.iterator, new_layout, loc=loc, ip=ip)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TMA Descriptor Workspace Helper
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TensormapWorkspace:
|
||||
"""
|
||||
Helper for linear workspace layout of TMA descriptors.
|
||||
|
||||
Manages address calculation for a workspace buffer containing TMA descriptors
|
||||
organized as: for each executor (e.g., expert or group), a fixed set of
|
||||
named descriptor slots.
|
||||
|
||||
Layout: [slot_0_exec_0, slot_1_exec_0, ..., slot_0_exec_1, slot_1_exec_1, ...]
|
||||
|
||||
Example:
|
||||
# 2Dx3D MoE: only C is expert-wise
|
||||
workspace = TensormapWorkspace(workspace_ptr, ["c"])
|
||||
|
||||
# 2Dx2D MoE: A and B are expert-wise
|
||||
workspace = TensormapWorkspace(workspace_ptr, ["a", "b"])
|
||||
|
||||
# General grouped GEMM: all three tensors
|
||||
workspace = TensormapWorkspace(workspace_ptr, ["a", "b", "c"])
|
||||
"""
|
||||
|
||||
def __init__(self, workspace_ptr: Pointer, slot_names: list):
|
||||
"""
|
||||
:param workspace_ptr: Pointer to the beginning of the workspace buffer
|
||||
:param slot_names: Ordered list of tensor names, defining the slot layout
|
||||
per executor. e.g., ["a", "b", "c"]
|
||||
"""
|
||||
self.workspace_ptr = workspace_ptr
|
||||
self._name_to_slot = {name: i for i, name in enumerate(slot_names)}
|
||||
self._slots_per_executor = len(slot_names)
|
||||
|
||||
@cute.jit
|
||||
def get_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer:
|
||||
"""
|
||||
Get the workspace pointer for a specific TMA descriptor.
|
||||
|
||||
:param tensor_name: Name of the tensor (must be one of the slot_names)
|
||||
:param executor_idx: Index of the executor (e.g., group_idx or expert_idx)
|
||||
:return: Aligned pointer to the TMA descriptor in workspace
|
||||
"""
|
||||
if cutlass.const_expr(tensor_name not in self._name_to_slot):
|
||||
raise ValueError(
|
||||
f"Invalid tensor_name '{tensor_name}', "
|
||||
f"expected one of {list(self._name_to_slot.keys())}"
|
||||
)
|
||||
slot = self._name_to_slot[tensor_name]
|
||||
byte_offset = (
|
||||
executor_idx * self._slots_per_executor + slot
|
||||
) * TensormapDescBytes
|
||||
return ptr_offset_bytes(self.workspace_ptr, byte_offset).align(
|
||||
TensormapDescBytes
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def size_bytes(num_slots: int, num_executors: int) -> int:
|
||||
"""
|
||||
Calculate workspace size in bytes.
|
||||
|
||||
:param num_slots: Number of descriptor slots per executor
|
||||
:param num_executors: Total number of executors (e.g., expert_cnt or group_cnt)
|
||||
:return: Total workspace size in bytes
|
||||
"""
|
||||
return num_slots * num_executors * TensormapDescBytes
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Online TMA Descriptor Creator (Abstract Base Class)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OnlineTensormapDescCreator(ABC):
|
||||
"""
|
||||
Abstract base class for building TMA descriptors online (at kernel runtime).
|
||||
|
||||
Subclasses store all needed parameters (both codegen-time configs and runtime
|
||||
values) as explicit instance attributes in __init__. No dict-based APIs.
|
||||
|
||||
Subclasses must implement exactly 2 abstract methods:
|
||||
- construct_and_write: Build TMA descriptor(s) for one executor and write to workspace
|
||||
- get_desc_ptr: Return raw gmem pointer to a specific descriptor in workspace
|
||||
|
||||
To convert the raw pointer for use with cute.copy, callers should use the
|
||||
standalone tensormap_ptr_for_copy() utility.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def construct_and_write(self, executor_idx: Int32, dependency=None) -> None:
|
||||
"""
|
||||
Build TMA descriptor(s) for one executor and write to workspace.
|
||||
|
||||
:param executor_idx: Index of the executor (e.g., group_idx or expert_idx).
|
||||
Semantics may vary by subclass when ``dependency`` is provided.
|
||||
:param dependency: Optional pipeline consumer for inter-warp-group
|
||||
synchronization. When provided, the subclass decides when to wait
|
||||
(via ``dependency.wait_and_advance()``) and release. The subclass
|
||||
also decides how to interpret ``executor_idx`` in this mode.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_desc_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer:
|
||||
"""
|
||||
Get the raw gmem pointer to a specific TMA descriptor in workspace.
|
||||
|
||||
:param tensor_name: Name identifying which tensor's descriptor
|
||||
:param executor_idx: Index of the executor (e.g., group_idx or expert_idx)
|
||||
:return: Raw pointer (gmem) to the TMA descriptor
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MoE Grouped GEMM Tensormap Constructor
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class MoEGroupedGemmTensormapConstructor(OnlineTensormapDescCreator):
|
||||
"""
|
||||
Tensormap descriptor constructor for MoE Grouped GEMM (expert-wise descriptors only).
|
||||
|
||||
Non-expert-wise descriptors are passed directly at kernel launch.
|
||||
This class only handles:
|
||||
- 2Dx3D: C descriptors (expert-wise, to avoid write conflicts)
|
||||
- 2Dx2D: A and B descriptors (expert-wise, tokens is reduction axis)
|
||||
|
||||
All parameters are stored as explicit instance attributes (no dicts).
|
||||
|
||||
Workspace layout:
|
||||
- 2Dx3D: [C_0, C_1, ..., C_{n-1}]
|
||||
- 2Dx2D: [A_0, A_1, ..., A_{n-1}, B_0, B_1, ..., B_{n-1}]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scenario: Literal["2Dx3D", "2Dx2D"],
|
||||
# Codegen-time configs
|
||||
a_dtype,
|
||||
b_dtype,
|
||||
c_dtype,
|
||||
a_smem_layout,
|
||||
b_smem_layout,
|
||||
epi_smem_layout,
|
||||
a_tma_op,
|
||||
b_tma_op,
|
||||
c_tma_op,
|
||||
tiled_mma,
|
||||
mma_tiler,
|
||||
cluster_layout_vmnk_shape,
|
||||
epi_tile,
|
||||
# Runtime params
|
||||
a_tensor: cute.Tensor, # fake GEMM domain A
|
||||
b_tensor: cute.Tensor, # fake GEMM domain B
|
||||
c_tensor: cute.Tensor, # fake GEMM domain C
|
||||
offs: cute.Tensor, # (experts,) cumsum
|
||||
workspace_ptr: Pointer,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.scenario = scenario
|
||||
# Codegen-time configs
|
||||
self.a_dtype = a_dtype
|
||||
self.b_dtype = b_dtype
|
||||
self.c_dtype = c_dtype
|
||||
self.a_smem_layout = a_smem_layout
|
||||
self.b_smem_layout = b_smem_layout
|
||||
self.epi_smem_layout = epi_smem_layout
|
||||
self.a_tma_op = a_tma_op
|
||||
self.b_tma_op = b_tma_op
|
||||
self.c_tma_op = c_tma_op
|
||||
self.tiled_mma = tiled_mma
|
||||
self.mma_tiler = mma_tiler
|
||||
self.cluster_layout_vmnk_shape = cluster_layout_vmnk_shape
|
||||
self.epi_tile = epi_tile
|
||||
# Runtime params
|
||||
self.a_tensor = a_tensor
|
||||
self.b_tensor = b_tensor
|
||||
self.c_tensor = c_tensor
|
||||
self.offs = offs
|
||||
# Workspace with scenario-specific slot layout
|
||||
if scenario == "2Dx3D":
|
||||
self.workspace = TensormapWorkspace(workspace_ptr, ["c"])
|
||||
else:
|
||||
self.workspace = TensormapWorkspace(workspace_ptr, ["a", "b"])
|
||||
|
||||
@staticmethod
|
||||
def get_workspace_size(scenario: Literal["2Dx3D", "2Dx2D"], expert_cnt: int) -> int:
|
||||
"""Calculate workspace size in bytes for tensormap descriptors."""
|
||||
if scenario == "2Dx3D":
|
||||
return TensormapWorkspace.size_bytes(1, expert_cnt) # only C
|
||||
else:
|
||||
return TensormapWorkspace.size_bytes(2, expert_cnt) # A and B
|
||||
|
||||
@cute.jit
|
||||
def get_desc_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer:
|
||||
return self.workspace.get_ptr(tensor_name, executor_idx)
|
||||
|
||||
@cute.jit
|
||||
def construct_and_write(self, executor_idx: Int32, dependency=None) -> None:
|
||||
"""
|
||||
Create expert-wise tensormap descriptors for the given expert.
|
||||
|
||||
- 2Dx3D: Creates C descriptor for this expert
|
||||
- 2Dx2D: Creates A and B descriptors for this expert
|
||||
"""
|
||||
if cutlass.const_expr(self.scenario == "2Dx3D"):
|
||||
self._construct_c_desc_2dx3d(executor_idx)
|
||||
else: # 2Dx2D
|
||||
self._construct_ab_descs_2dx2d(executor_idx)
|
||||
|
||||
@cute.jit
|
||||
def _construct_c_desc_2dx3d(self, expert_idx: Int32) -> None:
|
||||
"""
|
||||
2Dx3D: Create expert-wise C descriptor.
|
||||
C tensor: (fake_m, n, 1) = (tokens_sum, intermediate, 1)
|
||||
Slice fake_m -> (tokens_i, intermediate, 1) per expert.
|
||||
"""
|
||||
token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx)
|
||||
|
||||
c_ptr = self.c_tensor.iterator
|
||||
c_stride = self.c_tensor.stride
|
||||
intermediate = self.c_tensor.shape[1] # type: ignore[index]
|
||||
|
||||
c1 = cutlass.Int32(1)
|
||||
c0 = cutlass.Int32(0)
|
||||
|
||||
c_ptr_i = c_ptr + token_offset * c_stride[0] # type: ignore[index]
|
||||
c_layout_i = cute.make_layout(
|
||||
(tokens_i, intermediate, c1),
|
||||
stride=(c_stride[0], c_stride[1], c0), # type: ignore[index]
|
||||
)
|
||||
c_tensor_i = cute.make_tensor(c_ptr_i, c_layout_i)
|
||||
|
||||
tma_atom_c, _ = cpasync.make_tiled_tma_atom(
|
||||
self.c_tma_op,
|
||||
c_tensor_i,
|
||||
self.epi_smem_layout,
|
||||
self.epi_tile,
|
||||
)
|
||||
cpasync.copy_tensormap(tma_atom_c, self.get_desc_ptr("c", expert_idx))
|
||||
|
||||
@cute.jit
|
||||
def _construct_ab_descs_2dx2d(self, expert_idx: Int32) -> None:
|
||||
"""
|
||||
2Dx2D: Create expert-wise A and B descriptors.
|
||||
A: (m, fake_k, 1) -> slice to (m, tokens_i, 1)
|
||||
B: (n, fake_k, 1) -> slice to (n, tokens_i, 1)
|
||||
"""
|
||||
token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx)
|
||||
|
||||
c1 = cutlass.Int32(1)
|
||||
c0 = cutlass.Int32(0)
|
||||
|
||||
# A tensor: (m, fake_k, 1) -> (m, tokens_i, 1)
|
||||
a_ptr = self.a_tensor.iterator
|
||||
a_stride = self.a_tensor.stride
|
||||
a_m = self.a_tensor.shape[0] # type: ignore[index]
|
||||
|
||||
a_ptr_i = a_ptr + token_offset * a_stride[1] # type: ignore[index]
|
||||
a_layout_i = cute.make_layout(
|
||||
(a_m, tokens_i, c1),
|
||||
stride=(a_stride[0], a_stride[1], c0), # type: ignore[index]
|
||||
)
|
||||
a_tensor_i = cute.make_tensor(a_ptr_i, a_layout_i)
|
||||
|
||||
tma_atom_a, _ = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
self.a_tma_op,
|
||||
a_tensor_i,
|
||||
self.a_smem_layout,
|
||||
self.mma_tiler,
|
||||
self.tiled_mma,
|
||||
self.cluster_layout_vmnk_shape,
|
||||
)
|
||||
cpasync.copy_tensormap(tma_atom_a, self.get_desc_ptr("a", expert_idx))
|
||||
|
||||
# B tensor: (n, fake_k, 1) -> (n, tokens_i, 1)
|
||||
b_ptr = self.b_tensor.iterator
|
||||
b_stride = self.b_tensor.stride
|
||||
b_n = self.b_tensor.shape[0] # type: ignore[index]
|
||||
|
||||
b_ptr_i = b_ptr + token_offset * b_stride[1] # type: ignore[index]
|
||||
b_layout_i = cute.make_layout(
|
||||
(b_n, tokens_i, c1),
|
||||
stride=(b_stride[0], b_stride[1], c0), # type: ignore[index]
|
||||
)
|
||||
b_tensor_i = cute.make_tensor(b_ptr_i, b_layout_i)
|
||||
|
||||
tma_atom_b, _ = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
self.b_tma_op,
|
||||
b_tensor_i,
|
||||
self.b_smem_layout,
|
||||
self.mma_tiler,
|
||||
self.tiled_mma,
|
||||
self.cluster_layout_vmnk_shape,
|
||||
)
|
||||
cpasync.copy_tensormap(tma_atom_b, self.get_desc_ptr("b", expert_idx))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MoE Scaled Grouped GEMM Tensormap Constructor
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class MoEScaledGroupedGemmTensormapConstructor(OnlineTensormapDescCreator):
|
||||
"""
|
||||
Tensormap descriptor constructor for MoE Scaled Grouped GEMM (block-scaled).
|
||||
|
||||
.. py:attribute:: ChunkSize
|
||||
:value: 128
|
||||
|
||||
Number of experts processed per chunk in the desc_init_kernel.
|
||||
Must match the warp-group width (4 warps × 32 threads).
|
||||
|
||||
Extends MoEGroupedGemmTensormapConstructor with SFA/SFB descriptor support.
|
||||
|
||||
Expert-wise descriptors only — non-expert-wise descriptors are passed
|
||||
directly at kernel launch.
|
||||
|
||||
Workspace layout:
|
||||
- 2Dx3D: [C_0, C_1, ..., C_{n-1}] (1 slot per expert)
|
||||
- 2Dx2D: [A_0, B_0, SFA_0, SFB_0, A_1, B_1, SFA_1, SFB_1, ...] (4 slots per expert)
|
||||
|
||||
:param scenario: "2Dx3D" or "2Dx2D"
|
||||
:param sf_vec_size: Scale factor vector size (32 for MXFP8/MXFP4, 16 for NVFP4)
|
||||
:param a_dtype: Data type for tensor A
|
||||
:param b_dtype: Data type for tensor B
|
||||
:param c_dtype: Data type for tensor C
|
||||
:param sf_dtype: Data type for scale factors (SFA/SFB)
|
||||
:param a_smem_layout: SMEM layout for A TMA
|
||||
:param b_smem_layout: SMEM layout for B TMA
|
||||
:param epi_smem_layout: SMEM layout for epilogue (C) TMA
|
||||
:param sfa_smem_layout: SMEM layout for SFA TMA
|
||||
:param sfb_smem_layout: SMEM layout for SFB TMA
|
||||
:param a_tma_op: TMA operation for A
|
||||
:param b_tma_op: TMA operation for B
|
||||
:param c_tma_op: TMA operation for C (S2G store or reduce)
|
||||
:param sfa_tma_op: TMA operation for SFA
|
||||
:param sfb_tma_op: TMA operation for SFB
|
||||
:param tiled_mma: TiledMma for A/B/SFA/C TMA atom construction
|
||||
:param tiled_mma_sfb: TiledMma for SFB (separate due to 2CTA replication)
|
||||
:param mma_tiler: MMA tiler shape (M, N, K)
|
||||
:param mma_tiler_sfb: MMA tiler shape for SFB
|
||||
:param cluster_layout_vmnk_shape: Cluster layout shape for A/B/SFA multicast
|
||||
:param cluster_layout_sfb_vmnk_shape: Cluster layout shape for SFB multicast
|
||||
:param epi_tile: Epilogue tile shape
|
||||
:param a_tensor: Fake GEMM domain A tensor
|
||||
:param b_tensor: Fake GEMM domain B tensor
|
||||
:param c_tensor: Fake GEMM domain C tensor
|
||||
:param sfa_tensor: Fake GEMM domain SFA tensor (atom-tiled layout)
|
||||
:param sfb_tensor: Fake GEMM domain SFB tensor (atom-tiled layout)
|
||||
:param offs: (experts,) cumsum offsets in data domain
|
||||
:param offs_padded: (experts,) cumsum offsets in padded scale domain
|
||||
:param workspace_ptr: Pointer to workspace for TMA descriptors
|
||||
:param expert_cnt: Total number of experts
|
||||
"""
|
||||
|
||||
ChunkSize = 128
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scenario: Literal["2Dx3D", "2Dx2D"],
|
||||
sf_vec_size: int,
|
||||
# Codegen-time configs: dtypes
|
||||
a_dtype,
|
||||
b_dtype,
|
||||
c_dtype,
|
||||
sf_dtype,
|
||||
# Codegen-time configs: SMEM layouts
|
||||
a_smem_layout,
|
||||
b_smem_layout,
|
||||
epi_smem_layout,
|
||||
sfa_smem_layout,
|
||||
sfb_smem_layout,
|
||||
# Codegen-time configs: TMA ops
|
||||
a_tma_op,
|
||||
b_tma_op,
|
||||
c_tma_op,
|
||||
sfa_tma_op,
|
||||
sfb_tma_op,
|
||||
# Codegen-time configs: MMA / cluster / tile
|
||||
tiled_mma,
|
||||
tiled_mma_sfb,
|
||||
mma_tiler,
|
||||
mma_tiler_sfb,
|
||||
cluster_layout_vmnk_shape,
|
||||
cluster_layout_sfb_vmnk_shape,
|
||||
epi_tile,
|
||||
# Runtime params
|
||||
a_tensor: cute.Tensor,
|
||||
b_tensor: cute.Tensor,
|
||||
c_tensor: cute.Tensor,
|
||||
sfa_tensor: cute.Tensor,
|
||||
sfb_tensor: cute.Tensor,
|
||||
offs: cute.Tensor,
|
||||
offs_padded: cute.Tensor,
|
||||
workspace_ptr: Pointer,
|
||||
expert_cnt: Optional[Union[Int32, int]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.scenario = scenario
|
||||
self.sf_vec_size = sf_vec_size
|
||||
# Dtypes
|
||||
self.a_dtype = a_dtype
|
||||
self.b_dtype = b_dtype
|
||||
self.c_dtype = c_dtype
|
||||
self.sf_dtype = sf_dtype
|
||||
# SMEM layouts
|
||||
self.a_smem_layout = a_smem_layout
|
||||
self.b_smem_layout = b_smem_layout
|
||||
self.epi_smem_layout = epi_smem_layout
|
||||
self.sfa_smem_layout = sfa_smem_layout
|
||||
self.sfb_smem_layout = sfb_smem_layout
|
||||
# TMA ops
|
||||
self.a_tma_op = a_tma_op
|
||||
self.b_tma_op = b_tma_op
|
||||
self.c_tma_op = c_tma_op
|
||||
self.sfa_tma_op = sfa_tma_op
|
||||
self.sfb_tma_op = sfb_tma_op
|
||||
# MMA / cluster / tile
|
||||
self.tiled_mma = tiled_mma
|
||||
self.tiled_mma_sfb = tiled_mma_sfb
|
||||
self.mma_tiler = mma_tiler
|
||||
self.mma_tiler_sfb = mma_tiler_sfb
|
||||
self.cluster_layout_vmnk_shape = cluster_layout_vmnk_shape
|
||||
self.cluster_layout_sfb_vmnk_shape = cluster_layout_sfb_vmnk_shape
|
||||
self.epi_tile = epi_tile
|
||||
# Runtime params
|
||||
self.a_tensor = a_tensor
|
||||
self.b_tensor = b_tensor
|
||||
self.c_tensor = c_tensor
|
||||
self.sfa_tensor = sfa_tensor
|
||||
self.sfb_tensor = sfb_tensor
|
||||
self.offs = offs
|
||||
self.offs_padded = offs_padded
|
||||
self.expert_cnt = expert_cnt
|
||||
# Workspace with scenario-specific slot layout
|
||||
if scenario == "2Dx3D":
|
||||
self.workspace = TensormapWorkspace(workspace_ptr, ["c"])
|
||||
else:
|
||||
self.workspace = TensormapWorkspace(workspace_ptr, ["a", "b", "sfa", "sfb"])
|
||||
|
||||
@staticmethod
|
||||
def get_workspace_size(scenario: Literal["2Dx3D", "2Dx2D"], expert_cnt: int) -> int:
|
||||
"""Calculate workspace size in bytes for tensormap descriptors."""
|
||||
if scenario == "2Dx3D":
|
||||
return TensormapWorkspace.size_bytes(1, expert_cnt) # C only
|
||||
else:
|
||||
return TensormapWorkspace.size_bytes(4, expert_cnt) # A, B, SFA, SFB
|
||||
|
||||
@cute.jit
|
||||
def get_desc_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer:
|
||||
return self.workspace.get_ptr(tensor_name, executor_idx)
|
||||
|
||||
@cute.jit
|
||||
def construct_and_write(self, lane_in_group: Int32, dependency=None) -> None:
|
||||
"""
|
||||
Create expert-wise tensormap descriptors for all experts.
|
||||
|
||||
``lane_in_group`` is the thread's position within its warp group
|
||||
(0..ChunkSize-1). The method loops internally over all experts in
|
||||
chunks of ``ChunkSize``, with two-phase pipeline synchronization
|
||||
per chunk.
|
||||
|
||||
Per-chunk execution:
|
||||
|
||||
1. Phase 1: Build descriptors that do NOT depend on ``offs_padded``
|
||||
(A/B for 2Dx2D, C for 2Dx3D). Overlaps with Group A's prefix sum.
|
||||
2. Barrier: ``consumer.wait_and_advance()`` — all threads participate.
|
||||
3. Phase 2: Build descriptors that depend on ``offs_padded``
|
||||
(SFA/SFB for 2Dx2D). Reads padded offsets from SMEM buffer.
|
||||
4. Release: ``handle.release()`` — all threads participate.
|
||||
|
||||
:param lane_in_group: Thread's position within the warp group (0..127).
|
||||
:param dependency: ``(PipelineConsumer, smem_offs_padded)`` — the
|
||||
consumer for mbarrier sync, and the SMEM tensor of shape
|
||||
``(ChunkSize + 1,)`` with layout ``[carry, offs_padded[0..127]]``.
|
||||
"""
|
||||
consumer, smem_offs_padded = dependency
|
||||
assert self.expert_cnt is not None
|
||||
num_chunks = (self.expert_cnt + self.ChunkSize - 1) // self.ChunkSize
|
||||
|
||||
chunk_idx = cutlass.Int32(0)
|
||||
while chunk_idx < num_chunks:
|
||||
expert_idx = chunk_idx * self.ChunkSize + lane_in_group
|
||||
in_bounds = expert_idx < self.expert_cnt
|
||||
|
||||
# Phase 1: non-dependent descriptors
|
||||
if in_bounds:
|
||||
if cutlass.const_expr(self.scenario == "2Dx2D"):
|
||||
self._construct_ab_descs_2dx2d(expert_idx)
|
||||
else:
|
||||
self._construct_c_desc_2dx3d(expert_idx)
|
||||
|
||||
# All threads participate in barrier (fixed arrive count)
|
||||
handle = consumer.wait_and_advance()
|
||||
|
||||
# Phase 2: dependent descriptors (read padded offsets from SMEM)
|
||||
if in_bounds:
|
||||
if cutlass.const_expr(self.scenario == "2Dx2D"):
|
||||
# smem_offs_padded layout: [carry, chunk[0], ..., chunk[127]]
|
||||
# padded_offset = smem[lane] (prev expert's cumulative)
|
||||
# padded_end = smem[lane + 1] (this expert's cumulative)
|
||||
padded_offset = smem_offs_padded[lane_in_group]
|
||||
padded_size_i = smem_offs_padded[lane_in_group + 1] - padded_offset
|
||||
self._construct_sf_descs_2dx2d_direct(
|
||||
expert_idx, padded_offset, padded_size_i
|
||||
)
|
||||
|
||||
# All threads release (fixed arrive count)
|
||||
handle.release()
|
||||
|
||||
chunk_idx += 1
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# 2Dx3D: C descriptor (same as MoEGroupedGemmTensormapConstructor)
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
@cute.jit
|
||||
def _construct_c_desc_2dx3d(self, expert_idx: Int32) -> None:
|
||||
"""
|
||||
2Dx3D: Create expert-wise C descriptor.
|
||||
C: (fake_m, n, 1) -> slice to (tokens_i, n, 1) per expert.
|
||||
"""
|
||||
token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx)
|
||||
c1 = cutlass.Int32(1)
|
||||
|
||||
c_i = cute.domain_offset((token_offset, 0, 0), self.c_tensor)
|
||||
c_i = rewrite_tensor_shape(c_i, (tokens_i, self.c_tensor.shape[1], c1)) # type: ignore[index]
|
||||
|
||||
tma_atom_c, _ = cpasync.make_tiled_tma_atom(
|
||||
self.c_tma_op,
|
||||
c_i,
|
||||
self.epi_smem_layout,
|
||||
self.epi_tile,
|
||||
)
|
||||
cpasync.copy_tensormap(tma_atom_c, self.get_desc_ptr("c", expert_idx))
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# 2Dx2D: A, B descriptors (same as MoEGroupedGemmTensormapConstructor)
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
@cute.jit
|
||||
def _construct_ab_descs_2dx2d(self, expert_idx: Int32) -> None:
|
||||
"""
|
||||
2Dx2D: Create expert-wise A and B descriptors.
|
||||
A: (m, fake_k, 1) -> slice to (m, tokens_i, 1)
|
||||
B: (n, fake_k, 1) -> slice to (n, tokens_i, 1)
|
||||
"""
|
||||
token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx)
|
||||
c1 = cutlass.Int32(1)
|
||||
|
||||
# A: (m, fake_k, 1) -> domain_offset + rewrite shape
|
||||
a_i = cute.domain_offset((0, token_offset, 0), self.a_tensor)
|
||||
a_i = rewrite_tensor_shape(a_i, (self.a_tensor.shape[0], tokens_i, c1)) # type: ignore[index]
|
||||
|
||||
tma_atom_a, _ = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
self.a_tma_op,
|
||||
a_i,
|
||||
self.a_smem_layout,
|
||||
self.mma_tiler,
|
||||
self.tiled_mma,
|
||||
self.cluster_layout_vmnk_shape,
|
||||
)
|
||||
cpasync.copy_tensormap(tma_atom_a, self.get_desc_ptr("a", expert_idx))
|
||||
|
||||
# B: (n, fake_k, 1) -> domain_offset + rewrite shape
|
||||
b_i = cute.domain_offset((0, token_offset, 0), self.b_tensor)
|
||||
b_i = rewrite_tensor_shape(b_i, (self.b_tensor.shape[0], tokens_i, c1)) # type: ignore[index]
|
||||
|
||||
tma_atom_b, _ = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
self.b_tma_op,
|
||||
b_i,
|
||||
self.b_smem_layout,
|
||||
self.mma_tiler,
|
||||
self.tiled_mma,
|
||||
self.cluster_layout_vmnk_shape,
|
||||
)
|
||||
cpasync.copy_tensormap(tma_atom_b, self.get_desc_ptr("b", expert_idx))
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# 2Dx2D: SFA, SFB descriptors (new for block-scaled)
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
@cute.jit
|
||||
def _construct_sf_descs_2dx2d_direct(
|
||||
self,
|
||||
expert_idx: Int32,
|
||||
padded_offset: Int32,
|
||||
padded_size_i: Int32,
|
||||
) -> None:
|
||||
"""
|
||||
2Dx2D: Create expert-wise SFA and SFB descriptors with pre-computed
|
||||
padded offset and size.
|
||||
|
||||
This variant allows the caller to supply padded offsets from SMEM
|
||||
(in desc_init_kernel) instead of reading from ``self.offs_padded`` in GMEM.
|
||||
"""
|
||||
c1 = cutlass.Int32(1)
|
||||
|
||||
a_chunks_to_move = (
|
||||
padded_offset
|
||||
// self.sf_vec_size
|
||||
* cute.size(self.sfa_tensor, mode=[0])
|
||||
// 128
|
||||
)
|
||||
a_elems_to_move = (
|
||||
cute.size(self.sfa_tensor, mode=[0]) * padded_offset // self.sf_vec_size
|
||||
)
|
||||
b_chunks_to_move = (
|
||||
padded_offset
|
||||
// self.sf_vec_size
|
||||
* cute.size(self.sfb_tensor, mode=[0])
|
||||
// 128
|
||||
)
|
||||
b_elems_to_move = (
|
||||
cute.size(self.sfb_tensor, mode=[0]) * padded_offset // self.sf_vec_size
|
||||
)
|
||||
|
||||
per_expert_sfa_shape = (self.sfa_tensor.shape[0], padded_size_i, c1) # type: ignore[index]
|
||||
sfa_layout_i = tile_atom_to_shape_SF(per_expert_sfa_shape, self.sf_vec_size)
|
||||
sfa_i = cute.make_tensor(
|
||||
self.sfa_tensor.iterator + a_elems_to_move, sfa_layout_i
|
||||
)
|
||||
|
||||
tma_atom_sfa, _ = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
self.sfa_tma_op,
|
||||
sfa_i,
|
||||
self.sfa_smem_layout,
|
||||
self.mma_tiler,
|
||||
self.tiled_mma,
|
||||
self.cluster_layout_vmnk_shape,
|
||||
internal_type=cutlass.Uint64,
|
||||
)
|
||||
cpasync.copy_tensormap(tma_atom_sfa, self.get_desc_ptr("sfa", expert_idx))
|
||||
|
||||
per_expert_sfb_shape = (self.sfb_tensor.shape[0], padded_size_i, c1) # type: ignore[index]
|
||||
sfb_layout_i = tile_atom_to_shape_SF(per_expert_sfb_shape, self.sf_vec_size)
|
||||
sfb_i = cute.make_tensor(
|
||||
self.sfb_tensor.iterator + b_elems_to_move, sfb_layout_i
|
||||
)
|
||||
|
||||
tma_atom_sfb, _ = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
self.sfb_tma_op,
|
||||
sfb_i,
|
||||
self.sfb_smem_layout,
|
||||
self.mma_tiler_sfb,
|
||||
self.tiled_mma_sfb,
|
||||
self.cluster_layout_sfb_vmnk_shape,
|
||||
internal_type=cutlass.Uint64,
|
||||
)
|
||||
cpasync.copy_tensormap(tma_atom_sfb, self.get_desc_ptr("sfb", expert_idx))
|
||||
2019
reference/moe_torch_grouped_mm.py
Normal file
2019
reference/moe_torch_grouped_mm.py
Normal file
File diff suppressed because it is too large
Load Diff
3901
reference/moe_torch_scaled_grouped_mm.py
Normal file
3901
reference/moe_torch_scaled_grouped_mm.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user