Files
nvfp4-megamoe-kernel/reference/moe_torch_grouped_mm.py
biondizzle a2ea836c74 docs: add CuTeDSL rewrite plan + reference files
The C++ CUTLASS kernel is fundamentally broken (cosine 0.05 with real
data). Switching to NVIDIA's CuTeDSL approach based on their official
MoE scaled grouped GEMM example.

Reference files copied:
- moe_torch_scaled_grouped_mm.py (3900 lines — our new kernel)
- moe_utils.py, moe_persistent_scheduler.py, moe_sched_extension.py
- grouped_blockscaled_gemm.py, dense_blockscaled_gemm_persistent.py
- blockscaled_layout.py
2026-05-16 02:41:51 +00:00

2020 lines
79 KiB
Python

# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import os
import sys
from typing import Optional, Tuple, Literal, Type, Union
import cuda.bindings.driver as cuda
import cutlass
import cutlass.cute as cute
from cutlass.cute.typing import Pointer
from cutlass.cute.nvgpu import cpasync, tcgen05
import cutlass.utils as utils
import cutlass.pipeline as pipeline
from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait
if __name__ == "__main__":
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(current_dir, "../../.."))
from blackwell.kernel.moe.moe_utils import (
MoEGroupedGemmTensormapConstructor,
)
from blackwell.kernel.moe.moe_persistent_scheduler import (
MoEStaticSchedulerParams,
MoEStaticPersistentTileScheduler,
MoEWorkTileInfo,
)
from blackwell.kernel.moe.moe_sched_extension import GroupedMmSchedExtension
from cutlass.utils.gemm.sm100 import (
transform_partitioned_tensor_layout,
epilogue_tmem_copy_and_partition,
epilogue_smem_copy_and_partition,
)
class GroupedGemmKernel:
"""
Grouped GEMM kernel for MoE operations.
PyTorch interface (from torch.nn.functional.grouped_mm):
- 2Dx3D (Forward): mat_a(tokens_sum, K) x mat_b(experts, K, N) -> out(tokens_sum, N)
- 2Dx2D (Weight grad): mat_a(hidden, tokens_sum) x mat_b(tokens_sum, intermediate) -> out(experts, hidden, intermediate)
Kernel interface uses "fake" GEMM MNKL domain:
2Dx3D:
A_cute: (gemm_fake_m, gemm_k, 1) # fake_m = tokens_sum, scheduler will offset
B_cute: (gemm_n, gemm_k, gemm_fake_l) # fake_l = expert_idx, scheduler will select
C_cute: (gemm_fake_m, gemm_n, 1) # fake_m = tokens_sum, scheduler will offset
2Dx2D:
A_cute: (gemm_m, gemm_fake_k, 1) # fake_k = tokens_sum, scheduler will offset
B_cute: (gemm_n, gemm_fake_k, 1) # fake_k = tokens_sum, scheduler will offset
C_cute: (gemm_m, gemm_n, gemm_fake_l) # fake_l = expert_idx, scheduler will select
The scheduler handles the fake dimensions by:
- For fake_m/fake_k: Computing token_offset from offs and adjusting tensor coord
- For fake_l: Selecting expert slice via L coordinate
"""
def __init__(
self,
scenario: Literal["2Dx3D", "2Dx2D"],
out_dtype: Type[cutlass.Numeric],
accumulate_on_output: bool,
separate_tensormap_init: bool = True,
fixed_expert_cnt: Optional[int] = None,
acc_dtype: Type[cutlass.Numeric] = cutlass.Float32,
mma_tiler_mnk: Tuple[int, int, int] = (128, 128, 64),
cluster_shape_mnk: Tuple[int, int, int] = (1, 1, 1),
use_2cta_instrs: bool = False,
):
# User-provided configs
self.scenario = scenario
self.out_dtype = out_dtype
self.accumulate_on_output = accumulate_on_output
self.separate_tensormap_init = separate_tensormap_init
self.fixed_expert_cnt = fixed_expert_cnt # Not used yet...
self.acc_dtype = acc_dtype
self.mma_tiler_mnk = mma_tiler_mnk
self.cluster_shape_mn = (cluster_shape_mnk[0], cluster_shape_mnk[1])
self.use_2cta_instrs = use_2cta_instrs
self.arch = "sm_100"
if accumulate_on_output and scenario == "2Dx3D":
raise ValueError(
"Non-sense config: grad accumulate should only happens in 2Dx2D."
)
self._validate_mma_tiler_and_cluster_shape()
# K dimension is deferred in _setup_attributes
self.mma_tiler = (mma_tiler_mnk[0], mma_tiler_mnk[1], 1)
# CTA group for tcgen05 MMA
self.cta_group = (
tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
)
# Occupancy and warp specialization
self.occupancy = 1
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4
self.tma_warp_id = 5
self.sched_warp_id = 6
self.threads_per_cta = 32 * len(
(
self.mma_warp_id,
self.tma_warp_id,
self.sched_warp_id,
*self.epilogue_warp_id,
)
)
# Barrier IDs for synchronization
self.epilog_sync_bar_id = 1
self.tmem_alloc_sync_bar_id = 2
self.tmem_dealloc_sync_bar_id = 3
def _validate_mma_tiler_and_cluster_shape(self):
"""Validate codegen-time MMA tiler and cluster shape constraints."""
m, n = self.mma_tiler_mnk[0], self.mma_tiler_mnk[1]
cm, cn = self.cluster_shape_mn
if self.use_2cta_instrs:
valid_m = [128, 256]
else:
valid_m = [64, 128]
if m not in valid_m:
raise ValueError(
f"mma_tiler M ({m}) must be one of {valid_m} "
f"(use_2cta_instrs={self.use_2cta_instrs})"
)
if n not in range(32, 257, 32):
raise ValueError(f"mma_tiler N ({n}) must be a multiple of 32 in [32, 256]")
if cm % (2 if self.use_2cta_instrs else 1) != 0:
raise ValueError(
f"cluster_shape M ({cm}) must be even when use_2cta_instrs=True"
)
is_pow2 = lambda x: x > 0 and (x & (x - 1)) == 0
if cm * cn > 16 or not is_pow2(cm) or not is_pow2(cn):
raise ValueError(
f"Invalid cluster_shape ({cm}, {cn}): each dim must be "
f"a power of 2, and product must be <= 16"
)
def _create_tiled_mma(self) -> cute.TiledMma:
"""Create tiled MMA atom based on input dtypes and major modes."""
return utils.sm100.make_trivial_tiled_mma(
self.a_dtype,
self.b_dtype,
self.a_major_mode,
self.b_major_mode,
self.acc_dtype,
self.cta_group,
self.mma_tiler[:2],
)
def _setup_attributes(self) -> None:
"""
Set up configurations that depend on GEMM inputs.
This method configures:
- tiled_mma with correct dtypes and major modes
- MMA/cluster/tile shapes
- Cluster layout
- Multicast CTA counts
- Epilogue tile shape
- Stage counts (ACC, A/B, C)
- SMEM layouts for A/B/C
- Tensor memory allocation columns
- TMA load bytes
"""
tiled_mma = self._create_tiled_mma()
# Use user-specified K dimension directly from mma_tiler_mnk
# Verify K is a multiple of the MMA instruction's native K size
mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2])
assert self.mma_tiler_mnk[2] % mma_inst_shape_k == 0, (
f"mma_tiler K ({self.mma_tiler_mnk[2]}) must be a multiple of "
f"MMA instruction K ({mma_inst_shape_k})"
)
self.mma_tiler = (
self.mma_tiler[0],
self.mma_tiler[1],
self.mma_tiler_mnk[2],
)
# CTA tile shape
self.cta_tile_shape_mnk = (
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
self.mma_tiler[1],
self.mma_tiler[2],
)
# Cluster layout
self.cluster_layout_vmnk = cute.tiled_divide(
cute.make_layout((*self.cluster_shape_mn, 1)),
(tiled_mma.thr_id.shape,),
)
# Multicast CTA counts
self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2])
self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1])
self.is_a_mcast = self.num_mcast_ctas_a > 1
self.is_b_mcast = self.num_mcast_ctas_b > 1
# Epilogue tile shape (always use TMA store for MoE)
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk,
self.use_2cta_instrs,
self.c_layout,
self.c_dtype,
)
# C SMEM layout for epilogue
c_smem_layout = utils.sm100.make_smem_layout_epi(
self.c_dtype, self.c_layout, self.epi_tile, 1
)
self.smem_capacity = utils.get_smem_capacity_in_bytes()
# Compute stage counts
self.num_acc_stage = 2
self.num_c_stage = 2 # Always use TMA store for MoE
a_smem_layout_stage_one = utils.sm100.make_smem_layout_a(
tiled_mma, self.mma_tiler, self.a_dtype, 1
)
b_smem_layout_stage_one = utils.sm100.make_smem_layout_b(
tiled_mma, self.mma_tiler, self.b_dtype, 1
)
ab_bytes_per_stage = cute.size_in_bytes(
self.a_dtype, a_smem_layout_stage_one
) + cute.size_in_bytes(self.b_dtype, b_smem_layout_stage_one)
mbar_helpers_bytes = 1024
c_bytes_per_stage = cute.size_in_bytes(self.c_dtype, c_smem_layout)
c_bytes = c_bytes_per_stage * self.num_c_stage
self.num_sched_stages = 2
sched_work_tile_bytes_per_stage = 16 # 4 fields * sizeof(Int32)
sched_bytes = sched_work_tile_bytes_per_stage * self.num_sched_stages
fixed_overhead = mbar_helpers_bytes + c_bytes + sched_bytes
self.num_ab_stage = (
self.smem_capacity // self.occupancy - fixed_overhead
) // ab_bytes_per_stage
# Refine epilogue stages with remaining SMEM
self.num_c_stage += (
self.smem_capacity
- self.occupancy * ab_bytes_per_stage * self.num_ab_stage
- self.occupancy * fixed_overhead
) // (self.occupancy * c_bytes_per_stage)
# SMEM layouts
self.a_smem_layout_staged = utils.sm100.make_smem_layout_a(
tiled_mma, self.mma_tiler, self.a_dtype, self.num_ab_stage
)
self.b_smem_layout_staged = utils.sm100.make_smem_layout_b(
tiled_mma, self.mma_tiler, self.b_dtype, self.num_ab_stage
)
self.c_smem_layout_staged = utils.sm100.make_smem_layout_epi(
self.c_dtype, self.c_layout, self.epi_tile, self.num_c_stage
)
# Tensor memory allocation columns
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
tCtAcc_fake = tiled_mma.make_fragment_C(
cute.append(acc_shape, self.num_acc_stage)
)
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(
tCtAcc_fake, arch=self.arch
)
# TMA load bytes
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout)
b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout)
self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size
def get_workspace_size(self, expert_cnt: int) -> int:
"""
Workspace size for expert-wise TMA descriptors.
2Dx3D: Need C desc per expert -> expert_cnt * TensormapDescBytes
2Dx2D: Need A and B desc per expert -> 2 * expert_cnt * TensormapDescBytes
"""
return MoEGroupedGemmTensormapConstructor.get_workspace_size(
self.scenario, expert_cnt
)
@cute.jit
def __call__(
self,
mat_a: cute.Tensor, # PyTorch mat_a
mat_b: cute.Tensor, # PyTorch mat_b
out: cute.Tensor, # PyTorch output
offs: cute.Tensor, # (experts,) cumsum
bias: Optional[cute.Tensor],
workspace: cute.Tensor,
max_active_clusters: cutlass.Constexpr,
stream: cuda.CUstream,
) -> None:
"""
Launch the grouped GEMM kernel.
This method:
1. Transforms PyTorch tensors to GEMM domain tensors
2. Infers dtypes and major modes from GEMM domain tensors
3. Sets up kernel attributes
4. Creates TMA atoms for A, B, C
5. Creates scheduler parameters
6. Launches the kernel
"""
if cutlass.const_expr(bias is not None):
raise NotImplementedError("bias is not supported yet (align with torch).")
# =====================================================================
# Step 1: Transform PyTorch tensors to GEMM domain (fake MNKL)
# =====================================================================
c1 = cutlass.Int32(1)
c0 = cutlass.Int32(0)
if cutlass.const_expr(self.scenario == "2Dx3D"):
# mat_a: (tokens_sum, hidden) -> A_cute: (fake_m, k, 1)
tokens_sum, hidden = mat_a.shape
a_gemm = cute.make_tensor(
mat_a.iterator,
cute.make_layout(
(tokens_sum, hidden, c1),
stride=(mat_a.stride[0], mat_a.stride[1], c0),
),
)
# mat_b: (experts, hidden, intermediate) -> B_cute: (n, k, fake_l)
experts, hidden_b, intermediate = mat_b.shape
b_gemm = cute.make_tensor(
mat_b.iterator,
cute.make_layout(
(intermediate, hidden_b, experts),
stride=(mat_b.stride[2], mat_b.stride[1], mat_b.stride[0]),
),
)
# out: (tokens_sum, intermediate) -> C_cute: (fake_m, n, 1)
c_gemm = cute.make_tensor(
out.iterator,
cute.make_layout(
(tokens_sum, intermediate, c1),
stride=(out.stride[0], out.stride[1], c0),
),
)
expert_cnt = experts
intermediate_dim = intermediate
hidden_dim = hidden
else: # 2Dx2D
# mat_a: (hidden, tokens_sum) -> A_cute: (m, fake_k, 1)
hidden, tokens_sum = mat_a.shape
a_gemm = cute.make_tensor(
mat_a.iterator,
cute.make_layout(
(hidden, tokens_sum, c1),
stride=(mat_a.stride[0], mat_a.stride[1], c0),
),
)
# mat_b: (tokens_sum, intermediate) -> B_cute: (n, fake_k, 1)
tokens_sum_b, intermediate = mat_b.shape
b_gemm = cute.make_tensor(
mat_b.iterator,
cute.make_layout(
(intermediate, tokens_sum_b, c1),
stride=(mat_b.stride[1], mat_b.stride[0], c0),
),
)
# out: (experts, hidden, intermediate) -> C_cute: (m, n, fake_l)
experts, hidden_c, intermediate_c = out.shape
c_gemm = cute.make_tensor(
out.iterator,
cute.make_layout(
(hidden_c, intermediate_c, experts),
stride=(out.stride[1], out.stride[2], out.stride[0]),
),
)
expert_cnt = experts
intermediate_dim = intermediate
hidden_dim = hidden
# =====================================================================
# Step 2: Infer dtypes and major modes from GEMM domain tensors
# =====================================================================
self.a_dtype: Type[cutlass.Numeric] = a_gemm.element_type
self.b_dtype: Type[cutlass.Numeric] = b_gemm.element_type
self.c_dtype: Type[cutlass.Numeric] = c_gemm.element_type
self.a_major_mode = utils.LayoutEnum.from_tensor(a_gemm).mma_major_mode()
self.b_major_mode = utils.LayoutEnum.from_tensor(b_gemm).mma_major_mode()
self.c_layout = utils.LayoutEnum.from_tensor(c_gemm)
# =====================================================================
# Step 3: Setup kernel attributes
# =====================================================================
k = self.mma_tiler_mnk[2]
a_tile_bits = self.a_dtype.width * k
b_tile_bits = self.b_dtype.width * k
if cutlass.const_expr(a_tile_bits % 256 != 0):
raise ValueError(
f"a_dtype ({self.a_dtype.width}b) * mma_tiler K ({k}) = "
f"{a_tile_bits}b, must be a multiple of 256b (MMA instruction K width)"
)
if cutlass.const_expr(b_tile_bits % 256 != 0):
raise ValueError(
f"b_dtype ({self.b_dtype.width}b) * mma_tiler K ({k}) = "
f"{b_tile_bits}b, must be a multiple of 256b (MMA instruction K width)"
)
self._setup_attributes()
tiled_mma = self._create_tiled_mma()
# =====================================================================
# Step 4: Create TMA atoms for A, B, C
# =====================================================================
# TMA load for A
a_op = utils.sm100.cluster_shape_to_tma_atom_A(
self.cluster_shape_mn, tiled_mma.thr_id
)
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
a_op,
a_gemm,
a_smem_layout,
self.mma_tiler,
tiled_mma,
self.cluster_layout_vmnk.shape,
)
# TMA load for B
b_op = utils.sm100.cluster_shape_to_tma_atom_B(
self.cluster_shape_mn, tiled_mma.thr_id
)
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
b_op,
b_gemm,
b_smem_layout,
self.mma_tiler,
tiled_mma,
self.cluster_layout_vmnk.shape,
)
# TMA store for C (or TMA reduce for accumulate_on_output)
if cutlass.const_expr(self.accumulate_on_output):
c_tma_op = cpasync.CopyReduceBulkTensorTileS2GOp()
else:
c_tma_op = cpasync.CopyBulkTensorTileS2GOp()
epi_smem_layout = cute.select(self.c_smem_layout_staged, mode=[0, 1])
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
c_tma_op, c_gemm, epi_smem_layout, self.epi_tile
)
# =====================================================================
# Step 5: Create MoEStaticSchedulerParams and compute grid
# =====================================================================
sched_params = MoEStaticSchedulerParams(
scenario=self.scenario,
expert_shape=(expert_cnt, intermediate_dim, hidden_dim),
cta_tile_shape_mnk=self.cta_tile_shape_mnk,
cluster_shape_mn=self.cluster_shape_mn,
)
grid = MoEStaticSchedulerParams.get_grid_shape(
sched_params, max_active_clusters
)
# =====================================================================
# Step 5.5: Launch desc init kernel (if separate_tensormap_init)
# =====================================================================
#
# Pre-initialize expert-wise TMA descriptors in workspace before
# the main kernel. Stream ordering guarantees completion before
# the main kernel starts.
#
# 2Dx3D: C desc per expert (C has dynamic fake_m per expert)
# 2Dx2D: A,B desc per expert (A,B have dynamic fake_k per expert)
#
if cutlass.const_expr(self.separate_tensormap_init):
self.desc_init_kernel(
tiled_mma,
a_gemm,
b_gemm,
c_gemm,
offs,
expert_cnt,
workspace.iterator,
self.cluster_layout_vmnk,
self.a_smem_layout_staged,
self.b_smem_layout_staged,
self.c_smem_layout_staged,
self.epi_tile,
).launch(
grid=(expert_cnt, 1, 1),
block=[32, 1, 1],
stream=stream,
min_blocks_per_mp=1,
)
# =====================================================================
# Step 6: Launch kernel
# =====================================================================
self.kernel(
tiled_mma,
tma_atom_a,
tma_tensor_a,
tma_atom_b,
tma_tensor_b,
tma_atom_c,
tma_tensor_c,
a_gemm,
b_gemm,
c_gemm,
offs,
sched_params,
workspace.iterator,
self.cluster_layout_vmnk,
self.a_smem_layout_staged,
self.b_smem_layout_staged,
self.c_smem_layout_staged,
self.epi_tile,
).launch(
grid=grid,
block=[self.threads_per_cta, 1, 1],
cluster=(*self.cluster_shape_mn, 1),
stream=stream,
min_blocks_per_mp=self.occupancy,
)
# GPU device kernel: TMA descriptor initialization
@cute.kernel
def desc_init_kernel(
self,
tiled_mma: cute.TiledMma,
a_gemm: cute.Tensor, # GEMM domain A (fake MNKL)
b_gemm: cute.Tensor, # GEMM domain B (fake MNKL)
c_gemm: cute.Tensor, # GEMM domain C (fake MNKL)
offs: cute.Tensor, # (experts,) cumsum
expert_cnt: Union[cutlass.Int32, int],
workspace_ptr: Pointer,
cluster_layout_vmnk: cute.Layout,
a_smem_layout_staged: cute.ComposedLayout,
b_smem_layout_staged: cute.ComposedLayout,
c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout],
epi_tile: cute.Tile,
):
"""
Separate kernel to pre-initialize expert-wise TMA descriptors.
Grid: (expert_cnt, 1, 1) - one block per expert
Block: (32, 1, 1) - one warp per block
Each block constructs and writes TMA descriptors for one expert
to the pre-allocated workspace buffer.
2Dx3D: Creates C descriptor per expert (C has dynamic fake_m per expert)
2Dx2D: Creates A and B descriptors per expert (A/B have dynamic fake_k per expert)
"""
# =================================================================
# Reconstruct TMA constructor with explicit attributes
# =================================================================
a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, None, 0))
b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, None, 0))
epi_smem_layout = cute.select(c_smem_layout_staged, mode=[0, 1])
a_tma_op = utils.sm100.cluster_shape_to_tma_atom_A(
self.cluster_shape_mn, tiled_mma.thr_id
)
b_tma_op = utils.sm100.cluster_shape_to_tma_atom_B(
self.cluster_shape_mn, tiled_mma.thr_id
)
if cutlass.const_expr(self.accumulate_on_output):
c_tma_op = cpasync.CopyReduceBulkTensorTileS2GOp()
else:
c_tma_op = cpasync.CopyBulkTensorTileS2GOp()
tensormap_ctor = MoEGroupedGemmTensormapConstructor(
scenario=self.scenario,
a_dtype=self.a_dtype,
b_dtype=self.b_dtype,
c_dtype=self.c_dtype,
a_smem_layout=a_smem_layout,
b_smem_layout=b_smem_layout,
epi_smem_layout=epi_smem_layout,
a_tma_op=a_tma_op,
b_tma_op=b_tma_op,
c_tma_op=c_tma_op,
tiled_mma=tiled_mma,
mma_tiler=self.mma_tiler,
cluster_layout_vmnk_shape=cluster_layout_vmnk.shape,
epi_tile=epi_tile,
a_tensor=a_gemm,
b_tensor=b_gemm,
c_tensor=c_gemm,
offs=offs,
workspace_ptr=workspace_ptr,
)
# =================================================================
# Each block constructs descriptors for one expert
# =================================================================
expert_idx, _, _ = cute.arch.block_idx()
tensormap_ctor.construct_and_write(expert_idx)
# GPU device kernel: main GEMM kernel
@cute.kernel
def kernel(
self,
tiled_mma: cute.TiledMma,
tma_atom_a: cute.CopyAtom,
tma_tensor_a: cute.Tensor,
tma_atom_b: cute.CopyAtom,
tma_tensor_b: cute.Tensor,
tma_atom_c: cute.CopyAtom,
tma_tensor_c: cute.Tensor,
a_gemm: cute.Tensor, # GEMM domain A (fake MNKL)
b_gemm: cute.Tensor, # GEMM domain B (fake MNKL)
c_gemm: cute.Tensor, # GEMM domain C (fake MNKL)
offs: cute.Tensor, # (experts,) cumsum
sched_params: MoEStaticSchedulerParams,
workspace_ptr: Pointer,
cluster_layout_vmnk: cute.Layout,
a_smem_layout_staged: cute.ComposedLayout,
b_smem_layout_staged: cute.ComposedLayout,
c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout],
epi_tile: cute.Tile,
):
"""
GPU device kernel for MoE Grouped GEMM.
Warp specialization:
- Warps 0-3: Epilogue warps (TMEM -> RMEM -> SMEM -> GMEM)
- Warp 4: MMA warp (tcgen05.mma)
- Warp 5: TMA load warp (also prefetches expert-wise TMA descriptors)
The kernel uses MoEStaticPersistentTileScheduler to iterate over tiles
across all experts. For each tile:
1. TMA load warp fetches A/B tiles using get_gmem_tensor
2. MMA warp performs matrix multiply-accumulate
3. Epilogue warps store results using TMA store/reduce
Note: Python objects holding MLIR values cannot be kernel params.
The following are constructed inside the kernel from individually-passed params:
- tensormap_ctor: MoEGroupedGemmTensormapConstructor (online tensormap builder)
- ext: GroupedMmSchedExtension (domain conversion + TMA desc selection)
"""
# =================================================================
# Reconstruct dicts that can't be passed as kernel params
# =================================================================
# Construct TMA descriptor creator and scheduler extension
a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, None, 0))
b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, None, 0))
epi_smem_layout = cute.select(c_smem_layout_staged, mode=[0, 1])
a_tma_op = utils.sm100.cluster_shape_to_tma_atom_A(
self.cluster_shape_mn, tiled_mma.thr_id
)
b_tma_op = utils.sm100.cluster_shape_to_tma_atom_B(
self.cluster_shape_mn, tiled_mma.thr_id
)
if cutlass.const_expr(self.accumulate_on_output):
c_tma_op = cpasync.CopyReduceBulkTensorTileS2GOp()
else:
c_tma_op = cpasync.CopyBulkTensorTileS2GOp()
tensormap_ctor = MoEGroupedGemmTensormapConstructor(
scenario=self.scenario,
a_dtype=self.a_dtype,
b_dtype=self.b_dtype,
c_dtype=self.c_dtype,
a_smem_layout=a_smem_layout,
b_smem_layout=b_smem_layout,
epi_smem_layout=epi_smem_layout,
a_tma_op=a_tma_op,
b_tma_op=b_tma_op,
c_tma_op=c_tma_op,
tiled_mma=tiled_mma,
mma_tiler=self.mma_tiler,
cluster_layout_vmnk_shape=cluster_layout_vmnk.shape,
epi_tile=epi_tile,
a_tensor=a_gemm,
b_tensor=b_gemm,
c_tensor=c_gemm,
offs=offs,
workspace_ptr=workspace_ptr,
)
ext = GroupedMmSchedExtension(
scenario=self.scenario, tensormap_ctor=tensormap_ctor
)
# =================================================================
# Kernel setup
# =================================================================
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2
# CTA/thread coordinates
bidx, bidy, bidz = cute.arch.block_idx()
mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape)
is_leader_cta = mma_tile_coord_v == 0
cta_rank_in_cluster = cute.arch.make_warp_uniform(
cute.arch.block_idx_in_cluster()
)
block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(
cta_rank_in_cluster
)
tidx, _, _ = cute.arch.thread_idx()
# =================================================================
# SharedStorage
# =================================================================
@cute.struct
class SharedStorage:
ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
acc_full_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_acc_stage * 2
]
sched_buf: cute.struct.MemRange[cutlass.Int32, self.num_sched_stages * 4]
sched_mbar_ptr: cute.struct.MemRange[
cutlass.Int64, self.num_sched_stages * 2
]
tmem_dealloc_mbar_ptr: cutlass.Int64
tmem_holding_buf: cutlass.Int32
smem = utils.SmemAllocator()
storage = smem.allocate(SharedStorage)
# =================================================================
# Pipelines
# =================================================================
# AB pipeline (TMA load → MMA)
ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1
ab_pipeline_consumer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, num_tma_producer
)
ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create(
barrier_storage=storage.ab_full_mbar_ptr.data_ptr(),
num_stages=self.num_ab_stage,
producer_group=ab_pipeline_producer_group,
consumer_group=ab_pipeline_consumer_group,
tx_count=self.num_tma_load_bytes,
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
).make_participants()
# ACC pipeline (MMA → epilogue)
acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread)
num_acc_consumer_threads = (
len(self.epilogue_warp_id) * 32 * (2 if use_2cta_instrs else 1)
)
acc_pipeline_consumer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, num_acc_consumer_threads
)
acc_pipeline = pipeline.PipelineUmmaAsync.create(
barrier_storage=storage.acc_full_mbar_ptr.data_ptr(),
num_stages=self.num_acc_stage,
producer_group=acc_pipeline_producer_group,
consumer_group=acc_pipeline_consumer_group,
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
)
# Scheduler pipeline (sched warp → tma/mma/epi warps)
sched_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32)
num_sched_consumer_threads = 32 * len(
(self.tma_warp_id, self.mma_warp_id, *self.epilogue_warp_id)
)
sched_consumer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread, num_sched_consumer_threads
)
sched_pipeline = pipeline.PipelineAsync.create(
num_stages=self.num_sched_stages,
producer_group=sched_producer_group,
consumer_group=sched_consumer_group,
barrier_storage=storage.sched_mbar_ptr.data_ptr(),
defer_sync=True,
)
# TMEM allocator
tmem_alloc_barrier = pipeline.NamedBarrier(
barrier_id=self.tmem_alloc_sync_bar_id,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)),
)
tmem = utils.TmemAllocator(
storage.tmem_holding_buf.ptr,
barrier_for_retrieve=tmem_alloc_barrier,
allocator_warp_id=self.epilogue_warp_id[0],
is_two_cta=use_2cta_instrs,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr.ptr,
)
# Cluster barrier sync after init
pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True)
# =================================================================
# SMEM tensors A/B
# =================================================================
# (MMA, MMA_M, MMA_K, STAGE)
sA = smem.allocate_tensor(
element_type=self.a_dtype,
layout=a_smem_layout_staged.outer,
byte_alignment=128,
swizzle=a_smem_layout_staged.inner,
)
# (MMA, MMA_N, MMA_K, STAGE)
sB = smem.allocate_tensor(
element_type=self.b_dtype,
layout=b_smem_layout_staged.outer,
byte_alignment=128,
swizzle=b_smem_layout_staged.inner,
)
# Multicast masks
a_full_mcast_mask = None
b_full_mcast_mask = None
if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs):
a_full_mcast_mask = cpasync.create_tma_multicast_mask(
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2
)
b_full_mcast_mask = cpasync.create_tma_multicast_mask(
cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1
)
# MMA fragments (SMEM → TMEM partitions)
# (MMA, MMA_M, MMA_K, STAGE)
tCrA = tiled_mma.make_fragment_A(sA)
# (MMA, MMA_N, MMA_K, STAGE)
tCrB = tiled_mma.make_fragment_B(sB)
# (MMA, MMA_M, MMA_N)
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
# (MMA, MMA_M, MMA_N, STAGE)
tCtAcc_fake = tiled_mma.make_fragment_C(
cute.append(acc_shape, self.num_acc_stage)
)
# Cluster wait before TMEM alloc
pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn)
# =================================================================
# Scheduler warp (warp 6)
# =================================================================
sched_buf_ptr = storage.sched_buf.data_ptr()
sched_copy_atom = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(), cutlass.Int32, num_bits_per_copy=128
)
sched_buf_tensor = cute.make_tensor(
sched_buf_ptr, cute.make_layout((4, self.num_sched_stages), stride=(1, 4))
)
if warp_idx == self.sched_warp_id:
scheduler = MoEStaticPersistentTileScheduler.create(
sched_params, offs, cute.arch.block_idx(), cute.arch.grid_dim()
)
sched_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_sched_stages
)
# Always produce the initial work_tile_info first
work_tile_info = scheduler.initial_work_tile_info()
sched_pipeline.producer_acquire(sched_producer_state)
rmem = work_tile_info.to_rmem_tensor()
cute.copy(
sched_copy_atom,
rmem,
sched_buf_tensor[(None, sched_producer_state.index)],
)
cute.arch.fence_proxy("async.shared", space="cta")
sched_pipeline.producer_commit(sched_producer_state)
sched_producer_state.advance()
# Iterate remaining tiles starting from the first advance
work_tile_info = scheduler.advance_to_next_work()
while work_tile_info.is_valid_tile:
ext.prefetch_for_expert(work_tile_info.expert_idx)
sched_pipeline.producer_acquire(sched_producer_state)
rmem = work_tile_info.to_rmem_tensor()
cute.copy(
sched_copy_atom,
rmem,
sched_buf_tensor[(None, sched_producer_state.index)],
)
cute.arch.fence_proxy("async.shared", space="cta")
sched_pipeline.producer_commit(sched_producer_state)
sched_producer_state.advance()
work_tile_info = scheduler.advance_to_next_work()
# Write invalid sentinel (expert_idx = -1) so consumers exit
sched_pipeline.producer_acquire(sched_producer_state)
sentinel = MoEWorkTileInfo(
cutlass.Int32(-1), cutlass.Int32(0), cutlass.Int32(0), cutlass.Int32(0)
)
rmem = sentinel.to_rmem_tensor()
cute.copy(
sched_copy_atom,
rmem,
sched_buf_tensor[(None, sched_producer_state.index)],
)
cute.arch.fence_proxy("async.shared", space="cta")
sched_pipeline.producer_commit(sched_producer_state)
sched_pipeline.producer_tail(sched_producer_state)
# =================================================================
# TMA load warp (warp 5)
# =================================================================
if warp_idx == self.tma_warp_id:
sched_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_sched_stages
)
# Read initial work_tile_info
sched_pipeline.consumer_wait(sched_consumer_state)
rmem = cute.make_rmem_tensor((4,), cutlass.Int32)
cute.copy(
sched_copy_atom,
sched_buf_tensor[(None, sched_consumer_state.index)],
rmem,
)
work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem)
cute.arch.fence_acq_rel_cta()
sched_pipeline.consumer_release(sched_consumer_state)
sched_consumer_state.advance()
while work_tile_info.is_valid_tile:
k_tile_cnt = work_tile_info.k_tile_cnt
# Get real GEMM domain tensors + TMA desc ptrs via extension
real_a, desc_ptr_a = ext.get_gmem_tensor(
"a",
tma_tensor_a,
offs,
work_tile_info,
)
real_b, desc_ptr_b = ext.get_gmem_tensor(
"b",
tma_tensor_b,
offs,
work_tile_info,
)
# local_tile for this tile's A and B
gA_mkl = cute.local_tile(
real_a,
cute.slice_(self.mma_tiler, (None, 0, None)),
(None, None, None),
)
gB_nkl = cute.local_tile(
real_b,
cute.slice_(self.mma_tiler, (0, None, None)),
(None, None, None),
)
# MMA partition for TMA
thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
tCgA = thr_mma.partition_A(gA_mkl)
tCgB = thr_mma.partition_B(gB_nkl)
# TMA partition
a_cta_layout = cute.make_layout(
cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape
)
b_cta_layout = cute.make_layout(
cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape
)
tAsA, tAgA = cpasync.tma_partition(
tma_atom_a,
block_in_cluster_coord_vmnk[2],
a_cta_layout,
cute.group_modes(sA, 0, 3),
cute.group_modes(tCgA, 0, 3),
)
tBsB, tBgB = cpasync.tma_partition(
tma_atom_b,
block_in_cluster_coord_vmnk[1],
b_cta_layout,
cute.group_modes(sB, 0, 3),
cute.group_modes(tCgB, 0, 3),
)
# Slice to current tile coords (L=0 for MoE, expert already selected)
mma_tile_m = work_tile_info.tile_m_idx // cute.size(
tiled_mma.thr_id.shape
)
tAgA_slice = tAgA[(None, mma_tile_m, None, 0)]
tBgB_slice = tBgB[(None, work_tile_info.tile_n_idx, None, 0)]
# TMA load loop
ab_producer.reset()
peek_ab_empty_status = ab_producer.try_acquire()
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
handle = ab_producer.acquire_and_advance(peek_ab_empty_status)
peek_ab_empty_status = cutlass.Boolean(1)
if handle.count + 1 < k_tile_cnt:
peek_ab_empty_status = ab_producer.try_acquire()
cute.copy(
tma_atom_a,
tAgA_slice[(None, handle.count)],
tAsA[(None, handle.index)],
tma_bar_ptr=handle.barrier,
tma_desc_ptr=desc_ptr_a,
mcast_mask=a_full_mcast_mask,
)
cute.copy(
tma_atom_b,
tBgB_slice[(None, handle.count)],
tBsB[(None, handle.index)],
tma_bar_ptr=handle.barrier,
tma_desc_ptr=desc_ptr_b,
mcast_mask=b_full_mcast_mask,
)
# Read next work_tile_info
sched_pipeline.consumer_wait(sched_consumer_state)
rmem = cute.make_rmem_tensor((4,), cutlass.Int32)
cute.copy(
sched_copy_atom,
sched_buf_tensor[(None, sched_consumer_state.index)],
rmem,
)
work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem)
cute.arch.fence_acq_rel_cta()
sched_pipeline.consumer_release(sched_consumer_state)
sched_consumer_state.advance()
ab_producer.tail()
# =================================================================
# MMA warp (warp 4)
# =================================================================
if warp_idx == self.mma_warp_id:
# Retrieve TMEM
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
acc_producer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_acc_stage
)
sched_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_sched_stages
)
# Read initial work_tile_info
sched_pipeline.consumer_wait(sched_consumer_state)
rmem = cute.make_rmem_tensor((4,), cutlass.Int32)
cute.copy(
sched_copy_atom,
sched_buf_tensor[(None, sched_consumer_state.index)],
rmem,
)
work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem)
cute.arch.fence_acq_rel_cta()
sched_pipeline.consumer_release(sched_consumer_state)
sched_consumer_state.advance()
while work_tile_info.is_valid_tile:
k_tile_cnt = work_tile_info.k_tile_cnt
if is_leader_cta:
tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)]
# AB consumer mainloop
ab_consumer.reset()
peek_ab_full_status = cutlass.Boolean(1)
if k_tile_cnt > 0:
peek_ab_full_status = ab_consumer.try_wait()
acc_pipeline.producer_acquire(acc_producer_state)
for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1):
handle = ab_consumer.wait_and_advance(peek_ab_full_status)
peek_ab_full_status = cutlass.Boolean(1)
if handle.count + 1 < k_tile_cnt:
peek_ab_full_status = ab_consumer.try_wait()
tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile != 0)
tile_crd = (None, None, None, handle.index)
cute.gemm(
tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc
)
handle.release()
if k_tile_cnt > 0:
acc_pipeline.producer_commit(acc_producer_state)
if k_tile_cnt > 0:
acc_producer_state.advance()
# Read next work_tile_info
sched_pipeline.consumer_wait(sched_consumer_state)
rmem = cute.make_rmem_tensor((4,), cutlass.Int32)
cute.copy(
sched_copy_atom,
sched_buf_tensor[(None, sched_consumer_state.index)],
rmem,
)
work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem)
cute.arch.fence_acq_rel_cta()
sched_pipeline.consumer_release(sched_consumer_state)
sched_consumer_state.advance()
acc_pipeline.producer_tail(acc_producer_state)
# =================================================================
# SMEM tensor C (allocated after MMA section, same as dense)
# =================================================================
sC = smem.allocate_tensor(
element_type=self.c_dtype,
layout=c_smem_layout_staged.outer,
byte_alignment=128,
swizzle=c_smem_layout_staged.inner,
)
# =================================================================
# Epilogue warps (warps 0-3)
# =================================================================
if warp_idx < self.mma_warp_id:
# Allocate TMEM
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
acc_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_acc_stage
)
sched_consumer_state = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_sched_stages
)
c_producer_group = pipeline.CooperativeGroup(
pipeline.Agent.Thread,
32 * len(self.epilogue_warp_id),
)
c_pipeline = pipeline.PipelineTmaStore.create(
num_stages=self.num_c_stage, producer_group=c_producer_group
)
epilog_sync_barrier = pipeline.NamedBarrier(
barrier_id=self.epilog_sync_bar_id,
num_threads=32 * len(self.epilogue_warp_id),
)
# Epilogue copy setup (same for all tiles, depends only on shapes)
# Transform ACC layout: ((ATOM_M, ATOM_N), MMA_M, MMA_N, STAGE)
# -> ((ATOM_M, MMA_M), (ATOM_N, MMA_N), STAGE)
tCtAcc_transformed = transform_partitioned_tensor_layout(tCtAcc_base)
num_tiles_executed = cutlass.Int32(0)
# Read initial work_tile_info
sched_pipeline.consumer_wait(sched_consumer_state)
rmem = cute.make_rmem_tensor((4,), cutlass.Int32)
cute.copy(
sched_copy_atom,
sched_buf_tensor[(None, sched_consumer_state.index)],
rmem,
)
work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem)
cute.arch.fence_acq_rel_cta()
sched_pipeline.consumer_release(sched_consumer_state)
sched_consumer_state.advance()
while work_tile_info.is_valid_tile:
k_tile_cnt = work_tile_info.k_tile_cnt
# Get real C tensor + TMA desc ptr via extension
real_c, desc_ptr_c = ext.get_gmem_tensor(
"c",
tma_tensor_c,
offs,
work_tile_info,
)
# local_tile + partition for C
gC_mnl = cute.local_tile(
real_c,
cute.slice_(self.mma_tiler, (None, None, 0)),
(None, None, None),
)
thr_mma = tiled_mma.get_slice(mma_tile_coord_v)
tCgC = thr_mma.partition_C(gC_mnl)
tCgC_transformed = transform_partitioned_tensor_layout(tCgC)
mma_tile_coord_mnl = (
work_tile_info.tile_m_idx // cute.size(tiled_mma.thr_id.shape),
work_tile_info.tile_n_idx,
cutlass.Int32(0),
)
# Partition for TMEM → RMEM copy
tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = (
epilogue_tmem_copy_and_partition(
self,
tidx,
tCtAcc_transformed,
tCgC_transformed,
epi_tile,
use_2cta_instrs,
)
)
tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype)
tiled_copy_r2s, tRS_rC, tRS_sC = epilogue_smem_copy_and_partition(
self, tiled_copy_t2r, tTR_rC, tidx, sC
)
# TMA partition for C store (with expert-wise desc_ptr)
tCgC_epi = cute.flat_divide(tCgC_transformed, epi_tile)
bSG_sC, bSG_gC_partitioned = cpasync.tma_partition(
tma_atom_c,
0,
cute.make_layout(1),
cute.group_modes(sC, 0, 2),
cute.group_modes(tCgC_epi, 0, 2),
)
bSG_gC = bSG_gC_partitioned[(None, None, None, *mma_tile_coord_mnl)]
# Set TMEM buffer for current tile
tTR_tAcc = tTR_tAcc_base[
(None, None, None, None, None, acc_consumer_state.index)
]
# Wait for accumulator buffer full
if k_tile_cnt > 0:
acc_pipeline.consumer_wait(acc_consumer_state)
tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC))
# Store accumulator to global memory in subtiles
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
num_prev_subtiles = num_tiles_executed * subtile_cnt
for subtile_idx in range(subtile_cnt):
# TMEM → RMEM
tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)]
if cutlass.const_expr(self.scenario == "2Dx2D"):
if k_tile_cnt > 0:
cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc)
else:
cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc)
# Convert to output dtype
acc_vec = cute.zeros_like(tiled_copy_r2s.retile(tTR_rAcc))
if cutlass.const_expr(self.scenario == "2Dx2D"):
if k_tile_cnt > 0:
acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load()
else:
acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load()
acc_vec = acc_vec.to(self.c_dtype)
tRS_rC.store(acc_vec)
# RMEM → SMEM
c_buffer = (num_prev_subtiles + subtile_idx) % self.num_c_stage
cute.copy(
tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)]
)
cute.arch.fence_proxy("async.shared", space="cta")
epilog_sync_barrier.arrive_and_wait()
# SMEM → GMEM (TMA store or TMA reduce)
if warp_idx == self.epilogue_warp_id[0]:
cute.copy(
tma_atom_c,
bSG_sC[(None, c_buffer)],
bSG_gC[(None, subtile_idx)],
tma_desc_ptr=desc_ptr_c,
)
c_pipeline.producer_commit()
c_pipeline.producer_acquire()
epilog_sync_barrier.arrive_and_wait()
# Release accumulator buffer
if k_tile_cnt > 0:
acc_pipeline.consumer_release(acc_consumer_state)
acc_consumer_state.advance()
num_tiles_executed += cutlass.Int32(1)
# Read next work_tile_info
sched_pipeline.consumer_wait(sched_consumer_state)
rmem = cute.make_rmem_tensor((4,), cutlass.Int32)
cute.copy(
sched_copy_atom,
sched_buf_tensor[(None, sched_consumer_state.index)],
rmem,
)
work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem)
cute.arch.fence_acq_rel_cta()
sched_pipeline.consumer_release(sched_consumer_state)
sched_consumer_state.advance()
# Wait for C store complete
c_pipeline.producer_tail()
# Free TMEM
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
# =============================================================================
# Host Validation
# =============================================================================
from dataclasses import dataclass, field
import re
import numpy as np
import torch
import cutlass.torch as cutlass_torch
def torch_version_lt(major: int, minor: int) -> bool:
"""Best-effort torch version check that tolerates local build suffixes."""
match = re.match(r"^\s*(\d+)\.(\d+)", torch.__version__)
if match is None:
print(
"WARNING: failed to parse torch.__version__, "
"falling back to torch._grouped_mm host reference."
)
return True
version = (int(match.group(1)), int(match.group(2)))
return version < (major, minor)
@dataclass
class ProblemDesc:
tokens: int
experts: int
top_k_select: int
balance_route: bool
hidden: int
intermediate: int
scenario: Literal["2Dx3D", "2Dx2D"]
ab_dtype: torch.dtype
out_dtype: torch.dtype
acc_dtype: torch.dtype
grad_accumulate: bool = False
# GEMM-domain layout control (which axis is stride-1)
# A (M, K): "k_major" (default) or "m_major"
# B (N, K): "n_major" (default) or "k_major"
# C (M, N): "n_major" (default) or "m_major"
a_layout: Literal["k_major", "m_major"] = "k_major"
b_layout: Literal["k_major", "n_major"] = "n_major"
c_layout: Literal["m_major", "n_major"] = "n_major"
def __str__(self) -> str:
d = lambda t: str(t).split(".")[-1]
route = "balanced" if self.balance_route else "random"
return (
f"ProblemDesc: {self.scenario} | tokens={self.tokens} experts={self.experts} "
f"top_k={self.top_k_select} route={route} | hidden={self.hidden} intermediate={self.intermediate} | "
f"{d(self.ab_dtype)}->{d(self.out_dtype)}(acc={d(self.acc_dtype)}) grad_acc={self.grad_accumulate} | "
f"layout: A={self.a_layout} B={self.b_layout} C={self.c_layout}"
)
@dataclass
class ImplDesc:
mma_tiler_mnk: Tuple[int, int, int]
cluster_shape_mnk: Tuple[int, int, int]
use_2cta_instrs: bool
static_expert_cnt: Optional[int] = None
separate_tensormap_init: bool = True
def __str__(self) -> str:
tile = ",".join(map(str, self.mma_tiler_mnk))
cluster = ",".join(map(str, self.cluster_shape_mnk))
static_e = (
self.static_expert_cnt if self.static_expert_cnt is not None else "dynamic"
)
return (
f"ImplDesc: tile={tile} cluster={cluster} 2cta={self.use_2cta_instrs} | "
f"static_E={static_e} sep_tmap={self.separate_tensormap_init}"
)
@dataclass
class MiscDesc:
perf_run: bool = False
perf_e2e: bool = False
compare_with_bmm: bool = False
compare_with_sol: bool = False
no_torch_210: bool = field(init=False)
def __post_init__(self):
self.no_torch_210 = torch_version_lt(2, 10)
if self.perf_e2e and not self.perf_run:
raise ValueError("--perf_e2e requires --perf_run to be enabled.")
if self.perf_e2e and self.compare_with_sol:
raise ValueError(
"--perf_e2e and --compare_with_sol are mutually exclusive."
)
def __str__(self) -> str:
ref = "bmm" if self.compare_with_bmm else "grouped_mm"
return (
f"MiscDesc: perf={self.perf_run} perf_e2e={self.perf_e2e} "
f"ref={ref} sol={self.compare_with_sol} no_torch_210={self.no_torch_210}"
)
def l2_flush(size_mb: int = 400) -> None:
"""Best-effort L2 flush by touching a large temporary tensor."""
num_bytes = size_mb * 1024 * 1024
flush_buf = torch.randint(0, 256, (num_bytes,), dtype=torch.uint8, device="cuda")
del flush_buf
class GroupedGemmTester:
def __init__(self, problem: ProblemDesc, impl: ImplDesc, misc: MiscDesc):
self.problem = problem
self.impl = impl
self.misc = misc
self.tokens_after_repeat = problem.tokens * problem.top_k_select
self.expert_cnt = problem.experts
self.hidden = problem.hidden
self.intermediate = problem.intermediate
self.A_tensor: torch.Tensor = None
self.B_tensor: torch.Tensor = None
self.C_tensor: torch.Tensor = None
self.C_ref_tensor: torch.Tensor = None
self.offs_tensor: torch.Tensor = None
self.workspace_tensor: torch.Tensor = None
# This should be a common func
self.temp_type_mapping = {
torch.float32: cutlass.Float32,
torch.bfloat16: cutlass.BFloat16,
torch.float16: cutlass.Float16,
}
def _generate_offs(self) -> torch.Tensor:
"""Generate group-end offsets.
Some experts may receive 0 tokens (valid in real MoE routing).
"""
total = self.tokens_after_repeat
expert_cnt = self.expert_cnt
if self.problem.balance_route:
base = total // expert_cnt
remainder = total % expert_cnt
sizes = [base + (1 if i < remainder else 0) for i in range(expert_cnt)]
else:
proportions = np.random.dirichlet([0.5] * expert_cnt)
raw = np.floor(proportions * total).astype(int)
deficit = total - raw.sum()
while deficit > 0:
idx = int(np.argmin(raw / (proportions * total + 1e-12)))
raw[idx] += 1
deficit -= 1
while deficit < 0:
ratios = np.where(
raw > 0,
raw / (proportions * total + 1e-12),
-np.inf,
)
idx = int(np.argmax(ratios))
raw[idx] -= 1
deficit += 1
sizes = raw.tolist()
assert sum(sizes) == total
cum = 0
offsets = []
for s in sizes:
cum += s
offsets.append(cum)
return torch.tensor(offsets, dtype=torch.int32, device="cuda")
def _generate_tensor(self, shape: Tuple) -> torch.Tensor:
if self.misc.perf_run:
return torch.randn(shape, dtype=self.problem.ab_dtype, device="cuda")
else:
return torch.randint(-1, 2, shape, device="cuda", dtype=torch.int8).to(
self.problem.ab_dtype
)
def _get_stream(self) -> cuda.CUstream:
return cuda.CUstream(torch.cuda.current_stream().cuda_stream)
def generate_inputs(self) -> None:
self.offs_tensor = self._generate_offs()
tokens = self.tokens_after_repeat
hidden = self.hidden
intermediate = self.intermediate
expert_cnt = self.expert_cnt
if self.problem.scenario == "2Dx3D":
# PyTorch shape: A (tokens, hidden), B (expert_cnt, hidden, intermediate), C (tokens, intermediate)
# GEMM domain: A (M=tokens, K=hidden), B (N=intermediate, K=hidden), C (M=tokens, N=intermediate)
# GEMM A: k_major → K(hidden) stride-1; m_major → M(tokens) stride-1
if self.problem.a_layout == "k_major":
self.A_tensor = self._generate_tensor((tokens, hidden))
else:
self.A_tensor = self._generate_tensor((hidden, tokens)).T
# GEMM B: n_major → N(intermediate) stride-1; k_major → K(hidden) stride-1
if self.problem.b_layout == "n_major":
self.B_tensor = self._generate_tensor(
(expert_cnt, hidden, intermediate)
)
else:
self.B_tensor = self._generate_tensor(
(expert_cnt, intermediate, hidden)
).transpose(1, 2)
# GEMM C: n_major → N(intermediate) stride-1; m_major → M(tokens) stride-1
if self.problem.c_layout == "n_major":
self.C_tensor = torch.full(
(tokens, intermediate),
-1,
dtype=self.problem.out_dtype,
device="cuda",
)
else:
self.C_tensor = torch.full(
(intermediate, tokens),
-1,
dtype=self.problem.out_dtype,
device="cuda",
).T
elif self.problem.scenario == "2Dx2D":
# PyTorch shape: mat_a (hidden, tokens), mat_b (tokens, intermediate), out (expert_cnt, hidden, intermediate)
# out matches weight shape (expert_cnt, hidden, intermediate) for weight gradient
# GEMM domain: A (M=hidden, K=tokens), B (N=intermediate, K=tokens), C (M=hidden, N=intermediate)
# GEMM A: k_major → K(tokens) stride-1; m_major → M(hidden) stride-1
if self.problem.a_layout == "k_major":
self.A_tensor = self._generate_tensor((hidden, tokens))
else:
self.A_tensor = self._generate_tensor((tokens, hidden)).T
# GEMM B: n_major → N(intermediate) stride-1; k_major → K(tokens) stride-1
if self.problem.b_layout == "n_major":
self.B_tensor = self._generate_tensor((tokens, intermediate))
else:
self.B_tensor = self._generate_tensor((intermediate, tokens)).T
# GEMM C: n_major → N(intermediate) stride-1; m_major → M(hidden) stride-1
if self.problem.c_layout == "n_major":
self.C_tensor = torch.full(
(expert_cnt, hidden, intermediate),
-1,
dtype=self.problem.out_dtype,
device="cuda",
)
else:
self.C_tensor = torch.full(
(expert_cnt, intermediate, hidden),
-1,
dtype=self.problem.out_dtype,
device="cuda",
).transpose(1, 2)
if self.problem.grad_accumulate:
self.C_tensor *= 0
else:
raise ValueError(f"Unknown scenario: {self.problem.scenario}")
def compute_reference(self) -> None:
if self.misc.perf_run:
return
if self.misc.compare_with_bmm:
self._compute_reference_bmm()
else:
self._compute_reference_grouped_mm()
def _compute_reference_grouped_mm(self) -> None:
grouped_mm_op = (
torch._grouped_mm
if self.misc.no_torch_210
else torch.nn.functional.grouped_mm
)
self.C_ref_tensor = grouped_mm_op(
self.A_tensor,
self.B_tensor,
offs=self.offs_tensor,
out_dtype=self.problem.out_dtype,
)
def _compute_reference_bmm(self) -> None:
"""Manual per-expert torch.mm loop as reference (avoids grouped_mm bugs on small cases)."""
# Preallocate the full reference output to avoid keeping both the per-expert
# results list and the final cat/stack result alive at the same time.
self.C_ref_tensor = torch.empty_like(self.C_tensor)
prev = 0
for i in range(self.expert_cnt):
cur = self.offs_tensor[i].item()
if self.problem.scenario == "2Dx3D":
# A (tokens, hidden), B (E, hidden, intermediate) → C_i (tokens_i, intermediate)
a_slice = self.A_tensor[prev:cur, :]
b_slice = self.B_tensor[i]
self.C_ref_tensor[prev:cur, :].copy_(torch.mm(a_slice, b_slice))
else: # 2Dx2D
# A (hidden, tokens), B (tokens, intermediate) → C_i (hidden, intermediate)
a_slice = self.A_tensor[:, prev:cur]
b_slice = self.B_tensor[prev:cur, :]
self.C_ref_tensor[i, :, :].copy_(torch.mm(a_slice, b_slice))
prev = cur
def create_kernel(self) -> GroupedGemmKernel:
return GroupedGemmKernel(
scenario=self.problem.scenario,
out_dtype=self.temp_type_mapping[self.problem.out_dtype],
accumulate_on_output=self.problem.grad_accumulate
and self.problem.scenario == "2Dx2D",
separate_tensormap_init=self.impl.separate_tensormap_init,
fixed_expert_cnt=self.impl.static_expert_cnt,
acc_dtype=self.temp_type_mapping[self.problem.acc_dtype],
mma_tiler_mnk=self.impl.mma_tiler_mnk,
cluster_shape_mnk=self.impl.cluster_shape_mnk,
use_2cta_instrs=self.impl.use_2cta_instrs,
)
def run_kernel(self, kernel: GroupedGemmKernel) -> Optional[float]:
"""Run our CuTe kernel.
Returns:
Average kernel time in ms when perf_e2e is enabled, None otherwise.
"""
workspace_size = kernel.get_workspace_size(self.expert_cnt)
self.workspace_tensor = torch.full(
(workspace_size,), 255, dtype=torch.uint8, device="cuda"
)
torch.cuda.synchronize()
ab_cutlass_dtype = self.temp_type_mapping[self.problem.ab_dtype]
out_cutlass_dtype = self.temp_type_mapping[self.problem.out_dtype]
a_cute, self.A_tensor = cutlass_torch.cute_tensor_like(
self.A_tensor, ab_cutlass_dtype, is_dynamic_layout=True, assumed_align=16
)
b_cute, self.B_tensor = cutlass_torch.cute_tensor_like(
self.B_tensor, ab_cutlass_dtype, is_dynamic_layout=True, assumed_align=16
)
c_cute, self.C_tensor = cutlass_torch.cute_tensor_like(
self.C_tensor, out_cutlass_dtype, is_dynamic_layout=True, assumed_align=16
)
is_dynamic_expert_cnt = self.impl.static_expert_cnt is None
offs_cute, self.offs_tensor = cutlass_torch.cute_tensor_like(
self.offs_tensor,
cutlass.Int32,
is_dynamic_layout=is_dynamic_expert_cnt,
assumed_align=16,
)
workspace_cute, self.workspace_tensor = cutlass_torch.cute_tensor_like(
self.workspace_tensor,
cutlass.Uint8,
is_dynamic_layout=is_dynamic_expert_cnt,
assumed_align=128,
)
# Query max active clusters from hardware
cluster_size = self.impl.cluster_shape_mnk[0] * self.impl.cluster_shape_mnk[1]
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size)
print(f"A_tensor: {tuple(self.A_tensor.shape)}:{self.A_tensor.stride()}")
print(f"B_tensor: {tuple(self.B_tensor.shape)}:{self.B_tensor.stride()}")
print(
f"offset_tensor: {tuple(self.offs_tensor.shape)}:{self.offs_tensor.stride()}"
)
print(f"C_tensor: {tuple(self.C_tensor.shape)}:{self.C_tensor.stride()}")
stream = self._get_stream()
if self.misc.perf_e2e:
compiled = cute.compile(
kernel,
a_cute,
b_cute,
c_cute,
offs_cute,
None, # bias
workspace_cute,
max_active_clusters,
stream,
)
warmup_iters = 4
timed_iters = 4
for _ in range(warmup_iters):
l2_flush()
compiled(
a_cute,
b_cute,
c_cute,
offs_cute,
None, # bias
workspace_cute,
stream,
)
torch.cuda.synchronize()
times = []
for _ in range(timed_iters):
l2_flush()
torch.cuda.synchronize()
start_evt = torch.cuda.Event(enable_timing=True)
end_evt = torch.cuda.Event(enable_timing=True)
start_evt.record()
compiled(
a_cute,
b_cute,
c_cute,
offs_cute,
None, # bias
workspace_cute,
stream,
)
end_evt.record()
torch.cuda.synchronize()
times.append(start_evt.elapsed_time(end_evt))
avg_ms = sum(times) / len(times)
print(f"[perf_e2e] Individual times (ms): {[f'{t:.4f}' for t in times]}")
print(f"[perf_e2e] Average kernel time: {avg_ms:.4f} ms")
return avg_ms
else:
l2_flush()
kernel(
a_cute,
b_cute,
c_cute,
offs_cute,
None, # bias
workspace_cute,
max_active_clusters,
stream,
)
torch.cuda.synchronize()
return None
def validate(self) -> None:
if not self.misc.perf_run:
assert torch.equal(self.C_tensor, self.C_ref_tensor), (
"Validation failed: C_tensor != C_ref_tensor"
)
def run_sol_comparison(self) -> None:
"""Run a dense batched GEMM as Speed-of-Light reference.
Reuses the same tensor memory from the grouped run by
view/reshape/permute -- zero GPU allocation.
"""
import sys, os
_examples_root = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "..", "..")
)
if _examples_root not in sys.path:
sys.path.insert(0, _examples_root)
from blackwell.kernel.dense_gemm.dense_gemm_persistent import (
PersistentDenseGemmKernel,
)
tokens = self.tokens_after_repeat
experts = self.expert_cnt
assert tokens % experts == 0, (
f"compare_with_sol requires tokens*top_k ({tokens}) "
f"evenly divisible by experts ({experts}) so every group "
f"has exactly the same size"
)
tpe = tokens // experts
if self.problem.scenario == "2Dx3D":
M, N, K, L = tpe, self.intermediate, self.hidden, experts
else: # 2Dx2D
M, N, K, L = self.hidden, self.intermediate, tpe, experts
# Reshape into GEMM-domain batch-last: A(M,K,L), B(N,K,L), C(M,N,L).
# Data values are irrelevant (perf only) — just need correct shape
# and stride pattern so the dense kernel sees the right major mode.
if self.problem.a_layout == "k_major":
a_sol = self.A_tensor.contiguous().view(L, M, K).permute(1, 2, 0)
leading_dim_a = 1
else:
a_sol = self.A_tensor.contiguous().view(L, K, M).permute(2, 1, 0)
leading_dim_a = 0
if self.problem.b_layout == "n_major":
b_sol = self.B_tensor.contiguous().view(L, K, N).permute(2, 1, 0)
leading_dim_b = 0
else:
b_sol = self.B_tensor.contiguous().view(L, N, K).permute(1, 2, 0)
leading_dim_b = 1
if self.problem.c_layout == "n_major":
c_sol = self.C_tensor.contiguous().view(L, M, N).permute(1, 2, 0)
leading_dim_c = 1
else:
c_sol = self.C_tensor.contiguous().view(L, N, M).permute(2, 1, 0)
leading_dim_c = 0
from cutlass.cute.runtime import from_dlpack
a_cute_sol = from_dlpack(a_sol, assumed_align=16).mark_layout_dynamic(
leading_dim=leading_dim_a
)
b_cute_sol = from_dlpack(b_sol, assumed_align=16).mark_layout_dynamic(
leading_dim=leading_dim_b
)
c_cute_sol = from_dlpack(c_sol, assumed_align=16).mark_layout_dynamic(
leading_dim=leading_dim_c
)
mma_tiler_mn = self.impl.mma_tiler_mnk[:2]
cluster_shape_mn = self.impl.cluster_shape_mnk[:2]
cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1]
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size)
sol_kernel = PersistentDenseGemmKernel(
acc_dtype=self.temp_type_mapping[self.problem.acc_dtype],
use_2cta_instrs=self.impl.use_2cta_instrs,
mma_tiler_mn=mma_tiler_mn,
cluster_shape_mn=cluster_shape_mn,
use_tma_store=True,
)
print(f"\n[SOL] Dense BMM: M={M} N={N} K={K} L={L}")
print(f"[SOL] a_sol: {tuple(a_sol.shape)}:{a_sol.stride()}")
print(f"[SOL] b_sol: {tuple(b_sol.shape)}:{b_sol.stride()}")
print(f"[SOL] c_sol: {tuple(c_sol.shape)}:{c_sol.stride()}")
l2_flush()
sol_kernel(
a_cute_sol,
b_cute_sol,
c_cute_sol,
max_active_clusters,
self._get_stream(),
)
torch.cuda.synchronize()
def run(self) -> None:
from torch.profiler import profile, ProfilerActivity
print(self.problem)
print(self.impl)
print(self.misc)
self.generate_inputs()
kernel = self.create_kernel()
if self.misc.perf_e2e:
self.run_kernel(kernel)
else:
with profile(
activities=[ProfilerActivity.CUDA], record_shapes=True
) as prof:
self.compute_reference()
self.run_kernel(kernel)
if (
self.misc.compare_with_sol
and self.misc.perf_run
and self.problem.balance_route
):
self.run_sol_comparison()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
self.validate()
if __name__ == "__main__":
import argparse
def parse_dtype(s: str) -> torch.dtype:
return getattr(torch, s)
def parse_tuple(s: str) -> Tuple[int, ...]:
return tuple(int(x) for x in s.split(","))
parser = argparse.ArgumentParser()
parser.add_argument("--tokens", type=int, default=128)
parser.add_argument("--experts", type=int, default=128)
parser.add_argument("--top_k_select", type=int, default=8)
parser.add_argument("--balance_route", action="store_true", default=False)
parser.add_argument("--hidden", type=int, default=2048)
parser.add_argument("--intermediate", type=int, default=7168)
parser.add_argument(
"--scenario", type=str, default="2Dx3D", choices=["2Dx3D", "2Dx2D"]
)
parser.add_argument("--ab_dtype", type=str, default="bfloat16")
parser.add_argument("--out_dtype", type=str, default="bfloat16")
parser.add_argument("--acc_dtype", type=str, default="float32")
parser.add_argument("--grad_accumulate", action="store_true", default=False)
parser.add_argument(
"--a_layout", type=str, default="k_major", choices=["k_major", "m_major"]
)
parser.add_argument(
"--b_layout", type=str, default="n_major", choices=["k_major", "n_major"]
)
parser.add_argument(
"--c_layout", type=str, default="n_major", choices=["m_major", "n_major"]
)
parser.add_argument("--mma_tiler_mnk", type=str, default="128,128,64")
parser.add_argument("--cluster_shape_mnk", type=str, default="1,1,1")
parser.add_argument("--use_2cta_instrs", action="store_true", default=False)
parser.add_argument("--static_expert_cnt", type=int, default=None)
parser.add_argument("--separate_tensormap_init", action="store_true", default=False)
parser.add_argument("--perf_run", action="store_true", default=False)
parser.add_argument("--perf_e2e", action="store_true", default=False)
parser.add_argument("--compare_with_bmm", action="store_true", default=False)
parser.add_argument("--compare_with_sol", action="store_true", default=False)
args = parser.parse_args()
problem = ProblemDesc(
tokens=args.tokens,
experts=args.experts,
top_k_select=args.top_k_select,
balance_route=args.balance_route,
hidden=args.hidden,
intermediate=args.intermediate,
scenario=args.scenario,
ab_dtype=parse_dtype(args.ab_dtype),
out_dtype=parse_dtype(args.out_dtype),
acc_dtype=parse_dtype(args.acc_dtype),
grad_accumulate=args.grad_accumulate,
a_layout=args.a_layout,
b_layout=args.b_layout,
c_layout=args.c_layout,
)
if not args.separate_tensormap_init:
print(
"Change separate_tensormap_init to True as current the fused version not implmented yet."
)
args.separate_tensormap_init = True
impl = ImplDesc(
mma_tiler_mnk=parse_tuple(args.mma_tiler_mnk),
cluster_shape_mnk=parse_tuple(args.cluster_shape_mnk),
use_2cta_instrs=args.use_2cta_instrs,
static_expert_cnt=args.static_expert_cnt,
separate_tensormap_init=args.separate_tensormap_init,
)
misc = MiscDesc(
perf_run=args.perf_run,
perf_e2e=args.perf_e2e,
compare_with_bmm=args.compare_with_bmm,
compare_with_sol=args.compare_with_sol,
)
if misc.no_torch_210:
misc.compare_with_bmm = True
print("Override to set --compare_with_bmm to avoid possible torch crash.")
tester = GroupedGemmTester(problem, impl, misc)
tester.run()
print("PASS")