From a2ea836c74f1d1e4070f7d1092ebf7cb5c360686 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 16 May 2026 02:41:51 +0000 Subject: [PATCH] docs: add CuTeDSL rewrite plan + reference files MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The C++ CUTLASS kernel is fundamentally broken (cosine 0.05 with real data). Switching to NVIDIA's CuTeDSL approach based on their official MoE scaled grouped GEMM example. Reference files copied: - moe_torch_scaled_grouped_mm.py (3900 lines — our new kernel) - moe_utils.py, moe_persistent_scheduler.py, moe_sched_extension.py - grouped_blockscaled_gemm.py, dense_blockscaled_gemm_persistent.py - blockscaled_layout.py --- REWRITE_PLAN.md | 93 + reference/blockscaled_layout.py | 657 +++ .../dense_blockscaled_gemm_persistent.py | 3152 +++++++++++++ reference/grouped_blockscaled_gemm.py | 3278 ++++++++++++++ reference/moe_moe_persistent_scheduler.py | 695 +++ reference/moe_moe_sched_extension.py | 443 ++ reference/moe_moe_utils.py | 910 ++++ reference/moe_torch_grouped_mm.py | 2019 +++++++++ reference/moe_torch_scaled_grouped_mm.py | 3901 +++++++++++++++++ 9 files changed, 15148 insertions(+) create mode 100644 REWRITE_PLAN.md create mode 100644 reference/blockscaled_layout.py create mode 100644 reference/dense_blockscaled_gemm_persistent.py create mode 100644 reference/grouped_blockscaled_gemm.py create mode 100644 reference/moe_moe_persistent_scheduler.py create mode 100644 reference/moe_moe_sched_extension.py create mode 100644 reference/moe_moe_utils.py create mode 100644 reference/moe_torch_grouped_mm.py create mode 100644 reference/moe_torch_scaled_grouped_mm.py diff --git a/REWRITE_PLAN.md b/REWRITE_PLAN.md new file mode 100644 index 00000000..dc8ab005 --- /dev/null +++ b/REWRITE_PLAN.md @@ -0,0 +1,93 @@ +# NVFP4 MegaMoE Kernel Rewrite — CuTeDSL + +## Decision + +The C++ CUTLASS kernel is fundamentally broken. The GEMM produces cosine 0.05 +with real data (despite SF remap = 0 errors and uniform tests passing). The +root cause is likely in how CUTLASS's C++ API handles FP4 packing/tiling +internally — something we can't easily debug or fix. + +We're replacing it with NVIDIA's CuTeDSL approach (Python-based CUTLASS +kernels compiled via MLIR → PTX). This is what the NVIDIA CUTLASS team +recommends for Blackwell, and they have a working reference: + +`cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/moe/torch_scaled_grouped_mm.py` + +This is a 3900-line MoE scaled grouped GEMM that supports NVFP4 out of the box. + +## Reference Files (copied to reference/) + +- `reference/moe_torch_scaled_grouped_mm.py` — the kernel (3900 lines) +- `reference/moe_moe_utils.py` — tensormap constructor +- `reference/moe_moe_persistent_scheduler.py` — MoE tile scheduler +- `reference/moe_moe_sched_extension.py` — scheduler extension for scaled GEMM +- `reference/grouped_blockscaled_gemm.py` — simpler grouped blockscaled reference +- `reference/dense_blockscaled_gemm_persistent.py` — dense (non-grouped) reference +- `reference/blockscaled_layout.py` — SF layout utilities + +## Architecture + +### The Reference Kernel: ScaledGroupedGemmKernel + +7-warp persistent kernel: +- Warps 0-3: Epilogue (TMEM → RMEM → SMEM → GMEM) +- Warp 4: MMA (tcgen05.mma.block_scale with SFA/SFB in TMEM) +- Warp 5: TMA load (A, B, SFA, SFB from GMEM → SMEM) +- Warp 6: Scheduler (MoE persistent tile scheduler) + +Supports: +- NVFP4 (Float4E2M1FN + Float8E4M3FN + sf_vec_size=16) +- 2Dx3D scenario (our case): A(tokens, K) x B(experts, K, N) → C(tokens, N) +- PyTorch interface via torch.nn.functional.scaled_grouped_mm + +### Our Adaptation + +We need to: +1. Use ScaledGroupedGemmKernel directly (or with minimal adaptation) +2. Replace our Python pipeline (nvfp4_mega_moe.py) to call the CuTeDSL kernel +3. Handle the MoE-specific stuff (routing, gate/up fusion, SiLU, scatter) + +The CuTeDSL kernel handles A/B/SFA/SFB tiling and MMA internally. +We NO LONGER need: +- Our custom SF remap kernel (CuTeDSL handles SF layouts natively) +- Our C++ CUTLASS extension (cutlass_nvfp4_gemm.cu) +- transform_nvfp4_weights_for_mega_moe (CuTeDSL uses the natural layout) + +We STILL need: +- stage_activation (BF16 → FP4 quantization) +- The MoE routing logic (top-k selection, slot mapping) +- Gate/up fusion with up_correction +- SiLU activation +- Scatter with routing weights + +## Plan + +### Phase 1: Get the reference kernel running standalone +- Set up CuTeDSL build environment on B200 +- Run the reference example with NVFP4 config +- Verify it produces correct output + +### Phase 2: Integrate into our pipeline +- Create a Python module that wraps ScaledGroupedGemmKernel +- Replace cutlass_nvfp4_gemm calls with CuTeDSL kernel calls +- Handle weight format (the CuTeDSL kernel expects raw FP4 + SF, no transpose) + +### Phase 3: Full MoE pipeline +- Wire up stage_activation → L1 GEMM → SiLU → stage_activation → L2 GEMM → scatter +- Test with layertest.py (should get cosine ~0.995) + +### Phase 4: vLLM integration +- Update the vLLM patch to use the CuTeDSL kernel +- Remove the C++ extension build from the Dockerfile +- Test full inference + +## Key Differences from Current Approach + +| Current (C++ CUTLASS) | New (CuTeDSL) | +|----------------------|---------------| +| C++ .cu file | Python .py file | +| Manual SF remap kernel | CuTe handles SF natively | +| Manual weight transpose | CuTe handles layout | +| pip install C++ extension | cute.compile() JIT | +| Broken FP4 handling | Reference-verified | +| No Blackwell pipeline | Full TMA/MMA/Epilogue overlap | diff --git a/reference/blockscaled_layout.py b/reference/blockscaled_layout.py new file mode 100644 index 00000000..7bc4ee6f --- /dev/null +++ b/reference/blockscaled_layout.py @@ -0,0 +1,657 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +from dataclasses import dataclass, field +from typing import Optional + +from cutlass.cutlass_dsl import dsl_user_op + +import cutlass.cute as cute +from cutlass.cute.nvgpu import OperandMajorMode + +from cutlass._mlir import ir +import cutlass._mlir.dialects.cute as _cute_ir +import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir + + +@dataclass(frozen=True) +class BlockScaledBasicChunk: + """ + The basic scale factor atom layout decided by tcgen05 BlockScaled MMA Ops. + + This class represents the fixed layout pattern for scale factors used in + tcgen05 BlockScaled MMA Ops. The layout is determined by the + instruction specification and cannot be modified. + See `PTX documentation `. + """ + + sf_vec_size: int + major_mode: OperandMajorMode = OperandMajorMode.K + _layout: cute.Layout = field(init=False, repr=False) + + def __post_init__(self) -> None: + if self.major_mode == OperandMajorMode.K: + # K-major layout: (AtomMN, AtomK) + atom_shape = ((32, 4), (self.sf_vec_size, 4)) + atom_stride = ((16, 4), (0, 1)) + else: + # MN-major layout: (AtomK, AtomMN) + atom_shape = ((self.sf_vec_size, 4), (32, 4)) + atom_stride = ((0, 1), (16, 4)) + + object.__setattr__( + self, "_layout", cute.make_layout(atom_shape, stride=atom_stride) + ) + + @property + def layout(self) -> cute.Layout: + """ + Get the layout for this block scaled chunk. + + :return: The layout representing the scale factor atom + :rtype: cute.Layout + """ + return self._layout + + +@dsl_user_op +def tile_atom_to_shape_SF( + Shape: cute.Shape, + sf_vec_size: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """ + A helper function to get dynamic SFA/SFB layout by filling dynamic A/B shape to the scale factor atom layout. + + :param Shape: The shape of the A/B tensor + :param sf_vec_size: Scale factor vector size + + :return: The layout of the SFA/SFB tensor + :rtype: cute.Layout + """ + # ((Atom_MN, Rest_MN),(Atom_K, Rest_K),RestL) + sf_layout = cute.tile_to_shape( + BlockScaledBasicChunk(sf_vec_size).layout, Shape, (2, 1, 3) + ) + return sf_layout + + +@dsl_user_op +def make_smem_layout_sf( + tile_shape: cute.Tile, + sf_vec_size: int, + num_stages: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """ + A helper function to get dynamic SFA/SFB layout by filling dynamic A/B shape to the scale factor atom layout. + + :param Shape: The shape of the A/B tensor + :param sf_vec_size: Scale factor vector size + :param num_stages: Number of stages + + :return: The layout of the SFA/SFB tensor + :rtype: cute.Layout + """ + + smem_layout = cute.tile_to_shape( + BlockScaledBasicChunk(sf_vec_size).layout, + tile_shape, # type: ignore[arg-type] + (2, 1), + loc=loc, + ip=ip, + ) + smem_layout_staged = cute.append( + smem_layout, + cute.make_layout( + num_stages, + stride=cute.cosize(cute.filter_zeros(smem_layout)), + loc=loc, + ip=ip, + ), + loc=loc, + ip=ip, + ) + return smem_layout_staged + + +@dsl_user_op +def make_smem_layout_sfa( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: cute.Tile, + sf_vec_size: int, + num_stages: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """ + Make smem layout for SFA based on: + + 1. BlockScaledBasicChunk + 2. MMA tiler shape + 3. Scale factor vector size + 4. Number of stages + + :param tiled_mma: The tiled MMA + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The mma tiler shape + :type mma_tiler_mnk: cute.Tile + :param sf_vec_size: The scale factor vector size + :type sf_vec_size: int + :param num_stages: The number of stages + :type num_stages: int + + :return: Smem layout for SFA + :rtype: cute.Layout + """ + # (CTA_Tile_Shape_M, MMA_Tile_Shape_K) + sfa_tile_shape = ( + mma_tiler_mnk[0] // cute.size(tiled_mma.thr_id.shape), # type: ignore[index] + mma_tiler_mnk[2], # type: ignore[index] + ) + + # ((Atom_M, Rest_M),(Atom_K, Rest_K)) + smem_layout = cute.tile_to_shape( + BlockScaledBasicChunk(sf_vec_size).layout, + sfa_tile_shape, # type: ignore[arg-type] + (2, 1), + ) + + # Number of MMA instructions to cover all k-tiles + mma_tile_inst_m = mma_tiler_mnk[0] // cute.size(tiled_mma.shape_mnk, mode=[0]) # type: ignore[index] + mma_tile_inst_k = mma_tiler_mnk[2] // cute.size(tiled_mma.shape_mnk, mode=[2]) # type: ignore[index] + + # (CTA_Tile_Shape_M, MMA_Inst_Shape_K) + sfa_tile_shape = cute.shape_div(sfa_tile_shape, (mma_tile_inst_m, mma_tile_inst_k)) + # ((Atom_Inst_M, Atom_Inst_K), MMA_M, MMA_K)) + smem_layout = cute.tiled_divide(smem_layout, sfa_tile_shape) + + atom_m = 128 + tiler_inst = ((atom_m, sf_vec_size),) + # (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K) + smem_layout = cute.logical_divide(smem_layout, tiler_inst) + + # (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K, STAGE) + sfa_smem_layout_staged = cute.append( + smem_layout, + cute.make_layout( + num_stages, stride=cute.cosize(cute.filter_zeros(smem_layout)) + ), + ) + + return sfa_smem_layout_staged + + +@dsl_user_op +def make_smem_layout_sfb( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: cute.Tile, + sf_vec_size: int, + num_stages: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """ + Make smem layout for SFB based on: + + 1. BlockScaledBasicChunk + 2. MMA tiler shape + 3. Scale factor vector size + 4. Number of stages + + :param tiled_mma: The tiled MMA + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The mma tiler shape + :type mma_tiler_mnk: cute.Tile + :param sf_vec_size: The scale factor vector size + :type sf_vec_size: int + :param num_stages: The number of stages + :type num_stages: int + + :return: Smem layout for SFA + :rtype: cute.Layout + """ + # (Round_Up(CTA_Tile_Shape_N, 128), MMA_Tile_Shape_K) + sfb_tile_shape = ( + cute.round_up(mma_tiler_mnk[1], 128), # type: ignore[index, arg-type] + mma_tiler_mnk[2], # type: ignore[index] + ) + + # ((Atom_N, Rest_N),(Atom_K, Rest_K)) + smem_layout = cute.tile_to_shape( + BlockScaledBasicChunk(sf_vec_size).layout, + sfb_tile_shape, # type: ignore[arg-type] + (2, 1), + ) + + # Number of MMA instructions to cover all k-tiles + mma_tile_inst_n = mma_tiler_mnk[1] // cute.size(tiled_mma.shape_mnk, mode=[1]) # type: ignore[index] + mma_tile_inst_k = mma_tiler_mnk[2] // cute.size(tiled_mma.shape_mnk, mode=[2]) # type: ignore[index] + + # (CTA_Tile_Shape_N, MMA_Inst_Shape_K) + sfb_tile_shape = cute.shape_div(sfb_tile_shape, (mma_tile_inst_n, mma_tile_inst_k)) + # ((Atom_Inst_N, Atom_Inst_K), MMA_N, MMA_K) + smem_layout = cute.tiled_divide(smem_layout, sfb_tile_shape) + + atom_n = 128 + tiler_inst = ((atom_n, sf_vec_size),) + # (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K) + smem_layout = cute.logical_divide(smem_layout, tiler_inst) + + # (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K, STAGE) + sfb_smem_layout_staged = cute.append( + smem_layout, + cute.make_layout( + num_stages, stride=cute.cosize(cute.filter_zeros(smem_layout)) + ), + ) + + return sfb_smem_layout_staged + + +@dsl_user_op +def sm120_make_smem_layout_sfa( + tiled_mma: cute.TiledMma, + tile_shape_mnk: cute.Tile, + sf_vec_size: int, + num_stages: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """ + Make smem layout for SFA based on: + 1. BlockScaledBasicChunk + 2. MMA tiler shape + 3. Scale factor vector size + 4. Number of stages + + :param tiled_mma: The tiled MMA + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The mma tiler shape + :type mma_tiler_mnk: cute.Tile + :param sf_vec_size: The scale factor vector size + :type sf_vec_size: int + :param num_stages: The number of stages + :type num_stages: int + + :return: Smem layout for SFA + :rtype: cute.Layout + """ + + assert sf_vec_size == 16 or sf_vec_size == 32, "sf_vec_size must be 16 or 32" + + blk_mn = 128 + blk_sf = 4 + blk_elems = blk_mn * blk_sf + mma_nsf = tiled_mma.shape_mnk[2] // sf_vec_size + + mn_basic_block_shape = (32, 4) + mn_basic_block_stride = (16, 4) + k_basic_block_shape = (sf_vec_size, mma_nsf) + k_basic_block_stride = (0, 1) + + assert tile_shape_mnk[0] % blk_mn == 0, ( # type: ignore[index, operator] + "tile_shape_mnk[0] must be divisible by blk_mn" + ) + + sSFA_shapeM = (mn_basic_block_shape, tile_shape_mnk[0] // blk_mn) # type: ignore[index, operator] + sSF_strideM = (mn_basic_block_stride, blk_elems) + + assert tile_shape_mnk[2] % (blk_sf * mma_nsf) == 0, ( # type: ignore[index] + "tile_shape_mnk[2] must be divisible by blk_sf * mma_nsf" + ) + + sSFA_shapeK = ( + k_basic_block_shape, + blk_sf // mma_nsf, + tile_shape_mnk[2] // sf_vec_size // blk_sf, # type: ignore[index, operator] + ) + sSF_strideK = ( + k_basic_block_stride, + mma_nsf, + tile_shape_mnk[0] // blk_mn * blk_elems, # type: ignore[index, operator] + ) + + sSFA_shape = (sSFA_shapeM, sSFA_shapeK) + sSFA_stride = (sSF_strideM, sSF_strideK) + + smem_layout = cute.make_layout(sSFA_shape, stride=sSFA_stride) + + # (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K, STAGE) + sfa_smem_layout_staged = cute.append( + smem_layout, + cute.make_layout( + num_stages, stride=cute.cosize(cute.filter_zeros(smem_layout)) + ), + ) + + return sfa_smem_layout_staged + + +@dsl_user_op +def sm120_make_smem_layout_sfb( + tiled_mma: cute.TiledMma, + tile_shape_mnk: cute.Tile, + sf_vec_size: int, + num_stages: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """ + Make smem layout for SFB based on: + 1. BlockScaledBasicChunk + 2. MMA tiler shape + 3. Scale factor vector size + 4. Number of stages + + :param tiled_mma: The tiled MMA + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The mma tiler shape + :type mma_tiler_mnk: cute.Tile + :param sf_vec_size: The scale factor vector size + :type sf_vec_size: int + :param num_stages: The number of stages + :type num_stages: int + + :return: Smem layout for SFA + :rtype: cute.Layout + """ + + # A single indivisible block will hold 4 scale factors of 128 rows/columns (A/B matrix). + # 4 is chosen to make consecutive 32bits of data to have scale factors for only a single row(col). + blk_mn = 128 + blk_sf = 4 + blk_elems = blk_mn * blk_sf + + assert sf_vec_size == 16 or sf_vec_size == 32, "sf_vec_size must be 16 or 32" + + assert tile_shape_mnk[1] % blk_mn == 0, ( # type: ignore[index, operator] + "tile_shape_mnk[1] must be divisible by blk_mn" + ) + + assert tile_shape_mnk[2] % sf_vec_size == 0, ( # type: ignore[index, operator] + "tile_shape_mnk[2] must be divisible by sf_vec_size" + ) + + mma_nsf = tiled_mma.shape_mnk[2] // sf_vec_size + + mn_basic_block_shape = (32, 4) + mn_basic_block_stride = (16, 4) + k_basic_block_shape = (sf_vec_size, mma_nsf) + k_basic_block_stride = (0, 1) + + assert tile_shape_mnk[1] % blk_mn == 0, ( # type: ignore[index, operator] + "tile_shape_mnk[1] must be divisible by blk_mn" + ) + + sSFA_shapeN = (mn_basic_block_shape, tile_shape_mnk[1] // blk_mn) # type: ignore[index, operator] + sSF_strideN = (mn_basic_block_stride, blk_elems) + + assert tile_shape_mnk[2] % (blk_sf * mma_nsf) == 0, ( # type: ignore[index] + "tile_shape_mnk[2] must be divisible by blk_sf * mma_nsf" + ) + + sSFA_shapeK = ( + k_basic_block_shape, + blk_sf // mma_nsf, + tile_shape_mnk[2] // sf_vec_size // blk_sf, # type: ignore[index, operator] + ) + sSF_strideK = ( + k_basic_block_stride, + mma_nsf, + tile_shape_mnk[1] // blk_mn * blk_elems, # type: ignore[index, operator] + ) + + sSFA_shape = (sSFA_shapeN, sSFA_shapeK) + sSFA_stride = (sSF_strideN, sSF_strideK) + + smem_layout = cute.make_layout(sSFA_shape, stride=sSFA_stride) + + # (((Atom_Inst_M, Rest_M),(Atom_Inst_K, Rest_K)), MMA_M, MMA_K, STAGE) + sfb_smem_layout_staged = cute.append( + smem_layout, + cute.make_layout( + num_stages, stride=cute.cosize(cute.filter_zeros(smem_layout)) + ), + ) + + return sfb_smem_layout_staged + + +@dsl_user_op +def make_tmem_layout_sfa( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: cute.Tile, + sf_vec_size: int, + smem_layout: cute.Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """Make tmem layout for SFA based on: + + 1. SFA smem layout per stage + 2. Cta tile shape m + 3. tiled MMA atom thr size + 4. Scale factor vector size + + :param tiled_mma: The tiled MMA + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The mma tiler shape + :type mma_tiler_mnk: cute.Tile + :param sf_vec_size: The scale factor vector size + :type sf_vec_size: int + :param smem_layout: The smem layout of SFA per stage + :type smem_layout: cute.Layout + + :return: TMEM layout for SFA + :rtype: cute.Layout + """ + atom_thr_size = cute.size(tiled_mma.thr_id.shape, loc=loc, ip=ip) + cta_tile_shape_m = mma_tiler_mnk[0] // atom_thr_size # type: ignore[index] + + sfa_layout_ty = _cute_nvgpu_ir.make_tmem_layout_sfa( + smem_layout, cta_tile_shape_m, atom_thr_size, sf_vec_size + ) + return _cute_ir.static(sfa_layout_ty, loc=loc, ip=ip) + + +@dsl_user_op +def make_tmem_layout_sfb( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: cute.Tile, + sf_vec_size: int, + smem_layout: cute.Layout, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """Make tmem layout for SFB based on: + + 1. SFB smem layout per stage + 2. Cta tile shape m + 3. tiled MMA atom thr size + 4. Scale factor vector size + + :param tiled_mma: The tiled MMA + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The mma tiler shape + :type mma_tiler_mnk: cute.Tile + :param sf_vec_size: The scale factor vector size + :type sf_vec_size: int + :param smem_layout: The smem layout of SFB per stage + :type smem_layout: cute.Layout + + :return: TMEM layout for SFB + :rtype: cute.Layout + """ + atom_thr_size = cute.size(tiled_mma.thr_id.shape, loc=loc, ip=ip) + cta_tile_shape_m = mma_tiler_mnk[0] // atom_thr_size # type: ignore[index] + + sfb_layout_ty = _cute_nvgpu_ir.make_tmem_layout_sfb( + smem_layout, cta_tile_shape_m, atom_thr_size, sf_vec_size + ) + return _cute_ir.static(sfb_layout_ty, loc=loc, ip=ip) + + +@dataclass(frozen=True) +class Sm103BlockScaledBasicChunk: + """ + Basic scale-factor atom layout decided by tcgen05 BlockScaled MMA Ops on SM103. + + Represents the fixed layout pattern for scale factors used by tcgen05 + BlockScaled MMA Ops on SM103. The layout is determined by the instruction + specification and is not configurable. + """ + + sf_vec_size: int + major_mode: OperandMajorMode = OperandMajorMode.K + _layout: cute.Layout = field(init=False, repr=False) + + def __post_init__(self) -> None: + atom_shape: cute.Shape + atom_stride: cute.Stride + if self.major_mode == OperandMajorMode.K: + atom_shape = ((8, 4, 4), (self.sf_vec_size, 4)) + atom_stride = ((16, 128, 4), (0, 1)) + else: + atom_shape = ((self.sf_vec_size, 4), (8, 4, 4)) + atom_stride = ((0, 1), (16, 128, 4)) + + object.__setattr__( + self, "_layout", cute.make_layout(shape=atom_shape, stride=atom_stride) + ) + + @property + def layout(self) -> cute.Layout: + return self._layout + + +@dsl_user_op +def sm103_make_smem_layout_sfa( + tiled_mma: cute.TiledMma, + mma_tiler: cute.Tile, + sf_vec_size: int, + num_stages: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """ + Make SMEM layout for SFA based on: + 1) Sm103BlockScaledBasicChunk, 2) MMA tiler, 3) sf_vec_size, 4) stages. + + :param tiled_mma: The tiled MMA + :type tiled_mma: cute.TiledMma + :param mma_tiler: The mma tiler shape + :type mma_tiler: cute.Tile + :param sf_vec_size: The scale factor vector size + :type sf_vec_size: int + :param num_stages: The number of stages + :type num_stages: int + + :return: Smem layout for SFA + :rtype: cute.Layout + """ + mma_shape_mk = tiled_mma.partition_shape_A((mma_tiler[0], mma_tiler[2])) # type: ignore[index] + sf_atom = Sm103BlockScaledBasicChunk(sf_vec_size, tiled_mma.op.a_major_mode).layout # type: ignore[attr-defined] + k_divisor = 4 if sf_vec_size == 16 else 2 + mma_sfa_tiler = ( + mma_shape_mk[0][0] * mma_shape_mk[1], + mma_shape_mk[0][1] * mma_shape_mk[2] // k_divisor, + ) + sfa_smem_atom_layout = cute.tiled_product( + sf_atom, + cute.make_layout( + cute.shape_div(mma_sfa_tiler, cute.product_each(sf_atom.shape)) + ), + ) + sfa_smem_layout_staged = cute.make_layout( + shape=cute.append(sfa_smem_atom_layout.shape, num_stages), + stride=cute.append( + sfa_smem_atom_layout.stride, + cute.size(cute.filter_zeros(sfa_smem_atom_layout)), + ), + ) + return sfa_smem_layout_staged + + +@dsl_user_op +def sm103_make_smem_layout_sfb( + tiled_mma: cute.TiledMma, + mma_tiler: cute.Tile, + sf_vec_size: int, + num_stages: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """ + Make SMEM layout for SFB based on the basic chunk, MMA tiler, sf_vec_size, stages. + + :param tiled_mma: The tiled MMA + :type tiled_mma: cute.TiledMma + :param mma_tiler: The mma tiler shape + :type mma_tiler: cute.Tile + :param sf_vec_size: The scale factor vector size + :type sf_vec_size: int + :param num_stages: The number of stages + :type num_stages: int + + :return: Smem layout for SFB + :rtype: cute.Layout + """ + sf_atom = Sm103BlockScaledBasicChunk(sf_vec_size, tiled_mma.op.b_major_mode).layout # type: ignore[attr-defined] + k_divisor = 4 if sf_vec_size == 16 else 2 + mma_sfb_tiler = (mma_tiler[1], mma_tiler[2] // k_divisor) # type: ignore[index, operator] + if mma_sfb_tiler[0] == 128: + sfb_smem_atom_layout = cute.tiled_product( + sf_atom, + cute.make_layout( + cute.shape_div(mma_sfb_tiler, cute.product_each(sf_atom.shape)) + ), + ) + else: + sf_k_major_atom256 = cute.make_layout( + shape=( + (32, 4, 2), + (sf_vec_size, 4), + ), + stride=( + (16, 4, mma_sfb_tiler[1] // sf_vec_size // 4 * 512), + (0, 1), + ), + ) + sfb_smem_atom_layout = cute.tiled_product( + sf_k_major_atom256, + cute.make_layout( + cute.shape_div( + mma_sfb_tiler, cute.product_each(sf_k_major_atom256.shape) + ) + ), + ) + + sfb_smem_layout_staged = cute.make_layout( + shape=cute.append(sfb_smem_atom_layout.shape, num_stages), + stride=cute.append( + sfb_smem_atom_layout.stride, + cute.size(cute.filter_zeros(sfb_smem_atom_layout)), + ), + ) + return sfb_smem_layout_staged diff --git a/reference/dense_blockscaled_gemm_persistent.py b/reference/dense_blockscaled_gemm_persistent.py new file mode 100644 index 00000000..3b9bbe48 --- /dev/null +++ b/reference/dense_blockscaled_gemm_persistent.py @@ -0,0 +1,3152 @@ +# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +from typing import Type, Tuple, Union, Literal + +import cuda.bindings.driver as cuda +import torch + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.torch as cutlass_torch +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils +from cutlass.cute.runtime import make_ptr + +""" +This example provides an experimental implementation of the SM100 batched dense blockscaled GEMM kernel, please note that the APIs and implementation details related to this kernel may change in future releases. + +A high-performance persistent batched dense blockscaled GEMM example for the NVIDIA Blackwell SM100 architecture +using CUTE DSL. +- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M") for MXF8 input type and can only be row-major("K") for MXF4/NVF4 input type +- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K") for MXF8 input type and can only be row-major("K") for MXF4/NVF4 input type +- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") +- Matrix SFA layout is filled internally according to A shape and BlockScaledBasicChunk, which has M×ceil_div(K, sf_vec_size)×L elements respectively +- Matrix SFB layout is filled internally according to B shape and BlockScaledBasicChunk, which has N×ceil_div(K, sf_vec_size)×L elements respectively + +This GEMM kernel supports the following features: + - Utilizes Tensor Memory Access (TMA) for efficient memory operations + - Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations (including 2cta mma instructions) + - Implements TMA multicast with cluster to reduce L2 memory traffic + - Support persistent tile scheduling to better overlap memory load/store with mma between tiles + - Support warp specialization to avoid explicit pipelining between mainloop load and mma + +This GEMM works as follows: +1. DMA warp: Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. +2. MMA warp: + - Load scale factor A/B from shared memory (SMEM) to tensor memory (TMEM) using tcgen05.cp instruction. + - Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction. +3. EPILOGUE warp: + - Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. + - Type convert C matrix to output type. + - Optionally store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations, + or directly store C matrix from registers (RMEM) to global memory (GMEM) without TMA operations. + - Optionally accept an elementwise lambda function epilogue_op to apply to the output tensor: + e.g., relu can set epilogue_op = lambda x: cute.where(x > 0, x, cute.full_like(x, 0)) + +SM100 tcgen05.mma.kind.block_scale instructions operate as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Read scalefactor A from TMEM +- Read scalefactor B from TMEM +- Write accumulator to TMEM +The accumulator in TMEM must then be loaded to registers before writing back to GMEM. + +Input arguments to this example is shown below: + +.. code-block:: bash + + python examples/blackwell/dense_blockscaled_gemm_persistent.py \ + --ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU --sf_vec_size 16 \ + --c_dtype Float16 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 8192,8192,1024,1 + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/dense_blockscaled_gemm_persistent.py \ + --ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU --sf_vec_size 16 \ + --c_dtype Float16 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 8192,8192,1024,1 \ + --warmup_iterations 1 --iterations 10 --skip_ref_check + + +Constraints: +* Supported input data types: mxf8, mxf4, nvf4 + see detailed valid dtype combinations in below Sm100BlockScaledPersistentDenseGemmKernel class documentation +* A/B tensor must have the same data type, mixed data type is not supported (e.g., mxf8 x mxf4) +* Mma tiler M must be 128 or 256(use_2cta_instrs) +* Mma tiler N must be 64/128/192/256 +* Cluster shape M/N must be positive and power of 2, total cluster size <= 16 +* Cluster shape M must be multiple of 2 if Mma tiler M is 256(use_2cta_instrs) +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned, + i.e, number of elements is a multiple of 16 and 32 for Float8 and Float4, respectively. +""" + + +class Sm100BlockScaledPersistentDenseGemmKernel: + """This class implements batched matrix multiplication (C = A x SFA x B x SFB) with support for various data types + and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization. + + :param sf_vec_size: Scalefactor vector size. + :type sf_vec_size: int + :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N) + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + + :note: In current version, A and B tensor must have the same data type + - i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported + + :note: Supported combinations of A/B data types, SF data typs and SF vector size: + - MXF8: A/B: Float8E5M2/Float8E4M3FN + SF: Float8E8M0FNU + sf_vec_size: 32 + - MXF4: A/B: Float4E2M1FN + SF: Float8E8M0FNU + sf_vec_size: 32 + - NVF4: A/B: Float4E2M1FN + SF: Float8E8M0FNU/Float8E4M3FN + sf_vec_size: 16 + + :note: Supported accumulator data types: + - Float32 + + :note: Supported C data types: + - Float32 + - Float16/BFloat16 + - Float8E4M3FN/Float8E5M2 + :note: Constraints: + - MMA tiler M must be 128 or 256 (use_2cta_instrs) + - MMA tiler N must be 64/128/192/256 + - Cluster shape M must be multiple of 2 if Mma tiler M is 256 + - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 + - Also, Cluster shape M/N must be <= 4 for scale factor multicasts due to limited size of scale factors + + Example: + >>> gemm = Sm100BlockScaledPersistentDenseGemmKernel( + ... sf_vec_size=16, + ... mma_tiler_mn=(256, 128), + ... cluster_shape_mn=(2, 1) + ... ) + >>> gemm(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, max_active_clusters, stream) + """ + + def __init__( + self, + sf_vec_size: int, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ): + """Initializes the configuration for a Blackwell dense GEMM kernel. + + This configuration includes several key aspects: + + 1. MMA Instruction Settings (tcgen05): + - acc_dtype: Data types for MMA accumulator, always set to Float32 + - sf_vec_size: Scalefactor A/B vector size. + - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler. + + 2. Cluster Shape: + - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster. + + :param sf_vec_size: Scalefactor vector size. + :type sf_vec_size: int + :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: Tuple[int, int] + """ + + self.acc_dtype = cutlass.Float32 + self.sf_vec_size = sf_vec_size + self.use_2cta_instrs = mma_tiler_mn[0] == 256 + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + + self.cta_group = ( + tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + self.occupancy = 1 + # Set specialized warp ids + self.epilog_warp_id = ( + 0, + 1, + 2, + 3, + ) + self.mma_warp_id = 4 + self.tma_warp_id = 5 + self.threads_per_warp = 32 + self.threads_per_cta = self.threads_per_warp * len( + (self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id) + ) + # Set barrier id for epilogue sync and tmem ptr sync + self.epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=self.threads_per_warp * len(self.epilog_warp_id), + ) + self.tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=2, + num_threads=self.threads_per_warp * len((self.mma_warp_id, *self.epilog_warp_id)), + ) + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + self.num_tmem_alloc_cols = cute.arch.get_max_tmem_alloc_cols("sm_100") + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B/SFA/SFB + - Computing epilogue subtile + - Setting up A/B/SFA/SFB/C stage counts in shared memory + - Computing A/B/SFA/SFB/C shared memory layout + """ + # Compute mma instruction shapes + # (MMA_Tile_Shape_M, MMA_Tile_Shape_N, MMA_Inst_Shape_K) + self.mma_inst_shape_mn = ( + self.mma_tiler[0], + self.mma_tiler[1], + ) + # (CTA_Tile_Shape_M, Round_Up(MMA_Tile_Shape_N, 128), MMA_Inst_Shape_K) + self.mma_inst_shape_mn_sfb = ( + self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1), + cute.round_up(self.mma_inst_shape_mn[1], 128), + ) + + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mn, + ) + + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + cute.nvgpu.tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mn_sfb, + ) + + # Compute mma/cluster/tile shapes + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_inst_shape_mn[0], + self.mma_inst_shape_mn[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.mma_tiler_sfb = ( + self.mma_inst_shape_mn_sfb[0], + self.mma_inst_shape_mn_sfb[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + self.cta_tile_shape_mnk_sfb = ( + self.mma_tiler_sfb[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler_sfb[1], + self.mma_tiler_sfb[2], + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + self.cluster_layout_sfb_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma_sfb.thr_id.shape,), + ) + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.num_mcast_ctas_sfb = cute.size(self.cluster_layout_sfb_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + self.is_sfb_mcast = self.num_mcast_ctas_sfb > 1 + + # Compute epilogue subtile + self.epi_tile = sm100_utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + self.epi_tile_n = cute.size(self.epi_tile[1]) + + # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory + self.num_acc_stage, self.num_ab_stage, self.num_c_stage = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.sf_dtype, + self.sf_vec_size, + self.smem_capacity, + self.occupancy, + ) + + # Compute A/B/SFA/SFB/C shared memory layout + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + + # Overlap and double buffer accumulator when num_acc_stage == 1 for cta_tile_n = 256 case + self.overlapping_accum = self.num_acc_stage == 1 + + # Compute number of TMEM columns for SFA/SFB/Accumulator + sf_atom_mn = 32 + self.num_sfa_tmem_cols = ( + self.cta_tile_shape_mnk[0] // sf_atom_mn + ) * mma_inst_tile_k + self.num_sfb_tmem_cols = ( + self.cta_tile_shape_mnk_sfb[1] // sf_atom_mn + ) * mma_inst_tile_k + self.num_sf_tmem_cols = self.num_sfa_tmem_cols + self.num_sfb_tmem_cols + self.num_accumulator_tmem_cols = ( + self.cta_tile_shape_mnk[1] * self.num_acc_stage + if not self.overlapping_accum + else self.cta_tile_shape_mnk[1] * 2 - self.num_sf_tmem_cols + ) + + # Only when overlapping_accum is enabled, we need to release accumulator buffer early in epilogue + self.iter_acc_early_release_in_epilogue = ( + self.num_sf_tmem_cols // self.epi_tile_n + ) + + @cute.jit + def __call__( + self, + a_ptr: cute.Pointer, + b_ptr: cute.Pointer, + sfa_ptr: cute.Pointer, + sfb_ptr: cute.Pointer, + c_ptr: cute.Pointer, + layouts: cutlass.Constexpr[ + Tuple[tcgen05.OperandMajorMode, tcgen05.OperandMajorMode, utils.LayoutEnum] + ], + problem_mnkl: Tuple[int, int, int, int], + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Execute the GEMM operation in steps: + - Setup static attributes before smem/grid/tma computation + - Setup TMA load/store atoms and tensors + - Compute grid size with regard to hardware constraints + - Define shared storage for kernel + - Launch the kernel synchronously + + :param a_tensor: Input tensor A + :type a_tensor: cute.Tensor + :param b_tensor: Input tensor B + :type b_tensor: cute.Tensor + :param sfa_tensor: Scale factor tensor A + :type sfa_tensor: cute.Tensor + :param sfb_tensor: Scale factor tensor B + :type sfb_tensor: cute.Tensor + :param c_tensor: Output tensor C + :type c_tensor: cute.Tensor + :param max_active_clusters: Maximum number of active clusters + :type max_active_clusters: cutlass.Constexpr + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + :param epilogue_op: Optional elementwise lambda function to apply to the output tensor + :type epilogue_op: cutlass.Constexpr + :raises TypeError: If input data types are incompatible with the MMA instruction. + """ + # Setup static attributes before smem/grid/tma computation + self.a_dtype: Type[cutlass.Numeric] = a_ptr.value_type + self.b_dtype: Type[cutlass.Numeric] = b_ptr.value_type + self.sf_dtype: Type[cutlass.Numeric] = sfa_ptr.value_type + self.c_dtype: Type[cutlass.Numeric] = c_ptr.value_type + + m, n, k, l = problem_mnkl + self.a_major_mode, self.b_major_mode, self.c_layout = layouts + + # Check if input data types are compatible with MMA instruction + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + a_layout = cute.make_ordered_layout((m, cute.assume(k, 32), l), order=(0, 1, 2)) + if cutlass.const_expr(self.a_major_mode == tcgen05.OperandMajorMode.K): + a_layout = cute.make_ordered_layout( + (cute.assume(m, 32), k, l), order=(1, 0, 2) + ) + b_layout = cute.make_ordered_layout((n, cute.assume(k, 32), l), order=(0, 1, 2)) + if cutlass.const_expr(self.b_major_mode == tcgen05.OperandMajorMode.K): + b_layout = cute.make_ordered_layout( + (cute.assume(n, 32), k, l), order=(1, 0, 2) + ) + c_layout = cute.make_ordered_layout((cute.assume(m, 32), n, l), order=(0, 1, 2)) + if cutlass.const_expr(self.c_layout == utils.LayoutEnum.ROW_MAJOR): + c_layout = cute.make_ordered_layout( + (m, cute.assume(n, 32), l), order=(1, 0, 2) + ) + a_tensor = cute.make_tensor(a_ptr, a_layout) + b_tensor = cute.make_tensor(b_ptr, b_layout) + c_tensor = cute.make_tensor(c_ptr, c_layout) + + # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout + # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL) + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( + a_tensor.shape, self.sf_vec_size + ) + sfa_tensor = cute.make_tensor(sfa_ptr, sfa_layout) + + # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( + b_tensor.shape, self.sf_vec_size + ) + sfb_tensor = cute.make_tensor(sfb_ptr, sfb_layout) + + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mn, + ) + + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + cute.nvgpu.tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mn_sfb, + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a_tensor, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # Setup TMA load for B + b_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b_tensor, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # Setup TMA load for SFA + sfa_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + sfa_smem_layout = cute.slice_( + self.sfa_smem_layout_staged, (None, None, None, 0) + ) + tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( + sfa_op, + sfa_tensor, + sfa_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + + # Setup TMA load for SFB + sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB( + self.cluster_shape_mn, tiled_mma.thr_id + ) + sfb_smem_layout = cute.slice_( + self.sfb_smem_layout_staged, (None, None, None, 0) + ) + tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + sfb_op, + sfb_tensor, + sfb_smem_layout, + self.mma_tiler_sfb, + tiled_mma_sfb, + self.cluster_layout_sfb_vmnk.shape, + internal_type=cutlass.Int16, + ) + + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): + x = tma_tensor_sfb.stride[0][1] + y = cute.ceil_div(tma_tensor_sfb.shape[0][1], 4) + + new_shape = ( + (tma_tensor_sfb.shape[0][0], ((2, 2), y)), + tma_tensor_sfb.shape[1], + tma_tensor_sfb.shape[2], + ) + # Use right multiplication for ScaledBasis (3 * x instead of x * 3) + x_times_3 = 3 * x + new_stride = ( + (tma_tensor_sfb.stride[0][0], ((x, x), x_times_3)), + tma_tensor_sfb.stride[1], + tma_tensor_sfb.stride[2], + ) + tma_tensor_sfb_new_layout = cute.make_layout(new_shape, stride=new_stride) + tma_tensor_sfb = cute.make_tensor( + tma_tensor_sfb.iterator, tma_tensor_sfb_new_layout + ) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout) + sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) + self.num_tma_load_bytes = ( + a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size + ) * atom_thr_size + + # Setup TMA store for C + epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + c_tensor, + epi_smem_layout, + self.epi_tile, + ) + + # Compute grid size + self.tile_sched_params, grid = self._compute_grid( + c_tensor, + self.cta_tile_shape_mnk, + self.cluster_shape_mn, + max_active_clusters, + ) + + self.buffer_align_bytes = 1024 + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + tmem_dealloc_mbar: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + cute.cosize(self.c_smem_layout_staged.outer), + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sSFA: cute.struct.Align[ + cute.struct.MemRange[ + self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sSFB: cute.struct.Align[ + cute.struct.MemRange[ + self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tiled_mma_sfb, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + tma_atom_c, + tma_tensor_c, + self.cluster_layout_vmnk, + self.cluster_layout_sfb_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.sfa_smem_layout_staged, + self.sfb_smem_layout_staged, + self.c_smem_layout_staged, + self.epi_tile, + self.tile_sched_params, + epilogue_op, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + stream=stream, + min_blocks_per_mp=1, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tiled_mma_sfb: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + mSFA_mkl: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + mSFB_nkl: cute.Tensor, + tma_atom_c: cute.CopyAtom, + mC_mnl: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + cluster_layout_sfb_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout], + epi_tile: cute.Tile, + tile_sched_params: utils.PersistentTileSchedulerParams, + epilogue_op: cutlass.Constexpr, + ): + """ + GPU device kernel performing the Persistent batched GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # + # Prefetch tma desc + # + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + cpasync.prefetch_descriptor(tma_atom_sfa) + cpasync.prefetch_descriptor(tma_atom_sfb) + cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + # Initialize mainloop ab_pipeline (barrier) and states + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer + ) + ab_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ) + + # Initialize acc_pipeline (barrier) and states + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = self.threads_per_warp * len(self.epilog_warp_id) * ( + 2 if use_2cta_instrs else 1 + ) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_acc_consumer_threads + ) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ) + + # Tensor memory dealloc barrier init + tmem = utils.TmemAllocator( + storage.tmem_holding_buf.ptr, + barrier_for_retrieve=self.tmem_alloc_barrier, + allocator_warp_id=self.epilog_warp_id[0], + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr, + ) + + # Cluster arrive after barrier init + pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True) + + # + # Setup smem tensor A/B/SFA/SFB/C + # + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = storage.sC.get_tensor( + c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner + ) + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + # (MMA, MMA_M, MMA_K, STAGE) + sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) + # (MMA, MMA_N, MMA_K, STAGE) + sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged) + + # + # Compute multicast mask for A/B/SFA/SFB buffer full + # + a_full_mcast_mask = None + b_full_mcast_mask = None + sfa_full_mcast_mask = None + sfb_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + sfa_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + sfb_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1 + ) + + # + # Local_tile partition global tensors + # + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bK, RestM, RestK, RestL) + gSFA_mkl = cute.local_tile( + mSFA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gSFB_nkl = cute.local_tile( + mSFB_nkl, + cute.slice_(self.mma_tiler_sfb, (0, None, None)), + (None, None, None), + ) + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgSFA = thr_mma.partition_A(gSFA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCgC = thr_mma.partition_C(gC_mnl) + + # + # Partition global/shared tensor for TMA load A/B + # + # TMA load A partition_S/D + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # TMA load scaled factor A partition_S/D + sfa_cta_layout = a_cta_layout + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsSFA, tAgSFA = cute.nvgpu.cpasync.tma_partition( + tma_atom_sfa, + block_in_cluster_coord_vmnk[2], + sfa_cta_layout, + cute.group_modes(sSFA, 0, 3), + cute.group_modes(tCgSFA, 0, 3), + ) + tAsSFA = cute.filter_zeros(tAsSFA) + tAgSFA = cute.filter_zeros(tAgSFA) + + # TMA load scaled factor B partition_S/D + sfb_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition( + tma_atom_sfb, + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + cute.group_modes(sSFB, 0, 3), + cute.group_modes(tCgSFB, 0, 3), + ) + tBsSFB = cute.filter_zeros(tBsSFB) + tBgSFB = cute.filter_zeros(tBgSFB) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + if cutlass.const_expr(self.overlapping_accum): + num_acc_stage_overlapped = 2 + tCtAcc_fake = tiled_mma.make_fragment_C( + cute.append(acc_shape, num_acc_stage_overlapped) + ) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = cute.make_tensor( + tCtAcc_fake.iterator, + cute.make_layout( + tCtAcc_fake.shape, + stride=( + tCtAcc_fake.stride[0], + tCtAcc_fake.stride[1], + tCtAcc_fake.stride[2], + (256 - self.num_sf_tmem_cols) * tCtAcc_fake.stride[0][1], + ), + ), + ) + else: + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = tiled_mma.make_fragment_C( + cute.append(acc_shape, self.num_acc_stage) + ) + + # + # Cluster wait before tensor memory alloc + # + pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn) + + # + # Specialized TMA load warp + # + if warp_idx == self.tma_warp_id: + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + ab_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_ab_stage + ) + + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), RestK) + tAgA_slice = tAgA[ + (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) + ] + # ((atom_v, rest_v), RestK) + tBgB_slice = tBgB[ + (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) + ] + + # ((atom_v, rest_v), RestK) + tAgSFA_slice = tAgSFA[ + (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) + ] + + slice_n = mma_tile_coord_mnl[1] + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): + slice_n = mma_tile_coord_mnl[1] // 2 + # ((atom_v, rest_v), RestK) + tBgSFB_slice = tBgSFB[(None, slice_n, None, mma_tile_coord_mnl[2])] + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + ab_producer_state.reset_count() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + # + # Tma load loop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # Conditionally wait for AB buffer empty + ab_pipeline.producer_acquire( + ab_producer_state, peek_ab_empty_status + ) + + # TMA load A/B/SFA/SFB + cute.copy( + tma_atom_a, + tAgA_slice[(None, ab_producer_state.count)], + tAsA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_slice[(None, ab_producer_state.count)], + tBsB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atom_sfa, + tAgSFA_slice[(None, ab_producer_state.count)], + tAsSFA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=sfa_full_mcast_mask, + ) + cute.copy( + tma_atom_sfb, + tBgSFB_slice[(None, ab_producer_state.count)], + tBsSFB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=sfb_full_mcast_mask, + ) + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 + ab_producer_state.advance() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Wait A/B buffer empty + # + ab_pipeline.producer_tail(ab_producer_state) + + # + # Specialized MMA warp + # + if warp_idx == self.mma_warp_id: + # + # Bar sync for retrieve tensor memory ptr from shared mem + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator/SFA/SFB tensor + # + acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # Make accumulator tmem tensor + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # Make SFA tmem tensor + sfa_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + self.num_accumulator_tmem_cols, + dtype=self.sf_dtype, + ) + # (MMA, MMA_M, MMA_K) + tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) + + # Make SFB tmem tensor + sfb_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols, + dtype=self.sf_dtype, + ) + # (MMA, MMA_N, MMA_K) + tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) + # + # Partition for S2T copy of SFA/SFB + # + ( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t, + tCtSFA_compact_s2t, + ) = self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA) + ( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t, + tCtSFB_compact_s2t, + ) = self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + ab_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_ab_stage + ) + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # Get accumulator stage index + if cutlass.const_expr(self.overlapping_accum): + acc_stage_index = acc_producer_state.phase ^ 1 + else: + acc_stage_index = acc_producer_state.index + + # Set tensor memory buffer for current tile + # (MMA, MMA_M, MMA_N) + tCtAcc = tCtAcc_base[(None, None, None, acc_stage_index)] + + # Peek (try_wait) AB buffer full for k_tile = 0 + ab_consumer_state.reset_count() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # + # Wait for accumulator buffer empty + # + if is_leader_cta: + acc_pipeline.producer_acquire(acc_producer_state) + + tCtSFB_mma = tCtSFB + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): + # If this is an ODD tile, shift the TMEM start address for cta_tile_shape_n=192 case by two words (ignores first 64 columns of SFB) + offset = ( + cutlass.Int32(2) + if mma_tile_coord_mnl[1] % 2 == 1 + else cutlass.Int32(0) + ) + shifted_ptr = cute.recast_ptr( + acc_tmem_ptr + + self.num_accumulator_tmem_cols + + self.num_sfa_tmem_cols + + offset, + dtype=self.sf_dtype, + ) + tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout) + elif cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): + # Move in increments of 64 columns of SFB + offset = cutlass.Int32((mma_tile_coord_mnl[1] % 2) * 2) + shifted_ptr = cute.recast_ptr( + acc_tmem_ptr + + self.num_accumulator_tmem_cols + + self.num_sfa_tmem_cols + + offset, + dtype=self.sf_dtype, + ) + tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout) + + # + # Reset the ACCUMULATE field for each tile + # + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + # + # Mma mainloop + # + for k_tile in range(k_tile_cnt): + if is_leader_cta: + # Conditionally wait for AB buffer full + ab_pipeline.consumer_wait( + ab_consumer_state, peek_ab_full_status + ) + + # Copy SFA/SFB from smem to tmem + s2t_stage_coord = ( + None, + None, + None, + None, + ab_consumer_state.index, + ) + tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord] + tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord] + cute.copy( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t_staged, + tCtSFA_compact_s2t, + ) + cute.copy( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t_staged, + tCtSFB_compact_s2t, + ) + + # tCtAcc += tCrA * tCrSFA * tCrB * tCrSFB + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = ( + None, + None, + kblock_idx, + ab_consumer_state.index, + ) + + # Set SFA/SFB tensor to tiled_mma + sf_kblock_coord = (None, None, kblock_idx) + tiled_mma.set( + tcgen05.Field.SFA, + tCtSFA[sf_kblock_coord].iterator, + ) + tiled_mma.set( + tcgen05.Field.SFB, + tCtSFB_mma[sf_kblock_coord].iterator, + ) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_pipeline.consumer_release(ab_consumer_state) + + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 + ab_consumer_state.advance() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt: + if is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # + # Async arrive accumulator buffer full + # + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + acc_producer_state.advance() + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Wait for accumulator buffer empty + # + acc_pipeline.producer_tail(acc_producer_state) + # + # Specialized epilogue warps + # + if warp_idx < self.mma_warp_id: + # + # Alloc tensor memory buffer + # + tmem.allocate(self.num_tmem_alloc_cols) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # + # Partition for epilogue + # + epi_tidx = tidx + ( + tiled_copy_t2r, + tTR_tAcc_base, + tTR_rAcc, + ) = self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs + ) + + tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, epi_tidx, sC + ) + ( + tma_atom_c, + bSG_sC, + bSG_gC_partitioned, + ) = self.epilog_gmem_copy_and_partition( + epi_tidx, tma_atom_c, tCgC, epi_tile, sC + ) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + + # Threads/warps participating in tma store pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len(self.epilog_warp_id), + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=c_producer_group, + ) + + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # + # Slice to per mma tile index + # + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC_partitioned[ + ( + None, + None, + None, + *mma_tile_coord_mnl, + ) + ] + + # Get accumulator stage index + if cutlass.const_expr(self.overlapping_accum): + acc_stage_index = acc_consumer_state.phase + reverse_subtile = True if acc_stage_index == 0 else False + else: + acc_stage_index = acc_consumer_state.index + + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[ + (None, None, None, None, None, acc_stage_index) + ] + + # + # Wait for accumulator buffer full + # + acc_pipeline.consumer_wait(acc_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt + for subtile_idx in cutlass.range(subtile_cnt): + real_subtile_idx = subtile_idx + if cutlass.const_expr(self.overlapping_accum): + if reverse_subtile: + real_subtile_idx = ( + self.cta_tile_shape_mnk[1] // self.epi_tile_n + - 1 + - subtile_idx + ) + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, real_subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Async arrive accumulator buffer empty ealier when overlapping_accum is enabled + # + if cutlass.const_expr(self.overlapping_accum): + if subtile_idx == self.iter_acc_early_release_in_epilogue: + # Fence for TMEM load + cute.arch.fence_view_async_tmem_load() + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + # + # Convert to C type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tRS_rC.store(acc_vec) + + # + # Store C to shared memory + # + c_buffer = (num_prev_subtiles + subtile_idx) % self.num_c_stage + cute.copy( + tiled_copy_r2s, + tRS_rC, + tRS_sC[(None, None, None, c_buffer)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + "async.shared", + space="cta", + ) + self.epilog_sync_barrier.arrive_and_wait() + + # + # TMA store C to global memory + # + if warp_idx == self.epilog_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, real_subtile_idx)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + self.epilog_sync_barrier.arrive_and_wait() + + # + # Async arrive accumulator buffer empty + # + if cutlass.const_expr(not self.overlapping_accum): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Dealloc the tensor memory buffer + # + tmem.relinquish_alloc_permit() + self.epilog_sync_barrier.arrive_and_wait() + tmem.free(acc_tmem_ptr) + # + # Wait for C store complete + # + c_pipeline.producer_tail() + + def mainloop_s2t_copy_and_partition( + self, + sSF: cute.Tensor, + tSF: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for smem to tmem load for scale factor tensor, then use it to partition smem memory (source) and tensor memory (destination). + + :param sSF: The scale factor tensor in smem + :type sSF: cute.Tensor + :param tSF: The scale factor tensor in tmem + :type tSF: cute.Tensor + + :return: A tuple containing (tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t) where: + - tiled_copy_s2t: The tiled copy operation for smem to tmem load for scale factor tensor(s2t) + - tCsSF_compact_s2t: The partitioned scale factor tensor in smem + - tSF_compact_s2t: The partitioned scale factor tensor in tmem + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # (MMA, MMA_MN, MMA_K, STAGE) + tCsSF_compact = cute.filter_zeros(sSF) + # (MMA, MMA_MN, MMA_K) + tCtSF_compact = cute.filter_zeros(tSF) + + # Make S2T CopyAtom and tiledCopy + copy_atom_s2t = cute.make_copy_atom( + tcgen05.Cp4x32x128bOp(self.cta_group), + self.sf_dtype, + ) + tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact) + thr_copy_s2t = tiled_copy_s2t.get_slice(0) + + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t, tCsSF_compact_s2t_ + ) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) + tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact) + + return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0, None)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy( + copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_mnl_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_rmem_tensor( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + :type sepi: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tidx: cutlass.Int32, + atom: Union[cute.CopyAtom, cute.TiledCopy], + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to: + partition shared memory (source) and global memory (destination) for TMA store version. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version + :type atom: cute.CopyAtom or cute.TiledCopy + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing (tma_atom_c, bSG_sC, bSG_gC) where: + - tma_atom_c: The TMA copy atom + - bSG_sC: The partitioned shared memory tensor C + - bSG_gC: The partitioned global tensor C + :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + + tma_atom_c = atom + sC_for_tma_partition = cute.group_modes(sC, 0, 2) + gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sC_for_tma_partition, + gC_for_tma_partition, + ) + return tma_atom_c, bSG_sC, bSG_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + epi_tile: cute.Tile, + c_dtype: Type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + smem_capacity: int, + occupancy: int, + ) -> Tuple[int, int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout enum of operand C. + :type c_layout: utils.LayoutEnum + :param sf_dtype: Data type of Scale factor. + :type sf_dtype: type[cutlass.Numeric] + :param sf_vec_size: Scale factor vector size. + :type sf_vec_size: int + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + + :return: A tuple containing the computed number of stages for: + (ACC stages, A/B operand stages, C stages) + :rtype: tuple[int, int, int] + """ + # ACC stages + num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2 + + # Default C stages + num_c_stage = 2 + + # Calculate smem layout and size for one stage of A, B, SFA, SFB and C + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # a tmp 1 stage is provided + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # a tmp 1 stage is provided + ) + sfa_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + 1, # a tmp 1 stage is provided + ) + sfb_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + 1, # a tmp 1 stage is provided + ) + + c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, + ) + + ab_bytes_per_stage = ( + cute.size_in_bytes(a_dtype, a_smem_layout_stage_one) + + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfa_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one) + ) + mbar_helpers_bytes = 1024 + c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + c_bytes = c_bytes_per_stage * num_c_stage + + # Calculate A/B/SFA/SFB stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial C stages bytes + # Divide remaining by bytes needed per A/B/SFA/SFB stage + num_ab_stage = ( + smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes) + ) // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B/SFA/SFB stages and reserved bytes + # Add remaining unused smem to epilogue + num_c_stage += ( + smem_capacity + - occupancy * ab_bytes_per_stage * num_ab_stage + - occupancy * (mbar_helpers_bytes + c_bytes) + ) // (occupancy * c_bytes_per_stage) + + return num_acc_stage, num_ab_stage, num_c_stage + + @staticmethod + def _compute_grid( + c: cute.Tensor, + cta_tile_shape_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + max_active_clusters: cutlass.Constexpr, + ) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]: + """Use persistent tile scheduler to compute the grid size for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr + + :return: A tuple containing: + - tile_sched_params: Parameters for the persistent tile scheduler. + - grid: Grid shape for kernel launch. + :rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]] + """ + c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0)) + gc = cute.zipped_divide(c, tiler=c_shape) + num_ctas_mnl = gc[(0, (None, None, None))].shape + cluster_shape_mnl = (*cluster_shape_mn, 1) + + tile_sched_params = utils.PersistentTileSchedulerParams( + num_ctas_mnl, cluster_shape_mnl + ) + grid = utils.StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @staticmethod + def is_valid_dtypes_and_scale_factor_vec_size( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + c_dtype: Type[cutlass.Numeric], + ) -> bool: + """ + Check if the dtypes and sf_vec_size are valid combinations + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param sf_dtype: The data type of the scale factor + :type sf_dtype: Type[cutlass.Numeric] + :param sf_vec_size: The vector size of the scale factor + :type sf_vec_size: int + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + + :return: True if the dtypes and sf_vec_size are valid, False otherwise + :rtype: bool + """ + is_valid = True + + # Check valid ab_dtype + if ab_dtype not in { + cutlass.Float4E2M1FN, + cutlass.Float8E5M2, + cutlass.Float8E4M3FN, + }: + is_valid = False + + # Check valid sf_vec_size + if sf_vec_size not in {16, 32}: + is_valid = False + + # Check valid sf_dtype + if sf_dtype not in {cutlass.Float8E8M0FNU, cutlass.Float8E4M3FN}: + is_valid = False + + # Check valid sf_dtype and sf_vec_size combinations + if sf_dtype == cutlass.Float8E4M3FN and sf_vec_size == 32: + is_valid = False + if ab_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} and sf_vec_size == 16: + is_valid = False + + # Check valid c_dtype + if c_dtype not in { + cutlass.Float32, + cutlass.Float16, + cutlass.BFloat16, + cutlass.Float8E5M2, + cutlass.Float8E4M3FN, + }: + is_valid = False + + return is_valid + + @staticmethod + def is_valid_layouts( + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: Literal["m", "k"], + b_major: Literal["n", "k"], + c_major: Literal["m", "n"], + ) -> bool: + """ + Check if layouts and dtypes are valid combinations + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major dimension of the A tensor + :type a_major: Literal["m", "k"] + :param b_major: The major dimension of the B tensor + :type b_major: Literal["n", "k"] + :param c_major: The major dimension of the C tensor + :type c_major: Literal["m", "n"] + + :return: True if the layouts are valid, False otherwise + :rtype: bool + """ + is_valid = True + + if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"): + is_valid = False + return is_valid + + @staticmethod + def is_valid_mma_tiler_and_cluster_shape( + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ) -> bool: + """ + Check if the mma tiler and cluster shape are valid + + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + + :return: True if the mma tiler and cluster shape are valid, False otherwise + :rtype: bool + """ + is_valid = True + # Skip invalid mma tile shape + if mma_tiler_mn[0] not in [128, 256]: + is_valid = False + if mma_tiler_mn[1] not in [64, 128, 192, 256]: + is_valid = False + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if mma_tiler_mn[0] == 256 else 1) != 0: + is_valid = False + # Skip invalid cluster shape + is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + # Special cluster shape check for scale factor multicasts. + # Due to limited size of scale factors, we can't multicast among more than 4 CTAs. + or cluster_shape_mn[0] > 4 + or cluster_shape_mn[1] > 4 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_tensor_alignment( + m: int, + n: int, + k: int, + l: int, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: Literal["m", "k"], + b_major: Literal["n", "k"], + c_major: Literal["m", "n"], + ) -> bool: + """ + Check if the tensor alignment is valid + + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: Literal["m", "k"] + :param b_major: The major axis of the B tensor + :type b_major: Literal["n", "k"] + :param c_major: The major axis of the C tensor + :type c_major: Literal["m", "n"] + + :return: True if the problem shape is valid, False otherwise + :rtype: bool + """ + is_valid = True + + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + is_valid = False + return is_valid + + @staticmethod + def can_implement( + mnkl: Tuple[int, int, int, int], + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: Literal["m", "k"], + b_major: Literal["n", "k"], + c_major: Literal["m", "n"], + sf_vec_size: int, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ) -> bool: + """ + Check if the gemm can be implemented + + :param mnkl: The problem size as a tuple (M, N, K, L). + :type mnkl: Tuple[int, int, int, int] + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param sf_dtype: The data type of the scale factor tensor + :type sf_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: Literal["m", "k"] + :param b_major: The major axis of the B tensor + :type b_major: Literal["n", "k"] + :param c_major: The major axis of the C tensor + :type c_major: Literal["m", "n"] + :param sf_vec_size: The vector size + :type sf_vec_size: int + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + :return: True if the gemm can be implemented, False otherwise + :rtype: bool + """ + # Unpack parameters + m, n, k, l = mnkl + can_implement = True + # Skip unsupported types + if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_dtypes_and_scale_factor_vec_size( + ab_dtype, sf_dtype, sf_vec_size, c_dtype + ): + can_implement = False + # Skip unsupported layouts + if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_layouts( + ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + # Skip invalid mma tile shape and cluster shape + if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_mma_tiler_and_cluster_shape( + mma_tiler_mn, cluster_shape_mn + ): + can_implement = False + # Skip illegal problem shape for load/store alignment + if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_tensor_alignment( + m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + return can_implement + + +# Helper function to convert scale factor tensor from MKL layout to (32, 4, restM, 4, restK, l) format +@cute.jit +def cvt_sf_MKL_to_M32x4xrm_K4xrk_L( + sf_ref_ptr: cute.Pointer, + sf_mma_ptr: cute.Pointer, + mn: int, + sf_k: int, + l: int, + mma_shape: tuple, +): + mma_permute_order = (3, 4, 1, 5, 2, 0) + permuted_shape = tuple(mma_shape[i] for i in mma_permute_order) + cute_layout = cute.make_ordered_layout(permuted_shape, order=(2, 1, 4, 0, 3, 5)) + + sf_ref_tensor = cute.make_tensor( + sf_ref_ptr, cute.make_layout((mn, sf_k, l), stride=(sf_k, 1, mn * sf_k)) + ) + sf_mma_tensor = cute.make_tensor(sf_mma_ptr, cute_layout) + + sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3) + sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3) + for i in cutlass.range(cute.size(sf_ref_tensor)): + mkl_coord = sf_ref_tensor.layout.get_hier_coord(i) + sf_mma_tensor[mkl_coord] = sf_ref_tensor[mkl_coord] + pass + + +# Helper function for ceil division +def ceil_div(a, b): + return (a + b - 1) // b + + +# Convert scale factor tensors from (m, k, l) to (32, 4, restM, 4, restK, l) format +def create_and_reorder_scale_factor_tensor( + l, mn, k, sf_vec_size, sf_dtype, torch_tensor +): + """ + Create the CUTE-format scale factor tensor on CUDA based on the reference tensor. + """ + sf_k = ceil_div(k, sf_vec_size) + atom_m = (32, 4) + atom_k = 4 + mma_shape = ( + l, # batch size + ceil_div(mn, atom_m[0] * atom_m[1]), + ceil_div(sf_k, atom_k), + atom_m[0], + atom_m[1], + atom_k, + ) + + # Generate a random int8 tensor, then convert to float8_e4m3fn + cute_tensor = torch.ones(mma_shape, dtype=cutlass_torch.dtype(sf_dtype)).permute( + 3, 4, 1, 5, 2, 0 + ) + + # Call the helper function to do layout conversion + cvt_sf_MKL_to_M32x4xrm_K4xrk_L( + make_ptr( + sf_dtype, + torch_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ), + make_ptr( + sf_dtype, + cute_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ), + mn, + sf_k, + l, + mma_shape, + ) + return cute_tensor.cuda() + + +# Compile the persistent dense blockscaled GEMM operation +def scaled_mm( + gemm_obj: Sm100BlockScaledPersistentDenseGemmKernel, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + a_major: Literal["m", "k"], + b_major: Literal["n", "k"], + c_major: Literal["m", "n"], + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + options: str = "", +): + # Construct CuTe Pointers + a_ptr = make_ptr(ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16) + b_ptr = make_ptr(ab_dtype, 0, cute.AddressSpace.gmem, assumed_align=16) + c_ptr = make_ptr(c_dtype, 0, cute.AddressSpace.gmem, assumed_align=16) + sfa_ptr = make_ptr(sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32) + sfb_ptr = make_ptr(sf_dtype, 0, cute.AddressSpace.gmem, assumed_align=32) + + a_major_mode = ( + tcgen05.OperandMajorMode.K if a_major == "k" else tcgen05.OperandMajorMode.MN + ) + b_major_mode = ( + tcgen05.OperandMajorMode.K if b_major == "k" else tcgen05.OperandMajorMode.MN + ) + c_layout = ( + utils.LayoutEnum.ROW_MAJOR if c_major == "n" else utils.LayoutEnum.COL_MAJOR + ) + return cute.compile( + gemm_obj, + a_ptr, + b_ptr, + sfa_ptr, + sfb_ptr, + c_ptr, + (a_major_mode, b_major_mode, c_layout), + (cutlass.Int32(0), cutlass.Int32(0), cutlass.Int32(0), cutlass.Int32(0)), + max_active_clusters, + stream, + epilogue_op, + options=options, + ) + + +def is_emulated_dtype( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], +) -> bool: + if c_dtype in { + cutlass.Float32, + cutlass.Float16, + cutlass.BFloat16, + }: + if ab_dtype == cutlass.Float4E2M1FN and sf_dtype == cutlass.Float8E4M3FN: + return False + if ab_dtype == cutlass.Float8E4M3FN and sf_dtype == cutlass.Float8E8M0FNU: + return False + + return True + + +# Convert scale factor tensor from MKL layout to blocked layout +def to_blocked(input_matrix): + rows, cols = input_matrix.shape + # Please ensure rows and cols are multiples of 128 and 4 respectively + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + + # Pad the input matrix if necessary + if padded_rows != rows or padded_cols != cols: + # For FP8 types, convert to float32 for padding, then convert back + original_dtype = input_matrix.dtype + input_float32 = input_matrix.to(torch.float32) + padded = torch.nn.functional.pad( + input_float32, + (0, padded_cols - cols, 0, padded_rows - rows), + mode="constant", + value=0, + ) + # Convert back to original dtype if needed + if original_dtype != input_float32.dtype: + padded = padded.to(original_dtype) + else: + padded = input_matrix + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + return rearranged.flatten() + + +# Reference implementation of the persistent dense blockscaled GEMM operation (emulated version) +def reference_scaled_mm_emulated( + a: torch.Tensor, + b: torch.Tensor, + sfa: torch.Tensor, + sfb: torch.Tensor, + c: torch.Tensor, + mnkl: Tuple[int, int, int, int], + sf_vec_size: int, + c_dtype: Type[cutlass.Numeric], +): + m, n, k, l = mnkl + sfa_expanded = ( + torch.repeat_interleave(sfa, sf_vec_size, dim=1)[:, :k, :] + .to(dtype=torch.float32) + .cuda() + ) + sfb_expanded = ( + torch.repeat_interleave(sfb, sf_vec_size, dim=1)[:, :k, :] + .to(dtype=torch.float32) + .cuda() + ) + res_a = torch.einsum("mkl,mkl->mkl", a, sfa_expanded) + res_b = torch.einsum("nkl,nkl->nkl", b, sfb_expanded) + # Cast res_a and res_b to float32 for einsum to avoid NotImplementedError on 'Byte' + ref = torch.einsum("mkl,nkl->mnl", res_a, res_b) + c_ref = ref.to(dtype=cutlass_torch.dtype(c_dtype)) + return c_ref + + +# Reference implementation of the persistent dense blockscaled GEMM operation (non-emulated version) +def reference_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + sfa: torch.Tensor, + sfb: torch.Tensor, + c: torch.Tensor, + mnkl: Tuple[int, int, int, int], + c_dtype: Type[cutlass.Numeric], +): + m, n, k, l = mnkl + c_ref = torch.clone(c) + for l_idx in range(l): + # Convert the scale factor tensor to blocked format + scale_a = to_blocked(sfa[:, :, l_idx]) + scale_b = to_blocked(sfb[:, :, l_idx]) + # Ensure a_slice is row-major (M, K) with stride (K, 1) + a_slice = a[:, :, l_idx].contiguous() + # Ensure b_slice is row-major (N, K) so that transpose gives column-major (K, N) + b_slice = b[:, :, l_idx].contiguous() + # (m, k) @ (n, k).T -> (m, n) + res = torch._scaled_mm( + a_slice, + b_slice.transpose(0, 1), + scale_a.cuda(), + scale_b.cuda(), + bias=None, + out_dtype=c_ref.dtype, + ) + c_ref[:, :, l_idx] = res + return c_ref + + +# Construct CuTe Pointers for the persistent dense blockscaled GEMM operation (emulated version) +def construct_cute_pointers_emulated( + a: torch.Tensor, + b: torch.Tensor, + sfa: torch.Tensor, + sfb: torch.Tensor, + c: torch.Tensor, + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], +): + a_cute, _ = cutlass_torch.cute_tensor_like( + a.cpu(), + ab_dtype, + is_dynamic_layout=True, + assumed_align=16, + ) + a_cute = cutlass_torch.convert_cute_tensor( + a, + a_cute, + ab_dtype, + is_dynamic_layout=True, + ) + b_cute, _ = cutlass_torch.cute_tensor_like( + b.cpu(), + ab_dtype, + is_dynamic_layout=True, + assumed_align=16, + ) + b_cute = cutlass_torch.convert_cute_tensor( + b, + b_cute, + ab_dtype, + is_dynamic_layout=True, + ) + a_ptr = a_cute.iterator + b_ptr = b_cute.iterator + + sfa_ptr = make_ptr( + sf_dtype, sfa.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + sfb_ptr = make_ptr( + sf_dtype, sfb.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + c_ptr = make_ptr(c_dtype, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=16) + return a_ptr, b_ptr, c_ptr, sfa_ptr, sfb_ptr, a_cute, b_cute + + +# Construct CuTe Pointers for the persistent dense blockscaled GEMM operation (non-emulated version) +def construct_cute_pointers( + a: torch.Tensor, + b: torch.Tensor, + sfa: torch.Tensor, + sfb: torch.Tensor, + c: torch.Tensor, + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], +): + a_ptr = make_ptr(ab_dtype, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=16) + b_ptr = make_ptr(ab_dtype, b.data_ptr(), cute.AddressSpace.gmem, assumed_align=16) + sfa_ptr = make_ptr( + sf_dtype, sfa.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + sfb_ptr = make_ptr( + sf_dtype, sfb.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + c_ptr = make_ptr(c_dtype, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=16) + return a_ptr, b_ptr, c_ptr, sfa_ptr, sfb_ptr + + +# Use uint8 and uint32 to emulate unsupported +# dtype in torch +def prepare_tensors_emulated( + mnkl: Tuple[int, int, int, int], + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + c_dtype: Type[cutlass.Numeric], + a_major: Literal["m", "k"], + b_major: Literal["n", "k"], + c_major: Literal["m", "n"], +): + m, n, k, l = mnkl + sf_k = ceil_div(k, sf_vec_size) + + # Create tensor SFA/SFB with values in [1, 3) + sfa = ( + torch.randint(0, 3, (l, m, sf_k), dtype=torch.uint8) + .permute(1, 2, 0) + .to(dtype=cutlass_torch.dtype(sf_dtype)) + ) + sfb = ( + torch.randint(0, 3, (l, n, sf_k), dtype=torch.uint8) + .permute(1, 2, 0) + .to(dtype=cutlass_torch.dtype(sf_dtype)) + ) + + # Create tensor A/B with values in [0, 2) + if a_major == "k": + a = torch.randint(-2, 2, (l, m, k), dtype=torch.float32, device="cuda").permute( + 1, 2, 0 + ) + else: + a = torch.randint(-2, 2, (l, k, m), dtype=torch.float32, device="cuda").permute( + 2, 1, 0 + ) + if b_major == "k": + b = torch.randint(-2, 2, (l, n, k), dtype=torch.float32, device="cuda").permute( + 1, 2, 0 + ) + else: + b = torch.randint(-2, 2, (l, k, n), dtype=torch.float32, device="cuda").permute( + 2, 1, 0 + ) + if c_major == "n": + c = torch.empty( + (l, m, n), dtype=cutlass_torch.dtype(c_dtype), device="cuda" + ).permute(1, 2, 0) + else: + c = torch.empty( + (l, n, m), dtype=cutlass_torch.dtype(c_dtype), device="cuda" + ).permute(2, 1, 0) + return a, b, c, sfa, sfb + + +def prepare_tensors( + mnkl: Tuple[int, int, int, int], + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + c_dtype: Type[cutlass.Numeric], + a_major: Literal["m", "k"], + b_major: Literal["n", "k"], + c_major: Literal["m", "n"], +): + m, n, k, l = mnkl + + if ab_dtype == cutlass.Float4E2M1FN: + # Using int8 for torch.float4_e2m1fn_x2 tensor allocation + # Thus the size of k needs to be halved in this case. + k_fct = 2 + else: + k_fct = 1 + + sf_k = ceil_div(k, sf_vec_size) + + # Create tensor SFA/SFB + sfa = ( + torch.randint(0, 3, (l, m, sf_k), dtype=torch.uint8) + .permute(1, 2, 0) + .to(dtype=cutlass_torch.dtype(sf_dtype)) + ) + sfb = ( + torch.randint(0, 3, (l, n, sf_k), dtype=torch.uint8) + .permute(1, 2, 0) + .to(dtype=cutlass_torch.dtype(sf_dtype)) + ) + + # Create tensor A/B/C + if a_major == "k": + a = torch.randint( + -2, 2, (l, m, k // k_fct), dtype=torch.int8, device="cuda" + ).permute(1, 2, 0) + else: + a = torch.randint(-2, 2, (l, k, m), dtype=torch.int8, device="cuda").permute( + 2, 1, 0 + ) + if b_major == "k": + b = torch.randint( + -2, 2, (l, n, k // k_fct), dtype=torch.int8, device="cuda" + ).permute(1, 2, 0) + else: + b = torch.randint(-2, 2, (l, k, n), dtype=torch.int8, device="cuda").permute( + 2, 1, 0 + ) + if c_major == "n": + c = torch.randint( + -2, 2, (l, m, n), dtype=cutlass_torch.dtype(c_dtype), device="cuda" + ).permute(1, 2, 0) + else: + c = torch.randint( + -2, 2, (l, n, m), dtype=cutlass_torch.dtype(c_dtype), device="cuda" + ).permute(2, 1, 0) + + if ab_dtype == cutlass.Float4E2M1FN: + a = a.view(dtype=torch.float4_e2m1fn_x2) + b = b.view(dtype=torch.float4_e2m1fn_x2) + else: + a = a.to(dtype=cutlass_torch.dtype(ab_dtype)) + b = b.to(dtype=cutlass_torch.dtype(ab_dtype)) + + c = c.to(dtype=cutlass_torch.dtype(c_dtype)) + return a, b, c, sfa, sfb + + +# This will show how to covert torch tensor +# and pass to CuTe kernel +def run_scaled_mm( + mnkl: Tuple[int, int, int, int], + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + c_dtype: Type[cutlass.Numeric], + a_major: Literal["m", "k"], + b_major: Literal["n", "k"], + c_major: Literal["m", "n"], + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + tolerance: float = 1e-01, + warmup_iterations: int = 0, + iterations: int = 1, + skip_ref_check: bool = False, + use_cold_l2: bool = False, + **kwargs, +): + """Execute a persistent batched dense blockscaled GEMM operation on Blackwell architecture with performance benchmarking (non-emulated dtypes). + + This function prepares input tensors, configures and launches the persistent GEMM kernel, + optionally performs reference validation, and benchmarks the execution performance. + + :param mnkl: Problem size (M, N, K, L) + :type mnkl: Tuple[int, int, int, int] + :param ab_dtype: Data type for input tensors A and B + :type ab_dtype: Type[cutlass.Numeric] + :param sf_dtype: Data type for scale factor tensor + :type sf_dtype: Type[cutlass.Numeric] + :param sf_vec_size: Vector size for scale factor tensor + :type sf_vec_size: int + :param c_dtype: Data type for output tensor C + :type c_dtype: Type[cutlass.Numeric] + :param a_major/b_major/c_major: Memory layout of tensor A/B/C + :type a_major/b_major/c_major: Literal["m", "k", "n"] + :param mma_tiler_mn: MMA tiling size. + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster shape. + :type cluster_shape_mn: Tuple[int, int] + :param tolerance: Tolerance value for reference validation comparison, defaults to 1e-01 + :type tolerance: float, optional + :param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0 + :type warmup_iterations: int, optional + :param iterations: Number of benchmark iterations to run, defaults to 1 + :type iterations: int, optional + :param skip_ref_check: Whether to skip reference result validation, defaults to False + :type skip_ref_check: bool, optional + :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False + :type use_cold_l2: bool, optional + :raises RuntimeError: If CUDA GPU is not available + :raises ValueError: If the configuration is invalid or unsupported by the kernel + :return: Execution time of the GEMM kernel + :rtype: float + """ + print("Running Sm100 Persistent Dense BlockScaled GEMM test with:") + print(f"mnkl: {mnkl}") + print(f"AB dtype: {ab_dtype}, SF dtype: {sf_dtype}, SF Vec size: {sf_vec_size}") + print(f"C dtype: {c_dtype}") + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}") + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}") + + # Unpack parameters + m, n, k, l = mnkl + + # Configure gemm kernel + gemm = Sm100BlockScaledPersistentDenseGemmKernel( + sf_vec_size, + mma_tiler_mn, + cluster_shape_mn, + ) + + # Skip unsupported testcase + if not gemm.can_implement( + mnkl, + ab_dtype, + sf_dtype, + c_dtype, + a_major, + b_major, + c_major, + sf_vec_size, + mma_tiler_mn, + cluster_shape_mn, + ): + raise TypeError( + f"Unsupported testcase {ab_dtype}, {sf_dtype}, {sf_vec_size}, {c_dtype}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}" + ) + + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + torch.manual_seed(1111) + + # Get current CUDA stream from PyTorch + torch_stream = torch.cuda.current_stream() + # Get the raw stream pointer as a CUstream + current_stream = cuda.CUstream(torch_stream.cuda_stream) + + # Check if configuration can be implemented + max_active_clusters = utils.HardwareInfo().get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1] + ) + + # Compile gemm kernel with fake tensors + compiled_gemm = scaled_mm( + gemm, + ab_dtype, + c_dtype, + sf_dtype, + a_major, + b_major, + c_major, + max_active_clusters, + current_stream, + options=f"--opt-level 2", + ) + + # Create Torch Tensors for A, scale factor A, B, scale factor B, C + a, b, c, sfa, sfb = prepare_tensors( + mnkl, ab_dtype, sf_dtype, sf_vec_size, c_dtype, a_major, b_major, c_major + ) + # Reorder scale factor tensors to (32, 4, restM, 4, restK, l) format + sfa_reordered = create_and_reorder_scale_factor_tensor( + l, m, k, sf_vec_size, sf_dtype, sfa + ) + sfb_reordered = create_and_reorder_scale_factor_tensor( + l, n, k, sf_vec_size, sf_dtype, sfb + ) + # Construct CuTe Pointers + a_ptr, b_ptr, c_ptr, sfa_ptr, sfb_ptr = construct_cute_pointers( + a, + b, + sfa_reordered, + sfb_reordered, + c, + ab_dtype, + sf_dtype, + c_dtype, + ) + + # Compute reference result + if not skip_ref_check: + # Execute kernel once for reference checking + compiled_gemm( + a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, (m, n, k, l), current_stream + ) + c_ref = reference_scaled_mm(a, b, sfa, sfb, c, (m, n, k, l), c_dtype) + if c_dtype in (cutlass.Float8E5M2, cutlass.Float8E4M3FN): + # Rtol=0.001 and atol=0.1 are not supported for bitwise comparison of + # low dimensional floats. Please use rtol=0.0 and atol=0.0. + tolerance = 0.0 + torch.testing.assert_close(c, c_ref, atol=tolerance, rtol=tolerance) + + def generate_inputs(): + a, b, c, sfa, sfb = prepare_tensors( + mnkl, + ab_dtype, + sf_dtype, + sf_vec_size, + c_dtype, + a_major, + b_major, + c_major, + ) + # Reorder scale factor tensors to (32, 4, restM, 4, restK, l) format + sfa_reordered = create_and_reorder_scale_factor_tensor( + l, m, k, sf_vec_size, sf_dtype, sfa + ) + sfb_reordered = create_and_reorder_scale_factor_tensor( + l, n, k, sf_vec_size, sf_dtype, sfb + ) + # Construct CuTe Pointers + a_ptr, b_ptr, c_ptr, sfa_ptr, sfb_ptr = construct_cute_pointers( + a, + b, + sfa_reordered, + sfb_reordered, + c, + ab_dtype, + sf_dtype, + c_dtype, + ) + jit_args = cute.testing.JitArguments( + a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, (m, n, k, l), current_stream + ) + # Keep references to external variables (e.g., Torch tensors when taking a view) + jit_args.add_to_scope([a, b, sfa_reordered, sfb_reordered, c]) + return jit_args + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + a.numel() * a.element_size() + + b.numel() * b.element_size() + + sfa.numel() * sfa.element_size() + + sfb.numel() * sfb.element_size() + + c.numel() * c.element_size() + ) + workspace_count = cute.testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = cute.testing.benchmark( + compiled_gemm, + workspace_generator=generate_inputs, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + return exec_time # Return execution time in microseconds + + +# This is to compatible with the other narrow +# precision combinations are not supported in either +# torch or dlpack. For example, Float4E2M1FN with Float8E8M0FNU. +def run_scaled_mm_with_emulated_dtype( + mnkl: Tuple[int, int, int, int], + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + c_dtype: Type[cutlass.Numeric], + a_major: Literal["m", "k"], + b_major: Literal["n", "k"], + c_major: Literal["m", "n"], + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + tolerance: float = 1e-01, + warmup_iterations: int = 0, + iterations: int = 1, + skip_ref_check: bool = False, + use_cold_l2: bool = False, + **kwargs, +): + """Execute a persistent batched dense blockscaled GEMM operation on Blackwell architecture with performance benchmarking (emulated dtypes). + + This function prepares input tensors, configures and launches the persistent GEMM kernel, + optionally performs reference validation, and benchmarks the execution performance. + + :param mnkl: Problem size (M, N, K, L) + :type mnkl: Tuple[int, int, int, int] + :param ab_dtype: Data type for input tensors A and B + :type ab_dtype: Type[cutlass.Numeric] + :param sf_dtype: Data type for scale factor tensor + :type sf_dtype: Type[cutlass.Numeric] + :param sf_vec_size: Vector size for scale factor tensor + :type sf_vec_size: int + :param c_dtype: Data type for output tensor C + :type c_dtype: Type[cutlass.Numeric] + :param a_major/b_major/c_major: Memory layout of tensor A/B/C + :type a_major/b_major/c_major: Literal["m", "n","k"] + :param mma_tiler_mn: MMA tiling size. + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster shape. + :type cluster_shape_mn: Tuple[int, int] + :param tolerance: Tolerance value for reference validation comparison, defaults to 1e-01 + :type tolerance: float, optional + :param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0 + :type warmup_iterations: int, optional + :param iterations: Number of benchmark iterations to run, defaults to 1 + :type iterations: int, optional + :param skip_ref_check: Whether to skip reference result validation, defaults to False + :type skip_ref_check: bool, optional + :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False + :type use_cold_l2: bool, optional + :raises RuntimeError: If CUDA GPU is not available + :raises ValueError: If the configuration is invalid or unsupported by the kernel + :return: Execution time of the GEMM kernel + :rtype: float + """ + print("Running Sm100 Persistent Dense BlockScaled GEMM test (Emulated) with:") + print(f"mnkl: {mnkl}") + print(f"AB dtype: {ab_dtype}, SF dtype: {sf_dtype}, SF Vec size: {sf_vec_size}") + print(f"C dtype: {c_dtype}") + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}") + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}") + + # Unpack parameters + m, n, k, l = mnkl + + # Configure gemm kernel + gemm = Sm100BlockScaledPersistentDenseGemmKernel( + sf_vec_size, + mma_tiler_mn, + cluster_shape_mn, + ) + + # Skip unsupported testcase + if not gemm.can_implement( + mnkl, + ab_dtype, + sf_dtype, + c_dtype, + a_major, + b_major, + c_major, + sf_vec_size, + mma_tiler_mn, + cluster_shape_mn, + ): + raise TypeError( + f"Unsupported testcase {ab_dtype}, {sf_dtype}, {sf_vec_size}, {c_dtype}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}" + ) + + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + torch.manual_seed(1111) + + # Get current CUDA stream from PyTorch + torch_stream = torch.cuda.current_stream() + # Get the raw stream pointer as a CUstream + current_stream = cuda.CUstream(torch_stream.cuda_stream) + + # Check if configuration can be implemented + max_active_clusters = utils.HardwareInfo().get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1] + ) + + # Compile gemm kernel with fake tensors + compiled_gemm = scaled_mm( + gemm, + ab_dtype, + c_dtype, + sf_dtype, + a_major, + b_major, + c_major, + max_active_clusters, + current_stream, + options=f"--opt-level 2", + ) + + # Create Torch Tensors for A, scale factor A, B, scale factor B, C + a, b, c, sfa, sfb = prepare_tensors_emulated( + mnkl, ab_dtype, sf_dtype, sf_vec_size, c_dtype, a_major, b_major, c_major + ) + # Reorder scale factor tensors to (32, 4, restM, 4, restK, l) format + sfa_reordered = create_and_reorder_scale_factor_tensor( + l, m, k, sf_vec_size, sf_dtype, sfa + ) + sfb_reordered = create_and_reorder_scale_factor_tensor( + l, n, k, sf_vec_size, sf_dtype, sfb + ) + # Construct CuTe Pointers + a_ptr, b_ptr, c_ptr, sfa_ptr, sfb_ptr, a_cute, b_cute = ( + construct_cute_pointers_emulated( + a, + b, + sfa_reordered, + sfb_reordered, + c, + ab_dtype, + sf_dtype, + c_dtype, + ) + ) + + # Compute reference result + if not skip_ref_check: + # Execute kernel once for reference checking + compiled_gemm( + a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, (m, n, k, l), current_stream + ) + c_ref = reference_scaled_mm_emulated( + a, b, sfa, sfb, c, (m, n, k, l), sf_vec_size, c_dtype + ) + if c_dtype in (cutlass.Float8E5M2, cutlass.Float8E4M3FN): + # Rtol=0.001 and atol=0.1 are not supported for bitwise comparison of + # low dimensional floats. Please use rtol=0.0 and atol=0.0. + tolerance = 0.0 + torch.testing.assert_close(c, c_ref, atol=tolerance, rtol=tolerance) + + def generate_inputs(): + a, b, c, sfa, sfb = prepare_tensors_emulated( + mnkl, + ab_dtype, + sf_dtype, + sf_vec_size, + c_dtype, + a_major, + b_major, + c_major, + ) + # Reorder scale factor tensors to (32, 4, restM, 4, restK, l) format + sfa_reordered = create_and_reorder_scale_factor_tensor( + l, m, k, sf_vec_size, sf_dtype, sfa + ) + sfb_reordered = create_and_reorder_scale_factor_tensor( + l, n, k, sf_vec_size, sf_dtype, sfb + ) + # Construct CuTe Pointers + a_ptr, b_ptr, c_ptr, sfa_ptr, sfb_ptr, a_cute, b_cute = ( + construct_cute_pointers_emulated( + a, + b, + sfa_reordered, + sfb_reordered, + c, + ab_dtype, + sf_dtype, + c_dtype, + ) + ) + jit_args = cute.testing.JitArguments( + a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, (m, n, k, l), current_stream + ) + # Keep references to external variables (e.g., Torch tensors when taking a view) + jit_args.add_to_scope([a, b, sfa_reordered, sfb_reordered, c, a_cute, b_cute]) + return jit_args + + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + a.numel() * a.element_size() + + b.numel() * b.element_size() + + sfa.numel() * sfa.element_size() + + sfb.numel() * sfb.element_size() + + c.numel() * c.element_size() + ) + workspace_count = cute.testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = cute.testing.benchmark( + compiled_gemm, + workspace_generator=generate_inputs, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + return exec_time # Return execution time in microseconds + + +def run( + mnkl: Tuple[int, int, int, int], + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + c_dtype: Type[cutlass.Numeric], + a_major: Literal["m", "k"], + b_major: Literal["n", "k"], + c_major: Literal["m", "n"], + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + tolerance: float = 1e-01, + warmup_iterations: int = 0, + iterations: int = 1, + skip_ref_check: bool = False, + use_cold_l2: bool = False, + **kwargs, +): + """ + Execute the appropriate GEMM function based on dtype. + + Routes to either run_scaled_mm_with_emulated_dtype or run_scaled_mm + depending on whether the dtypes require emulation. + """ + if is_emulated_dtype(ab_dtype, sf_dtype, c_dtype): + exec_time = run_scaled_mm_with_emulated_dtype( + mnkl, + ab_dtype, + sf_dtype, + sf_vec_size, + c_dtype, + a_major, + b_major, + c_major, + mma_tiler_mn, + cluster_shape_mn, + tolerance, + warmup_iterations, + iterations, + skip_ref_check, + use_cold_l2, + ) + else: + exec_time = run_scaled_mm( + mnkl, + ab_dtype, + sf_dtype, + sf_vec_size, + c_dtype, + a_major, + b_major, + c_major, + mma_tiler_mn, + cluster_shape_mn, + tolerance, + warmup_iterations, + iterations, + skip_ref_check, + use_cold_l2, + ) + return exec_time + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + parser = argparse.ArgumentParser( + description="Example of Sm100 Dense Persistent BlockScaled GEMM." + ) + + parser.add_argument( + "--mnkl", + type=parse_comma_separated_ints, + default=(512, 256, 256, 1), + help="mnkl dimensions (comma-separated)", + ) + parser.add_argument( + "--mma_tiler_mn", + type=parse_comma_separated_ints, + default=(128, 128), + help="Mma tile shape (comma-separated)", + ) + parser.add_argument( + "--cluster_shape_mn", + type=parse_comma_separated_ints, + default=(1, 1), + help="Cluster shape (comma-separated)", + ) + parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.Float4E2M1FN) + parser.add_argument("--sf_dtype", type=cutlass.dtype, default=cutlass.Float8E4M3FN) + parser.add_argument("--sf_vec_size", type=int, default=16) + parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.Float16) + parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k") + parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k") + parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + parser.add_argument( + "--warmup_iterations", type=int, default=0, help="Warmup iterations" + ) + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations to run the kernel", + ) + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + + args = parser.parse_args() + + if len(args.mnkl) != 4: + parser.error("--mnkl must contain exactly 4 values") + + if len(args.mma_tiler_mn) != 2: + parser.error("--mma_tiler_mn must contain exactly 2 values") + + if len(args.cluster_shape_mn) != 2: + parser.error("--cluster_shape_mn must contain exactly 2 values") + + # Execute GEMM with appropriate function based on dtype + run( + args.mnkl, + args.ab_dtype, + args.sf_dtype, + args.sf_vec_size, + args.c_dtype, + args.a_major, + args.b_major, + args.c_major, + args.mma_tiler_mn, + args.cluster_shape_mn, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + print("PASS") diff --git a/reference/grouped_blockscaled_gemm.py b/reference/grouped_blockscaled_gemm.py new file mode 100644 index 00000000..1f6d6586 --- /dev/null +++ b/reference/grouped_blockscaled_gemm.py @@ -0,0 +1,3278 @@ +# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import functools +from typing import List, Type, Tuple, Union +from inspect import isclass + +import torch +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.torch as cutlass_torch +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils +from cutlass.cute.runtime import from_dlpack + +""" +This example provides an experimental implementation of the SM100 grouped blockscaled GEMM kernel, please note that the APIs and implementation details related to this kernel may change in future releases. + +A grouped blockscaled GEMM example for the NVIDIA Blackwell SM100 architecture using CUTE DSL + +This example demonstrates an implementation of grouped blockscaled GEMM using a TMA plus Blackwell SM100 TensorCore +warp-specialized persistent kernel. +The grouped GEMM workload computes a batch of GEMM operations with distinct problem sizes. Pointers to matrices +in global memory are passed to the kernel in an array (also held in global memory). Similarly, problem shapes and +strides are also stored in arrays in GMEM. + +This differs from "Batched Array" GEMM since the size of each GEMM problem in the grouped GEMM concept may be distinct. + +To run this example: + +.. code-block:: bash + + python examples/blackwell/grouped_blockscaled_gemm.py \ + --ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU --sf_vec_size 16 \ + --c_dtype Float16 \ + --mma_tiler_mn 128,128 --cluster_shape_mn 1,1 \ + --problem_sizes_mnkl "(8192,1280,32,1),(32,384,1536,1),(640,1280,32,1),(640,160,32,1)" \ + --num_groups 4 + +The above example command makes 4 groups of different m, n, k sizes. The Blackwell tcgen05 MMA tile shape +is specified as (128, 64) and the cluster shape is (1,1). The input, mma accumulator and output data type +are set as fp16, fp32 and fp16, respectively. + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/grouped_blockscaled_gemm.py \ + --ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU --sf_vec_size 16 \ + --c_dtype Float16 \ + --mma_tiler_mn 128,128 --cluster_shape_mn 1,1 \ + --problem_sizes_mnkl "(8192,1280,32,1),(32,384,1536,1),(640,1280,32,1),(640,160,32,1)" \ + --num_groups 4 + --warmup_iterations 1 --iterations 10 --skip_ref_check + +Constraints: +* Supported input data types: mxf8, mxf4, nvf4 + see detailed valid dtype combinations in below Sm100GroupedBlockScaledGemmKernel class documentation +* A/B tensors must have the same data type, mixed data type is not supported (e.g., mxf8 x mxf4) +* Mma tiler M must be 128 or 256(use_2cta_instrs) +* Mma tiler N must be 128 or 256 +* Cluster shape M/N must be positive and power of 2, total cluster size <= 16 +* Cluster shape M/N must be <= 4 for scale factor multicasts due to limited size of scale factors +* Cluster shape M must be multiple of 2 if Mma tiler M is 256(use_2cta_instrs) +* The l mode(aka, batch size) for each group must be 1. +* The majorness for A, B and C must be the same across all groups. +* The contiguous dimension of A/B/C tensors in each group must be at least 16 bytes aligned, + i.e, number of elements is a multiple of 16 and 32 for Float8 and Float4, respectively. +""" + + +class Sm100GroupedBlockScaledGemmKernel: + """This example demonstrates an implementation of grouped blockscaled GEMM using a TMA plus Blackwell SM100 TensorCore + warp-specialized persistent kernel. + + :param sf_vec_size: Scalefactor vector size. + :type sf_vec_size: int + :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N) + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + + :note: In current version, A and B tensors must have the same data type + - i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported + + :note: Supported combinations of A/B data types, SF data typs and SF vector size: + - MXF8: A/B: Float8E5M2/Float8E4M3FN + SF: Float8E8M0FNU + sf_vec_size: 32 + - MXF4: A/B: Float4E2M1FN + SF: Float8E8M0FNU + sf_vec_size: 32 + - NVF4: A/B: Float4E2M1FN + SF: Float8E8M0FNU/Float8E4M3FN + sf_vec_size: 16 + + :note: Supported accumulator data types: + - Float32 + + :note: Supported C data types: + - Float32 + - Float16/BFloat16 + - Float8E4M3FN/Float8E5M2 + :note: Constraints: + - MMA tiler M must be 128 or 256 (use_2cta_instrs) + - MMA tiler N must be 128/256 + - Cluster shape M must be multiple of 2 if Mma tiler M is 256 + - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 + - Cluster shape M/N must be <= 4 for scale factor multicasts due to limited size of scale factors + """ + + def __init__( + self, + sf_vec_size: int, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ): + """Initializes the configuration for a Blackwell grouped blockscaled GEMM kernel. + + Besides configurations for dense persistent blockscaled GEMM, there is an extra config specific to grouped blockscaled GEMM: + + :param sf_vec_size: Scalefactor vector size. + :type sf_vec_size: int + :param mma_tiler_mn: tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: tuple[int, int] + :param cluster_shape_mn: tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: tuple[int, int] + """ + self.acc_dtype = cutlass.Float32 + self.sf_vec_size = sf_vec_size + self.use_2cta_instrs = mma_tiler_mn[0] == 256 + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + + self.cta_group = ( + tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + self.tensormap_update_mode = utils.TensorMapUpdateMode.SMEM + + self.occupancy = 1 + # Set specialized warp ids + self.epilog_warp_id = ( + 0, + 1, + 2, + 3, + ) + self.mma_warp_id = 4 + self.tma_warp_id = 5 + self.threads_per_cta = 32 * len( + (self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id) + ) + # Set barrier for epilogue sync and tmem ptr sync + self.epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=32 * len(self.epilog_warp_id), + ) + self.tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=2, + num_threads=32 * len((self.mma_warp_id, *self.epilog_warp_id)), + ) + # Barrier used by MMA/TMA warps to signal A/B tensormap initialization completion + self.tensormap_ab_init_barrier = pipeline.NamedBarrier( + barrier_id=3, + num_threads=64, + ) + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + self.num_tmem_alloc_cols = cute.arch.get_max_tmem_alloc_cols("sm_100") + + # Set up configurations that dependent on gemm inputs. + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B/SFA/SFB + - Computing epilogue subtile + - Setting up A/B/SFA/SFB/C stage counts in shared memory + - Computing A/B/SFA/SFB/C shared memory layout + - Checking reserved smem bytes size capacity for mbar, tensor memory management and tensormap updates utilization + """ + # Compute mma instruction shapes + # (MMA_Tile_Shape_M, MMA_Tile_Shape_N, MMA_Inst_Shape_K) + self.mma_inst_shape_mn = ( + self.mma_tiler[0], + self.mma_tiler[1], + ) + # (CTA_Tile_Shape_M, Round_Up(MMA_Tile_Shape_N, 128), MMA_Inst_Shape_K) + self.mma_inst_shape_mn_sfb = ( + self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1), + cute.round_up(self.mma_inst_shape_mn[1], 128), + ) + + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mn, + ) + + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + cute.nvgpu.tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mn_sfb, + ) + + # Compute mma/cluster/tile shapes + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_inst_shape_mn[0], + self.mma_inst_shape_mn[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.mma_tiler_sfb = ( + self.mma_inst_shape_mn_sfb[0], + self.mma_inst_shape_mn_sfb[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + self.cluster_tile_shape_mnk = tuple( + x * y for x, y in zip(self.cta_tile_shape_mnk, (*self.cluster_shape_mn, 1)) + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + self.cluster_layout_sfb_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma_sfb.thr_id.shape,), + ) + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.num_mcast_ctas_sfb = cute.size(self.cluster_layout_sfb_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + self.is_sfb_mcast = self.num_mcast_ctas_sfb > 1 + + # Compute epilogue subtile + self.epi_tile = sm100_utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + + # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory + self.num_acc_stage, self.num_ab_stage, self.num_c_stage = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.sf_dtype, + self.sf_vec_size, + self.smem_capacity, + self.occupancy, + ) + + # Compute A/B/SFA/SFB/C shared memory layout + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + + mbar_smem_bytes = self._get_mbar_smem_bytes( + num_acc_stage=self.num_acc_stage, + num_ab_stage=self.num_ab_stage, + num_c_stage=self.num_c_stage, + ) + + # Use utils.TensorMapUpdateMode.SMEM by default + tensormap_smem_bytes = ( + Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap + * Sm100GroupedBlockScaledGemmKernel.num_tensormaps + ) + if ( + mbar_smem_bytes + + tensormap_smem_bytes + + Sm100GroupedBlockScaledGemmKernel.tensor_memory_management_bytes + > self.reserved_smem_bytes + ): + raise ValueError( + f"smem consumption for mbar and tensormap {mbar_smem_bytes + tensormap_smem_bytes} exceeds the " + f"reserved smem bytes {self.reserved_smem_bytes}" + ) + + @cute.jit + def __call__( + self, + initial_a: cute.Tensor, + initial_b: cute.Tensor, + initial_c: cute.Tensor, + initial_sfa: cute.Tensor, + initial_sfb: cute.Tensor, + group_count: cutlass.Constexpr[int], + problem_shape_mnkl: cute.Tensor, + strides_abc: cute.Tensor, + tensor_address_abc: cute.Tensor, + tensor_address_sfasfb: cute.Tensor, + total_num_clusters: cutlass.Constexpr[int], + tensormap_cute_tensor: cute.Tensor, + max_active_clusters: cutlass.Constexpr[int], + stream: cuda.CUstream, + ): + """Execute the GEMM operation in steps: + - Setup static attributes before smem/grid/tma computation + - Setup TMA load/store atoms and tensors + - Compute grid size with regard to hardware constraints + - Define shared storage for kernel + - Launch the kernel synchronously + + For grouped GEMM, tensor shapes, tensor strides, and tensor address are all provided + by different tensors in global memory. The "initial" tensors only carry data type and + majorness information. + + :param initial_a: Initial tensor A, used for data type and majorness information. + :type initial_a: cute.Tensor + :param initial_b: Initial tensor B, used for data type and majorness information. + :type initial_b: cute.Tensor + :param initial_c: Initial tensor C, used for data type and majorness information. + :type initial_c: cute.Tensor + :param initial_sfa: Initial tensor SFA, used for data type and majorness information. + :type initial_sfa: cute.Tensor + :param initial_sfb: Initial tensor SFB, used for data type and majorness information. + :type initial_sfb: cute.Tensor + :param group_count: The number of GEMM groups. + :type group_count: cutlass.Constexpr[int] + :param problem_shape_mnkl: Tensor containing the (M, N, K, L) shape for each group. + :type problem_shape_mnkl: cute.Tensor + :param strides_abc: Tensor containing the strides for A, B, and C for each group. + :type strides_abc: cute.Tensor + :param tensor_address_abc: Tensor containing the base addresses for A, B, and C for each group. + :type tensor_address_abc: cute.Tensor + :param tensor_address_sfasfb: Tensor containing the base addresses for SFA and SFB for each group. + :type tensor_address_sfasfb: cute.Tensor + :param total_num_clusters: Total number of clusters needed for all groups. + :type total_num_clusters: cutlass.Constexpr[int] + :param tensormap_cute_tensor: Tensor for storing tensormaps. + :type tensormap_cute_tensor: cute.Tensor + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr[int] + :param stream: CUDA stream for asynchronous execution. + :type stream: cuda.CUstream + :raises TypeError: If A and B data types do not match. + """ + self.a_dtype = initial_a.element_type + self.b_dtype = initial_b.element_type + self.sf_dtype = initial_sfa.element_type + self.c_dtype = initial_c.element_type + self.is_nvfp4_output = self.c_dtype is cutlass.Float4E2M1FN + self.a_major_mode = utils.LayoutEnum.from_tensor(initial_a).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(initial_b).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(initial_c) + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout + # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL) + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( + initial_a.shape, self.sf_vec_size + ) + initial_sfa = cute.make_tensor(initial_sfa.iterator, sfa_layout) + + # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( + initial_b.shape, self.sf_vec_size + ) + initial_sfb = cute.make_tensor(initial_sfb.iterator, sfb_layout) + + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mn, + ) + + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + cute.nvgpu.tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mn_sfb, + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + initial_a, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # Setup TMA load for B + b_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + initial_b, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # Setup TMA load for SFA + sfa_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + sfa_smem_layout = cute.slice_( + self.sfa_smem_layout_staged, (None, None, None, 0) + ) + tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( + sfa_op, + initial_sfa, + sfa_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + + # Setup TMA load for SFB + sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB( + self.cluster_shape_mn, tiled_mma.thr_id + ) + sfb_smem_layout = cute.slice_( + self.sfb_smem_layout_staged, (None, None, None, 0) + ) + tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + sfb_op, + initial_sfb, + sfb_smem_layout, + self.mma_tiler_sfb, + tiled_mma_sfb, + self.cluster_layout_sfb_vmnk.shape, + internal_type=cutlass.Int16, + ) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout) + sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) + self.num_tma_load_bytes = ( + a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size + ) * atom_thr_size + + # Setup TMA store for C + epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + initial_c, + epi_smem_layout, + self.epi_tile, + ) + + # Compute grid size + self.tile_sched_params, grid = self._compute_grid( + total_num_clusters, self.cluster_shape_mn, max_active_clusters + ) + + self.buffer_align_bytes = 1024 + self.size_tensormap_in_i64 = ( + Sm100GroupedBlockScaledGemmKernel.num_tensormaps + * Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap + // 8 + ) + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + tensormap_buffer: cute.struct.MemRange[ + cutlass.Int64, self.size_tensormap_in_i64 + ] + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + tmem_dealloc_mbar: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + cute.cosize(self.c_smem_layout_staged.outer), + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sSFA: cute.struct.Align[ + cute.struct.MemRange[ + self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sSFB: cute.struct.Align[ + cute.struct.MemRange[ + self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tiled_mma_sfb, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + tma_atom_c, + tma_tensor_c, + self.cluster_layout_vmnk, + self.cluster_layout_sfb_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.sfa_smem_layout_staged, + self.sfb_smem_layout_staged, + self.c_smem_layout_staged, + self.epi_tile, + self.tile_sched_params, + group_count, + problem_shape_mnkl, + strides_abc, + tensor_address_abc, + tensor_address_sfasfb, + tensormap_cute_tensor, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + stream=stream, + min_blocks_per_mp=1, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tiled_mma_sfb: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + mSFA_mkl: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + mSFB_nkl: cute.Tensor, + tma_atom_c: cute.CopyAtom, + mC_mnl: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + cluster_layout_sfb_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout], + epi_tile: cute.Tile, + tile_sched_params: utils.PersistentTileSchedulerParams, + group_count: cutlass.Constexpr, + problem_sizes_mnkl: cute.Tensor, + strides_abc: cute.Tensor, + ptrs_abc: cute.Tensor, + ptrs_sfasfb: cute.Tensor, + tensormaps: cute.Tensor, + ): + """ + GPU device kernel performing the grouped GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + if warp_idx == self.tma_warp_id: + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_a) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_b) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_sfa) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_sfb) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + # coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: tensormap buffer, a+b full/empty, accumulator full/empty, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + tensormap_smem_ptr = storage.tensormap_buffer.data_ptr() + tensormap_a_smem_ptr = tensormap_smem_ptr + tensormap_b_smem_ptr = ( + tensormap_a_smem_ptr + + Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap // 8 + ) + tensormap_sfa_smem_ptr = ( + tensormap_b_smem_ptr + + Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap // 8 + ) + tensormap_sfb_smem_ptr = ( + tensormap_sfa_smem_ptr + + Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap // 8 + ) + tensormap_c_smem_ptr = ( + tensormap_sfb_smem_ptr + + Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap // 8 + ) + + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar.ptr + tmem_holding_buf_ptr = storage.tmem_holding_buf.ptr + + # Initialize mainloop ab_pipeline (barrier) and states + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer + ) + ab_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Initialize acc_pipeline (barrier) and states + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = len(self.epilog_warp_id) * ( + 2 if use_2cta_instrs else 1 + ) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_acc_consumer_threads + ) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Tensor memory dealloc barrier init + if use_2cta_instrs: + if warp_idx == self.tma_warp_id: + num_tmem_dealloc_threads = 32 + with cute.arch.elect_one(): + cute.arch.mbarrier_init( + tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads + ) + + # Cluster arrive after barrier init + pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True) + + # + # Setup smem tensor A/B/SFA/SFB/C + # + sC = storage.sC.get_tensor( + c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner + ) + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + # (MMA, MMA_M, MMA_K, STAGE) + sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) + # (MMA, MMA_N, MMA_K, STAGE) + sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged) + + # + # Compute multicast mask for A/B/SFA/SFB buffer full + # + a_full_mcast_mask = None + b_full_mcast_mask = None + sfa_full_mcast_mask = None + sfb_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + sfa_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + sfb_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1 + ) + + # + # Local_tile partition global tensors + # + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bK, RestM, RestK, RestL) + gSFA_mkl = cute.local_tile( + mSFA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gSFB_nkl = cute.local_tile( + mSFB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgSFA = thr_mma.partition_A(gSFA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCgC = thr_mma.partition_C(gC_mnl) + + # + # Partition global/shared tensor for TMA load A/B + # + # TMA load A partition_S/D + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # TMA load scaled factor A partition_S/D + sfa_cta_layout = a_cta_layout + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsSFA, tAgSFA = cute.nvgpu.cpasync.tma_partition( + tma_atom_sfa, + block_in_cluster_coord_vmnk[2], + sfa_cta_layout, + cute.group_modes(sSFA, 0, 3), + cute.group_modes(tCgSFA, 0, 3), + ) + tAsSFA = cute.filter_zeros(tAsSFA) + tAgSFA = cute.filter_zeros(tAgSFA) + + # TMA load scaled factor B partition_S/D + sfb_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition( + tma_atom_sfb, + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + cute.group_modes(sSFB, 0, 3), + cute.group_modes(tCgSFB, 0, 3), + ) + tBsSFB = cute.filter_zeros(tBsSFB) + tBgSFB = cute.filter_zeros(tBgSFB) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = tiled_mma.make_fragment_C( + cute.append(acc_shape, self.num_acc_stage) + ) + + # + # Cluster wait before tensor memory alloc + # + pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn) + + # + # Get tensormap buffer address + # + grid_dim = cute.arch.grid_dim() + tensormap_workspace_idx = ( + bidz * grid_dim[1] * grid_dim[0] + bidy * grid_dim[0] + bidx + ) + + tensormap_manager = utils.TensorMapManager( + utils.TensorMapUpdateMode.SMEM, + Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap, + ) + tensormap_a_gmem_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(tensormap_workspace_idx, 0, None)].iterator + ) + tensormap_b_gmem_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(tensormap_workspace_idx, 1, None)].iterator + ) + tensormap_sfa_gmem_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(tensormap_workspace_idx, 2, None)].iterator + ) + tensormap_sfb_gmem_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(tensormap_workspace_idx, 3, None)].iterator + ) + tensormap_c_gmem_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(tensormap_workspace_idx, 4, None)].iterator + ) + + # + # Persistent tile scheduling loop + # + # When the problem shapes are on device, we launch one CTA per SM. + # The if condition later prevents the warps from extra CTAs from doing any work. + tile_sched = utils.StaticPersistentGroupTileScheduler.create( + tile_sched_params, + cute.arch.block_idx(), + grid_dim, + self.cluster_tile_shape_mnk, + utils.create_initial_search_state(), + group_count, + problem_sizes_mnkl, + ) + initial_work_tile_info = tile_sched.initial_work_tile_info() + + # + # Specialized TMA load warp + # + if warp_idx == self.tma_warp_id and initial_work_tile_info.is_valid_tile: + # + # Persistent tile scheduling loop + # + work_tile = initial_work_tile_info + + tensormap_init_done = cutlass.Boolean(False) + # group index of last tile + last_group_idx = cutlass.Int32(-1) + + ab_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_ab_stage + ) + + while work_tile.is_valid_tile: + grouped_gemm_cta_tile_info = work_tile.group_search_result + cur_k_tile_cnt = grouped_gemm_cta_tile_info.cta_tile_count_k + cur_group_idx = grouped_gemm_cta_tile_info.group_idx + is_k_tile_cnt_zero = cur_k_tile_cnt == 0 + # Do not load any data if cur_k_tile_cnt is 0 + if not is_k_tile_cnt_zero: + is_group_changed = cur_group_idx != last_group_idx + # skip tensormap update if we're working on the same group + if is_group_changed: + real_tensor_a = self.make_tensor_abc_for_tensormap_update( + cur_group_idx, + self.a_dtype, + ( + grouped_gemm_cta_tile_info.problem_shape_m, + grouped_gemm_cta_tile_info.problem_shape_n, + grouped_gemm_cta_tile_info.problem_shape_k, + ), + strides_abc, + ptrs_abc, + 0, # 0 for tensor A + ) + real_tensor_b = self.make_tensor_abc_for_tensormap_update( + cur_group_idx, + self.b_dtype, + ( + grouped_gemm_cta_tile_info.problem_shape_m, + grouped_gemm_cta_tile_info.problem_shape_n, + grouped_gemm_cta_tile_info.problem_shape_k, + ), + strides_abc, + ptrs_abc, + 1, # 1 for tensor B + ) + real_tensor_sfa = self.make_tensor_sfasfb_for_tensormap_update( + cur_group_idx, + self.sf_dtype, + ( + grouped_gemm_cta_tile_info.problem_shape_m, + grouped_gemm_cta_tile_info.problem_shape_n, + grouped_gemm_cta_tile_info.problem_shape_k, + ), + ptrs_sfasfb, + 0, # 0 for tensor SFA + ) + real_tensor_sfb = self.make_tensor_sfasfb_for_tensormap_update( + cur_group_idx, + self.sf_dtype, + ( + grouped_gemm_cta_tile_info.problem_shape_m, + grouped_gemm_cta_tile_info.problem_shape_n, + grouped_gemm_cta_tile_info.problem_shape_k, + ), + ptrs_sfasfb, + 1, # 1 for tensor SFB + ) + if not tensormap_init_done: + # wait tensormap initialization complete + self.tensormap_ab_init_barrier.arrive_and_wait() + tensormap_init_done = True + + tensormap_manager.update_tensormap( + ( + real_tensor_a, + real_tensor_b, + real_tensor_sfa, + real_tensor_sfb, + ), + (tma_atom_a, tma_atom_b, tma_atom_sfa, tma_atom_sfb), + ( + tensormap_a_gmem_ptr, + tensormap_b_gmem_ptr, + tensormap_sfa_gmem_ptr, + tensormap_sfb_gmem_ptr, + ), + self.tma_warp_id, + ( + tensormap_a_smem_ptr, + tensormap_b_smem_ptr, + tensormap_sfa_smem_ptr, + tensormap_sfb_smem_ptr, + ), + ) + + mma_tile_coord_mnl = ( + grouped_gemm_cta_tile_info.cta_tile_idx_m + // cute.size(tiled_mma.thr_id.shape), + grouped_gemm_cta_tile_info.cta_tile_idx_n, + 0, + ) + + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), RestK) + tAgA_slice = tAgA[ + (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) + ] + # ((atom_v, rest_v), RestK) + tBgB_slice = tBgB[ + (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) + ] + + # ((atom_v, rest_v), RestK) + tAgSFA_slice = tAgSFA[ + (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) + ] + # ((atom_v, rest_v), RestK) + tBgSFB_slice = tBgSFB[ + (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) + ] + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + ab_producer_state.reset_count() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < cur_k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + + if is_group_changed: + tensormap_manager.fence_tensormap_update(tensormap_a_gmem_ptr) + tensormap_manager.fence_tensormap_update(tensormap_b_gmem_ptr) + tensormap_manager.fence_tensormap_update(tensormap_sfa_gmem_ptr) + tensormap_manager.fence_tensormap_update(tensormap_sfb_gmem_ptr) + # + # Tma load loop + # + for k_tile in cutlass.range(0, cur_k_tile_cnt, 1, unroll=1): + # Conditionally wait for AB buffer empty + ab_pipeline.producer_acquire( + ab_producer_state, peek_ab_empty_status + ) + + # TMA load A/B/SFA/SFB + cute.copy( + tma_atom_a, + tAgA_slice[(None, ab_producer_state.count)], + tAsA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier( + ab_producer_state + ), + mcast_mask=a_full_mcast_mask, + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_a_gmem_ptr, + cute.AddressSpace.generic, + ), + ) + cute.copy( + tma_atom_b, + tBgB_slice[(None, ab_producer_state.count)], + tBsB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier( + ab_producer_state + ), + mcast_mask=b_full_mcast_mask, + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_b_gmem_ptr, + cute.AddressSpace.generic, + ), + ) + cute.copy( + tma_atom_sfa, + tAgSFA_slice[(None, ab_producer_state.count)], + tAsSFA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier( + ab_producer_state + ), + mcast_mask=sfa_full_mcast_mask, + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_sfa_gmem_ptr, + cute.AddressSpace.generic, + ), + ) + cute.copy( + tma_atom_sfb, + tBgSFB_slice[(None, ab_producer_state.count)], + tBsSFB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier( + ab_producer_state + ), + mcast_mask=sfb_full_mcast_mask, + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_sfb_gmem_ptr, + cute.AddressSpace.generic, + ), + ) + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 + ab_producer_state.advance() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < cur_k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + else: + if not tensormap_init_done: + # wait tensormap initialization complete + self.tensormap_ab_init_barrier.arrive_and_wait() + tensormap_init_done = True + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + last_group_idx = cur_group_idx + + # + # Wait A/B buffer empty + # + ab_pipeline.producer_tail(ab_producer_state) + + # + # Specialized MMA warp + # + if warp_idx == self.mma_warp_id and initial_work_tile_info.is_valid_tile: + # + # Initialize tensormaps for A, B, SFA and SFB + # + tensormap_manager.init_tensormap_from_atom( + tma_atom_a, tensormap_a_smem_ptr, self.mma_warp_id + ) + tensormap_manager.init_tensormap_from_atom( + tma_atom_b, tensormap_b_smem_ptr, self.mma_warp_id + ) + tensormap_manager.init_tensormap_from_atom( + tma_atom_sfa, tensormap_sfa_smem_ptr, self.mma_warp_id + ) + tensormap_manager.init_tensormap_from_atom( + tma_atom_sfb, tensormap_sfb_smem_ptr, self.mma_warp_id + ) + # indicate tensormap initialization has finished + self.tensormap_ab_init_barrier.arrive_and_wait() + + # + # Bar sync for retrieve tensor memory ptr from shared mem + # + self.tmem_alloc_barrier.arrive_and_wait() + + # + # Retrieving tensor memory ptr and make accumulator/SFA/SFB tensor + # + # Make accumulator tmem tensor + acc_tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=tmem_holding_buf_ptr, + ) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # Make SFA tmem tensor + sfa_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base), + dtype=self.sf_dtype, + ) + # (MMA, MMA_M, MMA_K) + tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) + + # Make SFB tmem tensor + sfb_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + + tcgen05.find_tmem_tensor_col_offset(tCtSFA), + dtype=self.sf_dtype, + ) + # (MMA, MMA_N, MMA_K) + tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) + # + # Partition for S2T copy of SFA/SFB + # + tiled_copy_s2t_sfa, tCsSFA_compact_s2t, tCtSFA_compact_s2t = ( + self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA) + ) + tiled_copy_s2t_sfb, tCsSFB_compact_s2t, tCtSFB_compact_s2t = ( + self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB) + ) + + # + # Persistent tile scheduling loop + # + work_tile = initial_work_tile_info + + ab_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_ab_stage + ) + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + while work_tile.is_valid_tile: + cur_group_idx = work_tile.group_search_result.group_idx + problem_shape_k = work_tile.group_search_result.problem_shape_k + + # MMA warp is only interested in number of tiles along K dimension + cur_k_tile_cnt = ( + problem_shape_k + self.cluster_tile_shape_mnk[2] - 1 + ) // self.cluster_tile_shape_mnk[2] + is_k_tile_cnt_zero = cur_k_tile_cnt == 0 + + # (MMA, MMA_M, MMA_N) + tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] + + # Peek (try_wait) AB buffer full for k_tile = 0 + ab_consumer_state.reset_count() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < cur_k_tile_cnt and is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # + # Wait for accumulator buffer empty + # + if is_leader_cta and not is_k_tile_cnt_zero: + acc_pipeline.producer_acquire(acc_producer_state) + + # + # Reset the ACCUMULATE field for each tile + # + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + # + # Mma mainloop + # + for k_tile in range(cur_k_tile_cnt): + if is_leader_cta: + # Conditionally wait for AB buffer full + ab_pipeline.consumer_wait( + ab_consumer_state, peek_ab_full_status + ) + + # Copy SFA/SFB from smem to tmem + s2t_stage_coord = ( + None, + None, + None, + None, + ab_consumer_state.index, + ) + tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord] + tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord] + cute.copy( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t_staged, + tCtSFA_compact_s2t, + ) + cute.copy( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t_staged, + tCtSFB_compact_s2t, + ) + + # tCtAcc += tCrA * tCrSFA * tCrB * tCrSFB + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = ( + None, + None, + kblock_idx, + ab_consumer_state.index, + ) + + # Set SFA/SFB tensor to tiled_mma + sf_kblock_coord = (None, None, kblock_idx) + tiled_mma.set( + tcgen05.Field.SFA, + tCtSFA[sf_kblock_coord].iterator, + ) + tiled_mma.set( + tcgen05.Field.SFB, + tCtSFB[sf_kblock_coord].iterator, + ) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_pipeline.consumer_release(ab_consumer_state) + + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 + ab_consumer_state.advance() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < cur_k_tile_cnt: + if is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # + # Async arrive accumulator buffer full + # + if not is_k_tile_cnt_zero: + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + acc_producer_state.advance() + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Wait for accumulator buffer empty + # + acc_pipeline.producer_tail(acc_producer_state) + + # + # Specialized epilogue warps + # + if warp_idx < self.mma_warp_id and initial_work_tile_info.is_valid_tile: + # initialize tensorap for C + tensormap_manager.init_tensormap_from_atom( + tma_atom_c, + tensormap_c_smem_ptr, + self.epilog_warp_id[0], + ) + # + # Alloc tensor memory buffer + # + if warp_idx == self.epilog_warp_id[0]: + cute.arch.alloc_tmem( + self.num_tmem_alloc_cols, + tmem_holding_buf_ptr, + is_two_cta=use_2cta_instrs, + ) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + self.tmem_alloc_barrier.arrive_and_wait() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + acc_tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=tmem_holding_buf_ptr, + ) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + ### Start from here + # + # Partition for epilogue + # + epi_tidx = tidx + tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = ( + self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs + ) + ) + + tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, epi_tidx, sC + ) + tma_atom_c, bSG_sC, bSG_gC_partitioned = ( + self.epilog_gmem_copy_and_partition( + epi_tidx, tma_atom_c, tCgC, epi_tile, sC + ) + ) + + # + # Persistent tile scheduling loop + # + work_tile = initial_work_tile_info + + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + + # Threads/warps participating in tma store pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilog_warp_id), + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=c_producer_group, + ) + # group index to start searching + last_group_idx = cutlass.Int32(-1) + + while work_tile.is_valid_tile: + grouped_gemm_cta_tile_info = work_tile.group_search_result + cur_group_idx = grouped_gemm_cta_tile_info.group_idx + cur_k_tile_cnt = grouped_gemm_cta_tile_info.cta_tile_count_k + is_k_tile_cnt_zero = cur_k_tile_cnt == 0 + is_group_changed = cur_group_idx != last_group_idx + + # We still need to store 0s when k_tile_cnt is 0 + if is_group_changed: + # construct tensor c based on real shape, stride information + real_tensor_c = self.make_tensor_abc_for_tensormap_update( + cur_group_idx, + self.c_dtype, + ( + grouped_gemm_cta_tile_info.problem_shape_m, + grouped_gemm_cta_tile_info.problem_shape_n, + grouped_gemm_cta_tile_info.problem_shape_k, + ), + strides_abc, + ptrs_abc, + 2, # 2 for tensor C + ) + tensormap_manager.update_tensormap( + ((real_tensor_c),), + ((tma_atom_c),), + ((tensormap_c_gmem_ptr),), + self.epilog_warp_id[0], + (tensormap_c_smem_ptr,), + ) + + mma_tile_coord_mnl = ( + grouped_gemm_cta_tile_info.cta_tile_idx_m + // cute.size(tiled_mma.thr_id.shape), + grouped_gemm_cta_tile_info.cta_tile_idx_n, + 0, + ) + + # + # Slice to per mma tile index + # + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC_partitioned[ + ( + None, + None, + None, + *mma_tile_coord_mnl, + ) + ] + + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[ + (None, None, None, None, None, acc_consumer_state.index) + ] + + # + # Wait for accumulator buffer full + # + if not is_k_tile_cnt_zero: + acc_pipeline.consumer_wait(acc_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + + if is_group_changed: + if warp_idx == self.epilog_warp_id[0]: + tensormap_manager.fence_tensormap_update(tensormap_c_gmem_ptr) + + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt + for subtile_idx in range(subtile_cnt): + if not is_k_tile_cnt_zero: + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Convert to C type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + tRS_rC.store(acc_vec.to(self.c_dtype)) + else: + if cutlass.const_expr(self.is_nvfp4_output): + zeros_i8 = cute.make_rmem_tensor( + cute.recast_layout( + cutlass.Int8.width, + self.c_dtype.width, + tRS_rC.layout, + ), + cutlass.Int8, + ) + zeros_i8.fill(0) + tRS_rC.store( + cute.recast_tensor(zeros_i8, self.c_dtype).load() + ) + else: + tRS_rC.fill(0) + + # + # Store C to shared memory + # + c_buffer = (num_prev_subtiles + subtile_idx) % self.num_c_stage + cute.copy( + tiled_copy_r2s, + tRS_rC, + tRS_sC[(None, None, None, c_buffer)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + "async.shared", + space="cta", + ) + self.epilog_sync_barrier.arrive_and_wait() + + # + # TMA store C to global memory + # + if warp_idx == self.epilog_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, subtile_idx)], + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_c_gmem_ptr, + cute.AddressSpace.generic, + ), + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + self.epilog_sync_barrier.arrive_and_wait() + # + # Async arrive accumulator buffer empty + # + if not is_k_tile_cnt_zero: + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + last_group_idx = cur_group_idx + + # + # Dealloc the tensor memory buffer + # + if warp_idx == self.epilog_warp_id[0]: + cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs) + self.epilog_sync_barrier.arrive_and_wait() + if warp_idx == self.epilog_warp_id[0]: + if use_2cta_instrs: + cute.arch.mbarrier_arrive( + tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1 + ) + cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) + cute.arch.dealloc_tmem( + acc_tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs + ) + # + # Wait for C store complete + # + c_pipeline.producer_tail() + + @cute.jit + def make_tensor_abc_for_tensormap_update( + self, + group_idx: cutlass.Int32, + dtype: Type[cutlass.Numeric], + problem_shape_mnk: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32], + strides_abc: cute.Tensor, + tensor_address_abc: cute.Tensor, + tensor_index: int, + ): + """Extract stride and tensor address for a given group and construct a global tensor for A, B or C. + + This function is used within the kernel to dynamically create a CUTE tensor + representing A, B, or C for the current group being processed, using the + group-specific address, shape, and stride information. + + :param group_idx: The index of the current group within the grouped GEMM. + :type group_idx: cutlass.Int32 + :param dtype: The data type of the tensor elements (e.g., cutlass.Float16). + :type dtype: Type[cutlass.Numeric] + :param problem_shape_mnk: The (M, N, K) problem shape for the current group. + :type problem_shape_mnk: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32] + :param strides_abc: Tensor containing strides for A, B, C for all groups. Layout: (group_count, 3, 2). + :type strides_abc: cute.Tensor + :param tensor_address_abc: Tensor containing global memory addresses for A, B, C for all groups. Layout: (group_count, 3). + :type tensor_address_abc: cute.Tensor + :param tensor_index: Specifies which tensor to create: 0 for A, 1 for B, 2 for C. + :type tensor_index: int + :return: A CUTE tensor representing the requested global memory tensor (A, B, or C) for the specified group. + :rtype: cute.Tensor + :raises TypeError: If the provided dtype is not a subclass of cutlass.Numeric. + """ + ptr_i64 = tensor_address_abc[(group_idx, tensor_index)] + if cutlass.const_expr( + not isclass(dtype) or not issubclass(dtype, cutlass.Numeric) + ): + raise TypeError( + f"dtype must be a type of cutlass.Numeric, got {type(dtype)}" + ) + tensor_gmem_ptr = cute.make_ptr( + dtype, ptr_i64, cute.AddressSpace.gmem, assumed_align=16 + ) + + strides_tensor_gmem = strides_abc[(group_idx, tensor_index, None)] + strides_tensor_reg = cute.make_rmem_tensor( + cute.make_layout(2), + strides_abc.element_type, + ) + cute.autovec_copy(strides_tensor_gmem, strides_tensor_reg) + stride_mn = strides_tensor_reg[0] + stride_k = strides_tensor_reg[1] + c1 = cutlass.Int32(1) + c0 = cutlass.Int32(0) + + if cutlass.const_expr(tensor_index == 0): # tensor A + m = problem_shape_mnk[0] + k = problem_shape_mnk[2] + return cute.make_tensor( + tensor_gmem_ptr, + cute.make_layout((m, k, c1), stride=(stride_mn, stride_k, c0)), + ) + elif cutlass.const_expr(tensor_index == 1): # tensor B + n = problem_shape_mnk[1] + k = problem_shape_mnk[2] + return cute.make_tensor( + tensor_gmem_ptr, + cute.make_layout((n, k, c1), stride=(stride_mn, stride_k, c0)), + ) + else: # tensor C + m = problem_shape_mnk[0] + n = problem_shape_mnk[1] + return cute.make_tensor( + tensor_gmem_ptr, + cute.make_layout((m, n, c1), stride=(stride_mn, stride_k, c0)), + ) + + @cute.jit + def make_tensor_sfasfb_for_tensormap_update( + self, + group_idx: cutlass.Int32, + dtype: Type[cutlass.Numeric], + problem_shape_mnk: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32], + tensor_address_sfasfb: cute.Tensor, + tensor_index: int, + ): + """Extract tensor address for a given group and construct a global tensor for SFA or SFB. + + This function is used within the kernel to dynamically create a CUTE tensor + representing SFA or SFB for the current group being processed, using the + group-specific address, shape information. + + :param group_idx: The index of the current group within the grouped GEMM. + :type group_idx: cutlass.Int32 + :param dtype: The data type of the tensor elements (e.g., cutlass.Float16). + :type dtype: Type[cutlass.Numeric] + :param problem_shape_mnk: The (M, N, K) problem shape for the current group. + :type problem_shape_mnk: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32] + :param tensor_address_sfasfb: Tensor containing global memory addresses for SFA, SFB for all groups. Layout: (group_count, 2). + :type tensor_address_sfasfb: cute.Tensor + :param tensor_index: Specifies which tensor to create: 0 for SFA, 1 for SFB. + :type tensor_index: int + :return: A CUTE tensor representing the requested global memory tensor (SFA, SFB) for the specified group. + :rtype: cute.Tensor + :raises TypeError: If the provided dtype is not a subclass of cutlass.Numeric. + """ + ptr_i64 = tensor_address_sfasfb[(group_idx, tensor_index)] + if cutlass.const_expr( + not isclass(dtype) or not issubclass(dtype, cutlass.Numeric) + ): + raise TypeError( + f"dtype must be a type of cutlass.Numeric, got {type(dtype)}" + ) + tensor_gmem_ptr = cute.make_ptr( + dtype, ptr_i64, cute.AddressSpace.gmem, assumed_align=16 + ) + + c1 = cutlass.Int32(1) + if cutlass.const_expr(tensor_index == 0): # tensor SFA + m = problem_shape_mnk[0] + k = problem_shape_mnk[2] + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( + (m, k, c1), self.sf_vec_size + ) + return cute.make_tensor( + tensor_gmem_ptr, + sfa_layout, + ) + else: # tensor SFB + n = problem_shape_mnk[1] + k = problem_shape_mnk[2] + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( + (n, k, c1), self.sf_vec_size + ) + return cute.make_tensor( + tensor_gmem_ptr, + sfb_layout, + ) + + def mainloop_s2t_copy_and_partition( + self, + sSF: cute.Tensor, + tSF: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for smem to tmem load for scale factor tensor, then use it to partition smem memory (source) and tensor memory (destination). + + :param sSF: The scale factor tensor in smem + :type sSF: cute.Tensor + :param tSF: The scale factor tensor in tmem + :type tSF: cute.Tensor + + :return: A tuple containing (tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t) where: + - tiled_copy_s2t: The tiled copy operation for smem to tmem load for scale factor tensor(s2t) + - tCsSF_compact_s2t: The partitioned scale factor tensor in smem + - tSF_compact_s2t: The partitioned scale factor tensor in tmem + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # (MMA, MMA_MN, MMA_K, STAGE) + tCsSF_compact = cute.filter_zeros(sSF) + # (MMA, MMA_MN, MMA_K) + tCtSF_compact = cute.filter_zeros(tSF) + + # Make S2T CopyAtom and tiledCopy + copy_atom_s2t = cute.make_copy_atom( + tcgen05.Cp4x32x128bOp(self.cta_group), + self.sf_dtype, + ) + tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact) + thr_copy_s2t = tiled_copy_s2t.get_slice(0) + + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t, tCsSF_compact_s2t_ + ) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) + tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact) + + return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0, None)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy( + copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_mnl_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_rmem_tensor( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + :type sepi: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tidx: cutlass.Int32, + atom: Union[cute.CopyAtom, cute.TiledCopy], + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to: + partition shared memory (source) and global memory (destination) for TMA store version. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version + :type atom: cute.CopyAtom or cute.TiledCopy + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing (tma_atom_c, bSG_sC, bSG_gC) where: + - tma_atom_c: The TMA copy atom + - bSG_sC: The partitioned shared memory tensor C + - bSG_gC: The partitioned global tensor C + :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + + tma_atom_c = atom + sC_for_tma_partition = cute.group_modes(sC, 0, 2) + gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sC_for_tma_partition, + gC_for_tma_partition, + ) + return tma_atom_c, bSG_sC, bSG_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + epi_tile: cute.Tile, + c_dtype: Type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + smem_capacity: int, + occupancy: int, + ) -> Tuple[int, int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout enum of operand C. + :type c_layout: utils.LayoutEnum + :param sf_dtype: Data type of Scale factor. + :type sf_dtype: type[cutlass.Numeric] + :param sf_vec_size: Scale factor vector size. + :type sf_vec_size: int + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + + :return: A tuple containing the computed number of stages for: + (ACC stages, A/B operand stages, C stages) + :rtype: tuple[int, int, int] + """ + # ACC stages + num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2 + + # Default C stages + num_c_stage = 2 + + # Calculate smem layout and size for one stage of A, B, SFA, SFB and C + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # a tmp 1 stage is provided + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # a tmp 1 stage is provided + ) + sfa_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + 1, # a tmp 1 stage is provided + ) + sfb_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + 1, # a tmp 1 stage is provided + ) + + c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, + ) + + ab_bytes_per_stage = ( + cute.size_in_bytes(a_dtype, a_smem_layout_stage_one) + + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfa_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one) + ) + mbar_helpers_bytes = 1024 + c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + c_bytes = c_bytes_per_stage * num_c_stage + + # Calculate A/B/SFA/SFB stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial C stages bytes + # Divide remaining by bytes needed per A/B/SFA/SFB stage + num_ab_stage = ( + smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes) + ) // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B/SFA/SFB stages and reserved bytes + # Add remaining unused smem to epilogue + num_c_stage += ( + smem_capacity + - occupancy * ab_bytes_per_stage * num_ab_stage + - occupancy * (mbar_helpers_bytes + c_bytes) + ) // (occupancy * c_bytes_per_stage) + + return num_acc_stage, num_ab_stage, num_c_stage + + @staticmethod + def _compute_grid( + total_num_clusters: int, + cluster_shape_mn: tuple[int, int], + max_active_clusters: cutlass.Constexpr[int], + ) -> tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]]: + """Compute tile scheduler parameters and grid shape for grouped GEMM operations. + + :param total_num_clusters: Total number of clusters to process across all groups. + :type total_num_clusters: int + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr[int] + + :return: A tuple containing: + - tile_sched_params: Parameters for the persistent tile scheduler. + - grid: Grid shape for kernel launch. + :rtype: tuple[utils.PersistentTileSchedulerParams, tuple[int, ...]] + """ + # Create problem shape with M, N dimensions from cluster shape + # and L dimension representing the total number of clusters. + problem_shape_ntile_mnl = ( + cluster_shape_mn[0], + cluster_shape_mn[1], + cutlass.Int32(total_num_clusters), + ) + + tile_sched_params = utils.PersistentTileSchedulerParams( + problem_shape_ntile_mnl, (*cluster_shape_mn, 1) + ) + + grid = utils.StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @staticmethod + def _get_mbar_smem_bytes(**kwargs_stages: int) -> int: + """Calculate shared memory consumption for memory barriers based on provided stages. + + Each stage requires 2 barriers, and each barrier consumes 8 bytes of shared memory. + The total consumption is the sum across all provided stages. This function calculates the total + shared memory needed for these barriers. + + :param kwargs_stages: Variable keyword arguments where each key is a stage name + (e.g., num_acc_stage, num_ab_stage) and each value is the + number of stages of that type. + :type kwargs_stages: int + :return: Total shared memory bytes required for all memory barriers. + :rtype: int + """ + num_barriers_per_stage = 2 + num_bytes_per_barrier = 8 + mbar_smem_consumption = sum( + [ + num_barriers_per_stage * num_bytes_per_barrier * stage + for stage in kwargs_stages.values() + ] + ) + return mbar_smem_consumption + + @staticmethod + def is_valid_dtypes_and_scale_factor_vec_size( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + c_dtype: Type[cutlass.Numeric], + ) -> bool: + """ + Check if the dtypes and sf_vec_size are valid combinations + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param sf_dtype: The data type of the scale factor + :type sf_dtype: Type[cutlass.Numeric] + :param sf_vec_size: The vector size of the scale factor + :type sf_vec_size: int + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + + :return: True if the dtypes and sf_vec_size are valid, False otherwise + :rtype: bool + """ + is_valid = True + + # Check valid ab_dtype + if ab_dtype not in { + cutlass.Float4E2M1FN, + cutlass.Float8E5M2, + cutlass.Float8E4M3FN, + }: + is_valid = False + + # Check valid sf_vec_size + if sf_vec_size not in {16, 32}: + is_valid = False + + # Check valid sf_dtype + if sf_dtype not in {cutlass.Float8E8M0FNU, cutlass.Float8E4M3FN}: + is_valid = False + + # Check valid sf_dtype and sf_vec_size combinations + if sf_dtype == cutlass.Float8E4M3FN and sf_vec_size == 32: + is_valid = False + if ab_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} and sf_vec_size == 16: + is_valid = False + + # Check valid c_dtype + if c_dtype not in { + cutlass.Float32, + cutlass.Float16, + cutlass.BFloat16, + cutlass.Float8E5M2, + cutlass.Float8E4M3FN, + }: + is_valid = False + + return is_valid + + @staticmethod + def is_valid_layouts( + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if layouts and dtypes are valid combinations + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major dimension of the A tensor + :type a_major: str + :param b_major: The major dimension of the B tensor + :type b_major: str + :param c_major: The major dimension of the C tensor + :type c_major: str + + :return: True if the layouts are valid, False otherwise + :rtype: bool + """ + is_valid = True + + if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"): + is_valid = False + return is_valid + + @staticmethod + def is_valid_mma_tiler_and_cluster_shape( + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ) -> bool: + """ + Check if the mma tiler and cluster shape are valid + + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + + :return: True if the mma tiler and cluster shape are valid, False otherwise + :rtype: bool + """ + is_valid = True + # Skip invalid mma tile shape + if mma_tiler_mn[0] not in [128, 256]: + is_valid = False + if mma_tiler_mn[1] not in [128, 256]: + is_valid = False + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if mma_tiler_mn[0] == 256 else 1) != 0: + is_valid = False + # Skip invalid cluster shape + is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + # Special cluster shape check for scale factor multicasts. + # Due to limited size of scale factors, we can't multicast among more than 4 CTAs. + or cluster_shape_mn[0] > 4 + or cluster_shape_mn[1] > 4 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_tensor_alignment( + problem_sizes_mnkl: List[Tuple[int, int, int, int]], + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the tensor alignment is valid + + :param problem_sizes_mnkl: The problem shape for each group + :type problem_sizes_mnkl: List[Tuple[int, int, int, int]] + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the problem shape is valid, False otherwise + :rtype: bool + """ + is_valid = True + + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + for m, n, k, l in problem_sizes_mnkl: + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment( + ab_dtype, b_major == "n", (n, k, l) + ) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + is_valid = False + return is_valid + + @staticmethod + def can_implement( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + c_dtype: Type[cutlass.Numeric], + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + problem_sizes_mnkl: List[Tuple[int, int, int, int]], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the gemm can be implemented + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param sf_dtype: The data type of the scale factor tensor + :type sf_dtype: Type[cutlass.Numeric] + :param sf_vec_size: The vector size + :type sf_vec_size: int + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the gemm can be implemented, False otherwise + :rtype: bool + """ + can_implement = True + # Skip unsupported types + if not Sm100GroupedBlockScaledGemmKernel.is_valid_dtypes_and_scale_factor_vec_size( + ab_dtype, sf_dtype, sf_vec_size, c_dtype + ): + can_implement = False + # Skip unsupported layouts + if not Sm100GroupedBlockScaledGemmKernel.is_valid_layouts( + ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + # Skip invalid mma tile shape and cluster shape + if not Sm100GroupedBlockScaledGemmKernel.is_valid_mma_tiler_and_cluster_shape( + mma_tiler_mn, cluster_shape_mn + ): + can_implement = False + # Skip illegal problem shape for load/store alignment + if not Sm100GroupedBlockScaledGemmKernel.is_valid_tensor_alignment( + problem_sizes_mnkl, ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + return can_implement + + # Size of smem we reserved for mbarrier, tensor memory management and tensormap update + reserved_smem_bytes = 1024 + bytes_per_tensormap = 128 + num_tensormaps = 5 + # size of smem used for tensor memory management + tensor_memory_management_bytes = 12 + + +# Create tensor and return the pointer, tensor, and stride +def create_tensor_and_stride( + l: int, + mode0: int, + mode1: int, + is_mode0_major: bool, + dtype: type[cutlass.Numeric], + is_dynamic_layout: bool = True, +) -> tuple[int, torch.Tensor, cute.Tensor, torch.Tensor, tuple[int, int]]: + """Create GPU tensor from either a new or existing CPU tensor. + + :param torch_tensor_cpu: Optional existing CPU tensor to reuse. If None, creates a new one. + :type torch_tensor_cpu: torch.Tensor, optional + """ + + # Create new CPU tensor + torch_tensor_cpu = cutlass_torch.matrix( + l, + mode0, + mode1, + is_mode0_major, + cutlass.Float32, + ) + + # Create GPU tensor from CPU tensor (new or existing) + cute_tensor, torch_tensor = cutlass_torch.cute_tensor_like( + torch_tensor_cpu, dtype, is_dynamic_layout, assumed_align=16 + ) + + # omit stride for L mode as it is always 1 + stride = (1, mode0) if is_mode0_major else (mode1, 1) + + return ( + torch_tensor.data_ptr(), + torch_tensor, + cute_tensor, + torch_tensor_cpu, + stride, + ) + + +def create_tensors_abc_for_all_groups( + problem_sizes_mnkl: List[tuple[int, int, int, int]], + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, +) -> tuple[ + List[List[int]], + List[List[torch.Tensor]], + List[tuple], + List[List[tuple]], + List[List[torch.Tensor]], +]: + ref_torch_fp32_tensors_abc = [] + torch_tensors_abc = [] + cute_tensors_abc = [] + strides_abc = [] + ptrs_abc = [] + + # Iterate through all groups and create tensors for each group + for group_idx, (m, n, k, l) in enumerate(problem_sizes_mnkl): + # Create tensors A, B, C + ( + ptr_a, + torch_tensor_a, + cute_tensor_a, + ref_torch_fp32_tensor_a, + stride_mk_a, + ) = create_tensor_and_stride(l, m, k, a_major == "m", ab_dtype) + + ( + ptr_b, + torch_tensor_b, + cute_tensor_b, + ref_torch_fp32_tensor_b, + stride_nk_b, + ) = create_tensor_and_stride(l, n, k, b_major == "n", ab_dtype) + + ( + ptr_c, + torch_tensor_c, + cute_tensor_c, + ref_torch_fp32_tensor_c, + stride_mn_c, + ) = create_tensor_and_stride(l, m, n, c_major == "m", c_dtype) + + ref_torch_fp32_tensors_abc.append( + [ref_torch_fp32_tensor_a, ref_torch_fp32_tensor_b, ref_torch_fp32_tensor_c] + ) + + ptrs_abc.append([ptr_a, ptr_b, ptr_c]) + torch_tensors_abc.append([torch_tensor_a, torch_tensor_b, torch_tensor_c]) + strides_abc.append([stride_mk_a, stride_nk_b, stride_mn_c]) + cute_tensors_abc.append( + ( + cute_tensor_a, + cute_tensor_b, + cute_tensor_c, + ) + ) + + return ( + ptrs_abc, + torch_tensors_abc, + cute_tensors_abc, + strides_abc, + ref_torch_fp32_tensors_abc, + ) + + +@cute.jit +def cvt_sf_MKL_to_M32x4xrm_K4xrk_L( + sf_ref_tensor: cute.Tensor, + sf_mma_tensor: cute.Tensor, +): + """Convert scale factor tensor from MKL layout to mma specification M(32x4xrest_m)xK(4xrest_k)xL layout""" + # sf_mma_tensor has flatten shape (32, 4, rest_m, 4, rest_k, l) + # group to ((32, 4, rest_m), (4, rest_k), l) + sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3) + sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3) + for i in cutlass.range(cute.size(sf_ref_tensor)): + mkl_coord = sf_ref_tensor.layout.get_hier_coord(i) + sf_mma_tensor[mkl_coord] = sf_ref_tensor[mkl_coord] + + +# Create scale factor tensor SFA/SFB +def create_scale_factor_tensor(l, mn, k, sf_vec_size, dtype): + def ceil_div(a, b): + return (a + b - 1) // b + + sf_k = max(1, ceil_div(k, sf_vec_size)) + ref_shape = (l, mn, sf_k) + + atom_m = (32, 4) + atom_k = 4 + mma_shape = ( + l, + ceil_div(mn, atom_m[0] * atom_m[1]), + ceil_div(sf_k, atom_k), + atom_m[0], + atom_m[1], + atom_k, + ) + + ref_permute_order = (1, 2, 0) + mma_permute_order = (3, 4, 1, 5, 2, 0) + + # Create f32 ref torch tensor (cpu) + ref_f32_torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( + ref_shape, + torch.float32, + permute_order=ref_permute_order, + init_type=cutlass_torch.TensorInitType.RANDOM, + init_config=cutlass_torch.RandomInitConfig( + min_val=1, + max_val=3, + ), + ) + + # Create f32 cute torch tensor (cpu) + cute_f32_torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( + mma_shape, + torch.float32, + permute_order=mma_permute_order, + init_type=cutlass_torch.TensorInitType.RANDOM, + init_config=cutlass_torch.RandomInitConfig( + min_val=0, + max_val=1, + ), + ) + + # convert ref f32 tensor to cute f32 tensor + cvt_sf_MKL_to_M32x4xrm_K4xrk_L( + from_dlpack(ref_f32_torch_tensor_cpu), + from_dlpack(cute_f32_torch_tensor_cpu), + ) + cute_f32_torch_tensor = cute_f32_torch_tensor_cpu.cuda() + + # reshape makes memory contiguous + ref_f32_torch_tensor_cpu = ( + ref_f32_torch_tensor_cpu.permute(2, 0, 1) + .unsqueeze(-1) + .expand(l, mn, sf_k, sf_vec_size) + .reshape(l, mn, sf_k * sf_vec_size) + .permute(*ref_permute_order) + ) + # prune to mkl for reference check. + ref_f32_torch_tensor_cpu = ref_f32_torch_tensor_cpu[:, :k, :] + + # Create dtype cute torch tensor (cpu) + cute_tensor, cute_torch_tensor = cutlass_torch.cute_tensor_like( + cute_f32_torch_tensor_cpu, + dtype, + is_dynamic_layout=True, + assumed_align=16, + ) + + # Convert f32 cute tensor to dtype cute tensor + cute_tensor = cutlass_torch.convert_cute_tensor( + cute_f32_torch_tensor, + cute_tensor, + dtype, + is_dynamic_layout=True, + ) + # get pointer of the tensor + ptr = cute_torch_tensor.data_ptr() + return ref_f32_torch_tensor_cpu, ptr, cute_tensor, cute_torch_tensor + + +def create_tensors_sfasfb_for_all_groups( + problem_sizes_mnkl: List[tuple[int, int, int, int]], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, +) -> tuple[ + List[List[int]], + List[List[torch.Tensor]], + List[tuple], + List[List[torch.Tensor]], +]: + ptrs_sfasfb = [] + torch_tensors_sfasfb = [] + cute_tensors_sfasfb = [] + refs_sfasfb = [] + + # Iterate through all groups and create tensors for each group + for group_idx, (m, n, k, l) in enumerate(problem_sizes_mnkl): + sfa_ref, ptr_sfa, sfa_tensor, sfa_torch = create_scale_factor_tensor( + l, m, k, sf_vec_size, sf_dtype + ) + sfb_ref, ptr_sfb, sfb_tensor, sfb_torch = create_scale_factor_tensor( + l, n, k, sf_vec_size, sf_dtype + ) + ptrs_sfasfb.append([ptr_sfa, ptr_sfb]) + torch_tensors_sfasfb.append([sfa_torch, sfb_torch]) + cute_tensors_sfasfb.append( + ( + sfa_tensor, + sfb_tensor, + ) + ) + refs_sfasfb.append([sfa_ref, sfb_ref]) + + return ( + ptrs_sfasfb, + torch_tensors_sfasfb, + cute_tensors_sfasfb, + refs_sfasfb, + ) + + +def run( + num_groups: int, + problem_sizes_mnkl: List[Tuple[int, int, int, int]], + host_problem_shape_available: bool, + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + tolerance: float = 1e-01, + warmup_iterations: int = 0, + iterations: int = 1, + skip_ref_check: bool = False, + use_cold_l2: bool = False, + **kwargs, +): + """Run SM100 grouped blockscaledGEMM example with specified configurations. + + :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False + :type use_cold_l2: bool, optional + :return: Execution time of the GEMM kernel in microseconds + :rtype: float + """ + print("Running Blackwell Grouped GEMM test with:") + print(f"{num_groups} groups") + for i, (m, n, k, l) in enumerate(problem_sizes_mnkl): + print(f"Group {i}: {m}x{n}x{k}x{l}") + print(f"AB dtype: {ab_dtype}, SF dtype: {sf_dtype}, SF Vec size: {sf_vec_size}") + print(f"C dtype: {c_dtype}") + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}") + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}") + + # Skip unsupported testcase + if not Sm100GroupedBlockScaledGemmKernel.can_implement( + ab_dtype, + sf_dtype, + sf_vec_size, + c_dtype, + mma_tiler_mn, + cluster_shape_mn, + problem_sizes_mnkl, + a_major, + b_major, + c_major, + ): + raise TypeError( + f"Unsupported testcase {ab_dtype}, {sf_dtype}, {sf_vec_size}, {c_dtype}, {mma_tiler_mn}, {cluster_shape_mn}, {problem_sizes_mnkl}, {a_major}, {b_major}, {c_major}" + ) + + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + torch.manual_seed(2025) + + # Create tensors A, B, C for all groups + ( + ptrs_abc, + torch_tensors_abc, + cute_tensors_abc, + strides_abc, + ref_f32_torch_tensors_abc, + ) = create_tensors_abc_for_all_groups( + problem_sizes_mnkl, + ab_dtype, + c_dtype, + a_major, + b_major, + c_major, + ) + # Create tensors SFA, SFB for all groups + ( + ptrs_sfasfb, + torch_tensors_sfasfb, + cute_tensors_sfasfb, + refs_f32_torch_tensors_sfasfb, + ) = create_tensors_sfasfb_for_all_groups( + problem_sizes_mnkl, + sf_dtype, + sf_vec_size, + ) + + # Setup inital tensors for TMA of A,B and C + alignment = 16 # 16 bytes aligned + divisibility_ab = 32 if ab_dtype == cutlass.Float4E2M1FN else 16 + divisibility_c = 32 if c_dtype == cutlass.Float4E2M1FN else 16 + divisibility_sf = 32 if sf_dtype == cutlass.Float4E2M1FN else 16 + + min_ab_size = alignment * 8 // ab_dtype.width # alignment bytes of width + div_mul_ab = (divisibility_ab + min_ab_size - 1) // min_ab_size + min_ab_size = min_ab_size * div_mul_ab + + min_c_size = alignment * 8 // c_dtype.width + div_mul_c = (divisibility_c + min_c_size - 1) // min_c_size + min_c_size = min_c_size * div_mul_c + + min_sf_size = alignment * 8 // sf_dtype.width + div_mul_sf = (divisibility_sf + min_sf_size - 1) // min_sf_size + min_sf_size = min_sf_size * div_mul_sf + + initial_cute_tensors_abc = [ + create_tensor_and_stride(1, min_ab_size, min_ab_size, a_major == "m", ab_dtype)[ + 2 + ], + create_tensor_and_stride(1, min_ab_size, min_ab_size, b_major == "n", ab_dtype)[ + 2 + ], + create_tensor_and_stride(1, min_c_size, min_c_size, c_major == "m", c_dtype)[2], + ] + initial_cute_tensors_sfasfb = [ + create_tensor_and_stride(1, min_sf_size, min_sf_size, a_major == "m", sf_dtype)[ + 2 + ], + create_tensor_and_stride(1, min_sf_size, min_sf_size, b_major == "n", sf_dtype)[ + 2 + ], + ] + + hardware_info = cutlass.utils.HardwareInfo() + sm_count = hardware_info.get_max_active_clusters(1) + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1] + ) + # Prepare tensormap buffer for each SM + num_tensormap_buffers = sm_count + tensormap_shape = ( + num_tensormap_buffers, + Sm100GroupedBlockScaledGemmKernel.num_tensormaps, + Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap // 8, + ) + tensor_of_tensormap, tensor_of_tensormap_torch = cutlass_torch.cute_tensor_like( + torch.empty(tensormap_shape, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + ) + + grouped_blockscaled_gemm = Sm100GroupedBlockScaledGemmKernel( + sf_vec_size, + mma_tiler_mn, + cluster_shape_mn, + ) + + # layout (num_groups, 4):(4, 1) + ( + tensor_of_dim_size_mnkl, + tensor_of_dim_size_mnkl_torch, + ) = cutlass_torch.cute_tensor_like( + torch.tensor(problem_sizes_mnkl, dtype=torch.int32), + cutlass.Int32, + is_dynamic_layout=False, + assumed_align=16, + ) + + # layout (num_groups, 3, 2):(6, 2, 1) + tensor_of_strides_abc, tensor_of_strides_abc_torch = cutlass_torch.cute_tensor_like( + torch.tensor(strides_abc, dtype=torch.int32), + cutlass.Int32, + is_dynamic_layout=False, + assumed_align=16, + ) + + # layout (num_groups,3):(3, 1) + tensor_of_ptrs_abc, tensor_of_ptrs_abc_torch = cutlass_torch.cute_tensor_like( + torch.tensor(ptrs_abc, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + assumed_align=16, + ) + + # layout (num_groups,2):(2, 1) + tensor_of_ptrs_sfasfb, tensor_of_ptrs_sfasfb_torch = cutlass_torch.cute_tensor_like( + torch.tensor(ptrs_sfasfb, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + assumed_align=16, + ) + + # Compute total number of cluster tiles we need to compute for given grouped GEMM problem + def compute_total_num_clusters( + problem_sizes_mnkl: List[tuple[int, int, int, int]], + cluster_tile_shape_mn: tuple[int, int], + ) -> int: + total_num_clusters = 0 + for m, n, _, _ in problem_sizes_mnkl: + num_clusters_mn = tuple( + (x + y - 1) // y for x, y in zip((m, n), cluster_tile_shape_mn) + ) + total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn) + return total_num_clusters + + # Compute cluster tile shape + def compute_cluster_tile_shape( + mma_tiler_mn: tuple[int, int], + cluster_shape_mn: tuple[int, int], + ) -> tuple[int, int]: + cta_tile_shape_mn = [128, mma_tiler_mn[1]] + return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn)) + + cluster_tile_shape_mn = compute_cluster_tile_shape(mma_tiler_mn, cluster_shape_mn) + total_num_clusters = compute_total_num_clusters( + problem_sizes_mnkl, cluster_tile_shape_mn + ) + + # Initialize Stream + current_stream = cutlass_torch.default_stream() + + # If the host problem shape is available, we will launch the grid with only + # the necessary clusters. The function compute_total_num_clusters() does that. + # If the problem shape only exists on device, we will need to launch all active + # clusters possible on a device. + if host_problem_shape_available: + print("Problem shapes available on host and device") + total_num_clusters = compute_total_num_clusters( + problem_sizes_mnkl, cluster_tile_shape_mn + ) + else: + print("Problem shapes available only on device") + total_num_clusters = max_active_clusters + + # Compile grouped GEMM kernel + compiled_grouped_gemm = cute.compile( + grouped_blockscaled_gemm, + initial_cute_tensors_abc[0], + initial_cute_tensors_abc[1], + initial_cute_tensors_abc[2], + initial_cute_tensors_sfasfb[0], + initial_cute_tensors_sfasfb[1], + num_groups, + tensor_of_dim_size_mnkl, + tensor_of_strides_abc, + tensor_of_ptrs_abc, + tensor_of_ptrs_sfasfb, + total_num_clusters, + tensor_of_tensormap, + max_active_clusters, + current_stream, + options=f"--opt-level 2", + ) + + # reference check + if not skip_ref_check: + compiled_grouped_gemm( + initial_cute_tensors_abc[0], + initial_cute_tensors_abc[1], + initial_cute_tensors_abc[2], + initial_cute_tensors_sfasfb[0], + initial_cute_tensors_sfasfb[1], + tensor_of_dim_size_mnkl, + tensor_of_strides_abc, + tensor_of_ptrs_abc, + tensor_of_ptrs_sfasfb, + tensor_of_tensormap, + current_stream, + ) + print("Verifying results...") + + for i, ( + (a_ref, b_ref, c_ref), + (sfa_ref, sfb_ref), + (a_tensor, b_tensor, c_tensor), + (m, n, k, l), + ) in enumerate( + zip( + ref_f32_torch_tensors_abc, + refs_f32_torch_tensors_sfasfb, + cute_tensors_abc, + problem_sizes_mnkl, + ) + ): + ref_res_a = torch.einsum("mkl,mkl->mkl", a_ref, sfa_ref) + ref_res_b = torch.einsum("nkl,nkl->nkl", b_ref, sfb_ref) + ref = torch.einsum("mkl,nkl->mnl", ref_res_a, ref_res_b) + + print(f"checking group {i}") + c_ref_device = c_ref.cuda() + + cute.testing.convert( + c_tensor, + from_dlpack(c_ref_device, assumed_align=16).mark_layout_dynamic( + leading_dim=(1 if c_major == "n" else 0) + ), + ) + + c_ref = c_ref_device.cpu() + + if c_dtype in (cutlass.Float32, cutlass.Float16, cutlass.BFloat16): + torch.testing.assert_close(c_ref, ref, atol=tolerance, rtol=1e-02) + elif c_dtype in (cutlass.Float8E5M2, cutlass.Float8E4M3FN): + # Convert ref : f32 -> f8 -> f32 + ref_f8_ = torch.empty( + *(l, m, n), dtype=torch.uint8, device="cuda" + ).permute(1, 2, 0) + ref_f8 = from_dlpack(ref_f8_, assumed_align=16).mark_layout_dynamic( + leading_dim=1 + ) + ref_f8.element_type = c_dtype + ref_device = ref.permute(2, 0, 1).contiguous().permute(1, 2, 0).cuda() + ref_tensor = from_dlpack( + ref_device, assumed_align=16 + ).mark_layout_dynamic(leading_dim=1) + cute.testing.convert(ref_tensor, ref_f8) + cute.testing.convert(ref_f8, ref_tensor) + ref = ref_device.cpu() + torch.testing.assert_close(c_ref, ref, atol=tolerance, rtol=1e-02) + def generate_tensors(): + ( + ptrs_abc_workspace, + torch_tensors_abc_workspace, + cute_tensors_abc_workspace, + strides_abc_workspace, + _, + ) = create_tensors_abc_for_all_groups( + problem_sizes_mnkl, + ab_dtype, + c_dtype, + a_major, + b_major, + c_major, + ) + + ( + ptrs_sfasfb_workspace, + torch_tensors_sfasfb_workspace, + cute_tensors_sfasfb_workspace, + _, + ) = create_tensors_sfasfb_for_all_groups( + problem_sizes_mnkl, + sf_dtype, + sf_vec_size, + ) + + initial_cute_tensors_abc_workspace = [ + create_tensor_and_stride( + 1, min_ab_size, min_ab_size, a_major == "m", ab_dtype + )[2], + create_tensor_and_stride( + 1, min_ab_size, min_ab_size, b_major == "n", ab_dtype + )[2], + create_tensor_and_stride( + 1, min_c_size, min_c_size, c_major == "m", c_dtype + )[2], + ] + initial_cute_tensors_sfasfb_workspace = [ + create_tensor_and_stride( + 1, min_sf_size, min_sf_size, a_major == "m", sf_dtype + )[2], + create_tensor_and_stride( + 1, min_sf_size, min_sf_size, b_major == "n", sf_dtype + )[2], + ] + + # Create new tensors for this workspace + tensor_of_strides_abc_workspace, _ = cutlass_torch.cute_tensor_like( + torch.tensor(strides_abc_workspace, dtype=torch.int32), + cutlass.Int32, + is_dynamic_layout=False, + assumed_align=16, + ) + + tensor_of_ptrs_abc_workspace, _ = cutlass_torch.cute_tensor_like( + torch.tensor(ptrs_abc_workspace, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + assumed_align=16, + ) + + tensor_of_ptrs_sfasfb_workspace, _ = cutlass_torch.cute_tensor_like( + torch.tensor(ptrs_sfasfb_workspace, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + assumed_align=16, + ) + + tensormap_workspace, _ = cutlass_torch.cute_tensor_like( + torch.empty(tensormap_shape, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + ) + + args = cute.testing.JitArguments( + initial_cute_tensors_abc_workspace[0], + initial_cute_tensors_abc_workspace[1], + initial_cute_tensors_abc_workspace[2], + initial_cute_tensors_sfasfb_workspace[0], + initial_cute_tensors_sfasfb_workspace[1], + tensor_of_dim_size_mnkl, + tensor_of_strides_abc_workspace, + tensor_of_ptrs_abc_workspace, + tensor_of_ptrs_sfasfb_workspace, + tensormap_workspace, + current_stream, + ) + args.add_to_scope([torch_tensors_abc_workspace, torch_tensors_sfasfb_workspace]) + return args + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + sum( + [ + sum( + [ + torch_tensor.numel() * torch_tensor.element_size() + for torch_tensor in group_tensors + ] + ) + for group_tensors in torch_tensors_abc + torch_tensors_sfasfb + ] + ) + + + # Add size of strides tensor + tensor_of_strides_abc_torch.numel() + * tensor_of_strides_abc_torch.element_size() + + + # Add size of ptrs tensor A, B, C + tensor_of_ptrs_abc_torch.numel() * tensor_of_ptrs_abc_torch.element_size() + + + # Add size of ptrs tensor SFA, SFB + tensor_of_ptrs_sfasfb_torch.numel() + * tensor_of_ptrs_sfasfb_torch.element_size() + + + # Add size of tensormap tensor + tensor_of_tensormap_torch.numel() * tensor_of_tensormap_torch.element_size() + ) + workspace_count = cute.testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = cute.testing.benchmark( + compiled_grouped_gemm, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + runtime_s = exec_time / 1.0e6 + fmas = 0 + for group in range(num_groups): + [M, N, K, _] = problem_sizes_mnkl[group] + fmas += M * N * K + flop = 2 * fmas + gflop = flop / 1.0e9 + gflops = gflop / runtime_s + + print("Average Runtime : ", exec_time / 1000, "ms") + print("GFLOPS : ", gflops) + + return exec_time # Return execution time in microseconds + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + def parse_comma_separated_tuples(s: str) -> List[tuple[int, ...]]: + if s.strip().startswith("("): + # Split on ),( to separate tuples + tuples = s.strip("()").split("),(") + result = [] + tuple_len = None + + for t in tuples: + # Parse individual tuple + nums = [int(x.strip()) for x in t.split(",")] + + # Validate tuple length consistency + if tuple_len is None: + tuple_len = len(nums) + elif len(nums) != tuple_len: + raise argparse.ArgumentTypeError( + "All tuples must have the same length" + ) + + result.append(tuple(nums)) + return result + + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers or list of tuples" + ) + + parser = argparse.ArgumentParser( + description="Example of Grouped GEMM on Blackwell." + ) + parser.add_argument( + "--num_groups", + type=int, + default=2, + help="Number of groups", + ) + parser.add_argument( + "--problem_sizes_mnkl", + type=parse_comma_separated_tuples, + default=((128, 128, 128, 1), (128, 128, 128, 1)), + help="a tuple of problem sizes for each group (comma-separated tuples)", + ) + parser.add_argument( + "--mma_tiler_mn", + type=parse_comma_separated_ints, + default=(128, 128), + help="Mma tile shape (comma-separated)", + ) + parser.add_argument( + "--host_problem_shape_available", + action="store_true", + help="Enable the compute of grid based upon host problem shape", + ) + parser.add_argument( + "--cluster_shape_mn", + type=parse_comma_separated_ints, + default=(1, 1), + help="Cluster shape (comma-separated)", + ) + parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.Float4E2M1FN) + parser.add_argument("--sf_dtype", type=cutlass.dtype, default=cutlass.Float8E8M0FNU) + parser.add_argument("--sf_vec_size", type=int, default=16) + parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.Float16) + parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k") + parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k") + parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + parser.add_argument( + "--warmup_iterations", type=int, default=0, help="Warmup iterations" + ) + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations to run the kernel", + ) + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + + args = parser.parse_args() + + if ( + len(args.problem_sizes_mnkl) != 0 + and len(args.problem_sizes_mnkl) != args.num_groups + ): + parser.error("--problem_sizes_mnkl must contain exactly num_groups tuples") + + # l mode must be 1 for all groups + for _, _, _, l in args.problem_sizes_mnkl: + if l != 1: + parser.error("l must be 1 for all groups") + + if len(args.mma_tiler_mn) != 2: + parser.error("--mma_tiler_mn must contain exactly 2 values") + + if len(args.cluster_shape_mn) != 2: + parser.error("--cluster_shape_mn must contain exactly 2 values") + + run( + args.num_groups, + args.problem_sizes_mnkl, + args.host_problem_shape_available, + args.ab_dtype, + args.sf_dtype, + args.sf_vec_size, + args.c_dtype, + args.a_major, + args.b_major, + args.c_major, + args.mma_tiler_mn, + args.cluster_shape_mn, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + print("PASS") diff --git a/reference/moe_moe_persistent_scheduler.py b/reference/moe_moe_persistent_scheduler.py new file mode 100644 index 00000000..6cf6c0c9 --- /dev/null +++ b/reference/moe_moe_persistent_scheduler.py @@ -0,0 +1,695 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" +MoE Persistent Tile Scheduler + +A specialized tile scheduler for MoE (Mixture of Experts) grouped GEMM operations. +This scheduler handles tile iteration across all experts, producing MoEWorkTileInfo +(expert_idx, tile_m_idx, tile_n_idx, k_tile_cnt) for each tile. + +Scenarios: +- 2Dx3D (Forward): A(tokens_sum, hidden) x B(experts, intermediate, hidden) -> C(tokens_sum, intermediate) +- 2Dx2D (Backward): A(intermediate, tokens_sum) x B(hidden, tokens_sum) -> C(experts, intermediate, hidden) + +Key design principle: +- Scheduler is ONLY responsible for tile iteration (tensor-agnostic, TMA-agnostic) +- Domain conversion (fake tensor -> real expert tensor) is handled by MoESchedExtension +- TMA descriptor management is handled by OnlineTensormapDescCreator +- The kernel orchestrates all three components +""" + +from typing import List, Tuple, Literal + +import cutlass +import cutlass.cute as cute +from cutlass.cutlass_dsl import ( + Boolean, + Int32, + Integer, + extract_mlir_values, + new_from_mlir_values, + const_expr, + dsl_user_op, +) +from cutlass._mlir import ir + +# ============================================================================= +# Work Tile Info +# ============================================================================= + +class MoEWorkTileInfo: + """ + Work tile information for MoE scheduler. + + Contains CTA-level tile information for executor warps: + - expert_idx: Which expert (-1 means invalid/done) + - tile_m_idx: CTA tile index along GEMM M dimension + - tile_n_idx: CTA tile index along GEMM N dimension + - k_tile_cnt: Number of CTA tiles along K dimension + + Note: These are CTA-level indices, not cluster-level. + tile_l_idx is always 0 for MoE, executor can hardcode it. + + For 2Dx3D (Forward): + M = tokens_i (dynamic), N = intermediate (fixed), K = hidden (fixed) + + For 2Dx2D (Backward): + M = intermediate (fixed), N = hidden (fixed), K = tokens_i (dynamic) + """ + + def __init__( + self, + expert_idx: Int32, # -1 means invalid tile + tile_m_idx: Int32, + tile_n_idx: Int32, + k_tile_cnt: Int32, + ): + self.expert_idx = expert_idx + self.tile_m_idx = tile_m_idx + self.tile_n_idx = tile_n_idx + self.k_tile_cnt = k_tile_cnt + + @property + def is_valid_tile(self) -> Boolean: + """Check if this is a valid work tile (expert_idx >= 0).""" + return self.expert_idx >= Int32(0) + + def __extract_mlir_values__(self) -> List[ir.Value]: + values = extract_mlir_values(self.expert_idx) + values.extend(extract_mlir_values(self.tile_m_idx)) + values.extend(extract_mlir_values(self.tile_n_idx)) + values.extend(extract_mlir_values(self.k_tile_cnt)) + return values + + def __new_from_mlir_values__(self, values: List[ir.Value]) -> "MoEWorkTileInfo": + assert len(values) == 4 + return MoEWorkTileInfo( + expert_idx=new_from_mlir_values(self.expert_idx, [values[0]]), + tile_m_idx=new_from_mlir_values(self.tile_m_idx, [values[1]]), + tile_n_idx=new_from_mlir_values(self.tile_n_idx, [values[2]]), + k_tile_cnt=new_from_mlir_values(self.k_tile_cnt, [values[3]]), + ) + + def to_rmem_tensor(self): + """Pack work tile info fields into an rmem tensor of shape (4,) for vectorized smem copy.""" + rmem = cute.make_rmem_tensor((4,), Int32) + rmem[0] = self.expert_idx + rmem[1] = self.tile_m_idx + rmem[2] = self.tile_n_idx + rmem[3] = self.k_tile_cnt + return rmem + + @staticmethod + def from_rmem_tensor(rmem) -> "MoEWorkTileInfo": + """Unpack work tile info from an rmem tensor of shape (4,).""" + return MoEWorkTileInfo( + expert_idx=rmem[0], # type: ignore[arg-type] + tile_m_idx=rmem[1], # type: ignore[arg-type] + tile_n_idx=rmem[2], # type: ignore[arg-type] + k_tile_cnt=rmem[3], # type: ignore[arg-type] + ) + + +# ============================================================================= +# Scheduler Parameters +# ============================================================================= + +class MoEStaticSchedulerParams: + """ + Parameters for MoE tile scheduler. + + Uses unified semantics for both scenarios: + - expert_shape: (expert_cnt, intermediate, hidden) + + For 2Dx3D: GEMM is (M=tokens_i, N=intermediate, K=hidden) per expert + For 2Dx2D: GEMM is (M=hidden, N=intermediate, K=tokens_i) per expert + + Tile hierarchy: + - cta_tile_shape_mnk: Single CTA tile shape (tile_m, tile_n, tile_k) + - cluster_shape_mn: CTAs per cluster (cluster_m, cluster_n) + - cluster_tile_shape_mn: Cluster tile shape = cta_tile_shape * cluster_shape + + This class is used both on host (for grid shape calculation) and on device + (stored in scheduler). Codegen-time constants (scenario, cta_tile_shape_mnk, + cluster_shape_mn) are NOT serialized to MLIR values. + """ + + def __init__( + self, + scenario: Literal["2Dx3D", "2Dx2D"], + expert_shape: Tuple[int | Int32, int | Int32, int | Int32], # (expert_cnt, intermediate, hidden) + cta_tile_shape_mnk: Tuple[int, int, int], # (tile_m, tile_n, tile_k) + cluster_shape_mn: Tuple[int, int], # (cluster_m, cluster_n) + ): + self.scenario = scenario + e, i, h = expert_shape + self.expert_cnt = e if isinstance(e, Int32) else Int32(e) + self.intermediate = i if isinstance(i, Int32) else Int32(i) + self.hidden = h if isinstance(h, Int32) else Int32(h) + self.cta_tile_shape_mnk = cta_tile_shape_mnk + self.cluster_shape_mn = cluster_shape_mn + + @property + def cluster_tile_m(self) -> int: + """Cluster tile size along M = cta_tile_m * cluster_m.""" + return self.cta_tile_shape_mnk[0] * self.cluster_shape_mn[0] + + @property + def cluster_tile_n(self) -> int: + """Cluster tile size along N = cta_tile_n * cluster_n.""" + return self.cta_tile_shape_mnk[1] * self.cluster_shape_mn[1] + + @property + def cta_tile_k(self) -> int: + """CTA tile size along K (same as cluster since cluster_k = 1).""" + return self.cta_tile_shape_mnk[2] + + def __extract_mlir_values__(self) -> List[ir.Value]: + """Only serialize runtime values, not codegen-time constants.""" + values = [] + values.extend(extract_mlir_values(self.expert_cnt)) + values.extend(extract_mlir_values(self.intermediate)) + values.extend(extract_mlir_values(self.hidden)) + return values + + def __new_from_mlir_values__(self, values: List[ir.Value]) -> "MoEStaticSchedulerParams": + assert len(values) == 3 + return MoEStaticSchedulerParams( + scenario=self.scenario, + expert_shape=( + new_from_mlir_values(self.expert_cnt, [values[0]]), + new_from_mlir_values(self.intermediate, [values[1]]), + new_from_mlir_values(self.hidden, [values[2]]), + ), + cta_tile_shape_mnk=self.cta_tile_shape_mnk, + cluster_shape_mn=self.cluster_shape_mn, + ) + + @staticmethod + def get_grid_shape( + params: "MoEStaticSchedulerParams", + max_active_clusters: int, + ) -> Tuple[int, int, int]: + """ + Compute grid shape for kernel launch. + + Since host doesn't know token distribution across experts, + we launch max_active_clusters and let device-side scheduler + determine which tiles are valid. + """ + return ( + params.cluster_shape_mn[0], + params.cluster_shape_mn[1], + max_active_clusters, + ) + + +# ============================================================================= +# Scheduler (Device-side) +# ============================================================================= + +class MoEStaticPersistentTileScheduler: + """ + Persistent tile scheduler specialized for MoE grouped GEMM. + + This scheduler is ONLY responsible for tile iteration. It does NOT know + about tensor types, TMA descriptors, or domain conversion. Those concerns + are handled by MoESchedExtension and OnlineTensormapDescCreator respectively. + + Architecture: + - Scheduler warp: Holds scheduler instance, iterates tiles, broadcasts work_tile_info + - Executor warps: Read work_tile_info from smem, use MoESchedExtension for + domain conversion and TMA desc selection + + The scheduler handles: + - 2Dx3D: Dynamic M per expert (from offs), fixed N (intermediate) and K (hidden) + - 2Dx2D: Fixed M (intermediate) and N (hidden), dynamic K per expert (reduction axis) + + Usage (Scheduler warp): + scheduler = MoEStaticPersistentTileScheduler.create(params, offs, block_idx, grid_dim) + work_tile_info = scheduler.initial_work_tile_info() + # Broadcast work_tile_info to smem... + + while work_tile_info.is_valid_tile: + # ... do work ... + work_tile_info = scheduler.advance_to_next_work() + # Broadcast work_tile_info to smem... + + Usage (Executor warps - via MoESchedExtension): + # Read work_tile_info from smem... + real_a, desc_a = ext.get_gmem_tensor("a", tma_tensor_a, offs, work_tile_info) + real_b, desc_b = ext.get_gmem_tensor("b", tma_tensor_b, offs, work_tile_info) + real_c, desc_c = ext.get_gmem_tensor("c", tma_tensor_c, offs, work_tile_info) + """ + + def __init__( + self, + # Params (contains scenario, expert_cnt, intermediate, hidden, tile/cluster shapes) + params: MoEStaticSchedulerParams, + # Runtime tensor for scheduling + offs: cute.Tensor, # (experts,) cumsum of token counts + # Scheduling state + num_persistent_clusters: Int32, + current_work_linear_idx: Int32, + cta_id_in_cluster: cute.Coord, + # Expert tracking state (for O(1) advance within same expert) + current_expert_idx: Int32, + expert_tile_start: Int32, # cumsum of tiles before current expert + expert_tile_end: Int32, # cumsum of tiles including current expert + ): + self.params = params + self.offs = offs + self.num_persistent_clusters = num_persistent_clusters + self._current_work_linear_idx = current_work_linear_idx + self.cta_id_in_cluster = cta_id_in_cluster + # Expert tracking + self.current_expert_idx = current_expert_idx + self.expert_tile_start = expert_tile_start + self.expert_tile_end = expert_tile_end + + # ========================================================================= + # Convenience accessors for params + # ========================================================================= + + @property + def scenario(self) -> Literal["2Dx3D", "2Dx2D"]: + return self.params.scenario + + @property + def expert_cnt(self) -> Int32: + return self.params.expert_cnt + + @property + def intermediate(self) -> Int32: + return self.params.intermediate + + @property + def hidden(self) -> Int32: + return self.params.hidden + + @property + def cta_tile_shape_mnk(self) -> Tuple[int, int, int]: + return self.params.cta_tile_shape_mnk + + @property + def cluster_shape_mn(self) -> Tuple[int, int]: + return self.params.cluster_shape_mn + + @property + def cluster_tile_m(self) -> int: + return self.params.cluster_tile_m + + @property + def cluster_tile_n(self) -> int: + return self.params.cluster_tile_n + + @property + def cta_tile_k(self) -> int: + return self.params.cta_tile_k + + # ========================================================================= + # MLIR value serialization (for SSA value passing in device code) + # ========================================================================= + + def __extract_mlir_values__(self) -> List[ir.Value]: + values = [] + # Params (only runtime values are extracted) + values.extend(extract_mlir_values(self.params)) + # Runtime tensor for scheduling + values.extend(extract_mlir_values(self.offs)) + # Scheduling state + values.extend(extract_mlir_values(self.num_persistent_clusters)) + values.extend(extract_mlir_values(self._current_work_linear_idx)) + values.extend(extract_mlir_values(self.cta_id_in_cluster)) + # Expert tracking state + values.extend(extract_mlir_values(self.current_expert_idx)) + values.extend(extract_mlir_values(self.expert_tile_start)) + values.extend(extract_mlir_values(self.expert_tile_end)) + return values + + def __new_from_mlir_values__( + self, values: List[ir.Value] + ) -> "MoEStaticPersistentTileScheduler": + idx = 0 + + # Params (3 values: expert_cnt, intermediate, hidden) + new_params = new_from_mlir_values(self.params, values[idx:idx + 3]) + idx += 3 + + # Runtime tensor for scheduling (variable size) + offs_len = len(extract_mlir_values(self.offs)) + new_offs = new_from_mlir_values(self.offs, values[idx:idx + offs_len]) + idx += offs_len + + # Scheduling state + new_num_persistent_clusters = new_from_mlir_values( + self.num_persistent_clusters, [values[idx]] + ) + idx += 1 + new_current_work_linear_idx = new_from_mlir_values( + self._current_work_linear_idx, [values[idx]] + ) + idx += 1 + + # cta_id_in_cluster (3 values for Coord) + new_cta_id_in_cluster = new_from_mlir_values( + self.cta_id_in_cluster, values[idx:idx + 3] + ) + idx += 3 + + # Expert tracking state + new_current_expert_idx = new_from_mlir_values( + self.current_expert_idx, [values[idx]] + ) + idx += 1 + new_expert_tile_start = new_from_mlir_values( + self.expert_tile_start, [values[idx]] + ) + idx += 1 + new_expert_tile_end = new_from_mlir_values( + self.expert_tile_end, [values[idx]] + ) + idx += 1 + + return MoEStaticPersistentTileScheduler( + params=new_params, + offs=new_offs, + num_persistent_clusters=new_num_persistent_clusters, + current_work_linear_idx=new_current_work_linear_idx, + cta_id_in_cluster=new_cta_id_in_cluster, + current_expert_idx=new_current_expert_idx, + expert_tile_start=new_expert_tile_start, + expert_tile_end=new_expert_tile_end, + ) + + # ========================================================================= + # Factory method + # ========================================================================= + + @staticmethod + @dsl_user_op + def create( + params: MoEStaticSchedulerParams, + offs: cute.Tensor, + block_idx: Tuple[Integer, Integer, Integer], + grid_dim: Tuple[Integer, Integer, Integer], + *, + loc=None, + ip=None, + ) -> "MoEStaticPersistentTileScheduler": + """ + Create a MoE persistent tile scheduler. + + :param params: Scheduler parameters (from host) + :param offs: Cumsum tensor of token counts per expert, shape (experts,) + :param block_idx: CUDA block index + :param grid_dim: CUDA grid dimensions + """ + num_persistent_clusters = cute.size(grid_dim, loc=loc, ip=ip) // cute.size( + params.cluster_shape_mn, loc=loc, ip=ip + ) + + bidx, bidy, bidz = block_idx + current_work_linear_idx = Int32(bidz) + + cta_id_in_cluster = ( + Int32(bidx % params.cluster_shape_mn[0]), + Int32(bidy % params.cluster_shape_mn[1]), + Int32(0), + ) + + # Initialize expert tracking to "before expert 0" + # The first call to _get_work_tile_for_linear_idx will advance to the correct expert + current_expert_idx = Int32(0) + expert_tile_start = Int32(0) + expert_tile_end = Int32(0) # Will be computed on first access + + return MoEStaticPersistentTileScheduler( + params=params, + offs=offs, + num_persistent_clusters=num_persistent_clusters, + current_work_linear_idx=current_work_linear_idx, + cta_id_in_cluster=cta_id_in_cluster, + current_expert_idx=current_expert_idx, + expert_tile_start=expert_tile_start, + expert_tile_end=expert_tile_end, + ) + + # ========================================================================= + # Tile iteration methods + # ========================================================================= + + @dsl_user_op + @cute.jit + def initial_work_tile_info(self, *, loc=None, ip=None) -> MoEWorkTileInfo: + """Get the initial work tile info.""" + return self._get_work_tile_for_linear_idx( + self._current_work_linear_idx, loc=loc, ip=ip + ) + + @dsl_user_op + @cute.jit + def advance_to_next_work(self, *, loc=None, ip=None) -> MoEWorkTileInfo: + """Advance to the next work tile and return its info.""" + self._current_work_linear_idx += self.num_persistent_clusters + return self._get_work_tile_for_linear_idx( + self._current_work_linear_idx, loc=loc, ip=ip + ) + + @dsl_user_op + @cute.jit + def _get_work_tile_for_linear_idx( + self, + cluster_linear_idx: Int32, + *, + loc=None, + ip=None + ) -> MoEWorkTileInfo: + """ + Convert a linear cluster index to MoEWorkTileInfo. + + Uses cached expert tracking state for O(1) fast path when staying + within the same expert. Advances expert state when needed. + + Returns an invalid tile (expert_idx = -1) if cluster_linear_idx is out of range. + """ + # Ensure expert tracking is initialized and up-to-date + self._advance_expert_to_contain(cluster_linear_idx, loc=loc, ip=ip) + + # Check if valid (still within expert range after advancing) + is_valid = self.current_expert_idx < self.expert_cnt + + work_tile_info = MoEWorkTileInfo( + expert_idx=Int32(-1), + tile_m_idx=Int32(0), + tile_n_idx=Int32(0), + k_tile_cnt=Int32(0), + ) + + if is_valid: + # Compute local cluster tile indices within current expert + local_idx = cluster_linear_idx - self.expert_tile_start + cluster_tile_m_idx, cluster_tile_n_idx = self._decompose_local_idx( + local_idx, self.current_expert_idx, loc=loc, ip=ip + ) + + # Convert cluster tile indices to CTA tile indices + # cta_tile_idx = cluster_tile_idx * cluster_shape + cta_id_in_cluster + cta_tile_m_idx = ( + cluster_tile_m_idx * self.cluster_shape_mn[0] + + self.cta_id_in_cluster[0] # type: ignore[index] + ) + cta_tile_n_idx = ( + cluster_tile_n_idx * self.cluster_shape_mn[1] + + self.cta_id_in_cluster[1] # type: ignore[index] + ) + # Compute k_tile_cnt + k_tile_cnt = self._compute_k_tile_cnt(self.current_expert_idx, loc=loc, ip=ip) + + work_tile_info = MoEWorkTileInfo( + expert_idx=self.current_expert_idx, + tile_m_idx=cta_tile_m_idx, + tile_n_idx=cta_tile_n_idx, + k_tile_cnt=k_tile_cnt, + ) + return work_tile_info + + @dsl_user_op + @cute.jit + def _advance_expert_to_contain( + self, + cluster_linear_idx: Int32, + *, + loc=None, + ip=None, + ) -> None: + """ + Advance expert tracking state until current expert contains cluster_linear_idx, + or we run out of experts. + + Fast path: If already in correct expert, no work needed. + """ + # Initialize expert_tile_end if this is the first call (expert_tile_end == 0) + if self.expert_tile_end == Int32(0): + tiles_for_expert_0 = self._compute_tiles_for_expert(Int32(0), loc=loc, ip=ip) + self.expert_tile_end = tiles_for_expert_0 + + # Advance until cluster_linear_idx < expert_tile_end or no more experts + while cluster_linear_idx >= self.expert_tile_end and self.current_expert_idx < self.expert_cnt: + self.current_expert_idx = self.current_expert_idx + 1 + self.expert_tile_start = self.expert_tile_end + + if self.current_expert_idx < self.expert_cnt: + tiles_for_expert = self._compute_tiles_for_expert( + self.current_expert_idx, loc=loc, ip=ip + ) + self.expert_tile_end = self.expert_tile_end + tiles_for_expert + + @dsl_user_op + @cute.jit + def _compute_tiles_for_expert( + self, + expert_idx: Int32, + *, + loc=None, + ip=None, + ) -> Int32: + """Compute total cluster tiles for a given expert.""" + if const_expr(self.scenario == "2Dx2D"): + # Fixed M=hidden, N=intermediate + cluster_tile_m_cnt = (self.hidden + self.cluster_tile_m - 1) // self.cluster_tile_m + cluster_tile_n_cnt = (self.intermediate + self.cluster_tile_n - 1) // self.cluster_tile_n + return cluster_tile_m_cnt * cluster_tile_n_cnt + else: # 2Dx3D + # Variable M (tokens), fixed N + tokens_i = self.offs[expert_idx] + if expert_idx > 0: + tokens_i = tokens_i - self.offs[expert_idx - 1] # type: ignore[operator] + cluster_tile_m_cnt = ( + tokens_i + self.cluster_tile_m - 1 # type: ignore[operator] + ) // self.cluster_tile_m + cluster_tile_n_cnt = ( + self.intermediate + self.cluster_tile_n - 1 + ) // self.cluster_tile_n + return cluster_tile_m_cnt * cluster_tile_n_cnt + + @dsl_user_op + @cute.jit + def _decompose_local_idx( + self, + local_idx: Int32, + expert_idx: Int32, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32]: + """ + Decompose local cluster tile index within expert to (cluster_tile_m_idx, cluster_tile_n_idx). + + Uses "short side first" strategy: the shorter dimension changes faster. + This maximizes overlap between adjacent clusters for better L2 cache utilization. + + For example, if m_cnt=2, n_cnt=8: + - N is longer, so M changes faster: local_idx = n_idx * m_cnt + m_idx + - Linearization order: (0,0), (1,0), (0,1), (1,1), (0,2), (1,2), ... + """ + # Get tile counts for M and N + cluster_tile_m_cnt, cluster_tile_n_cnt = self._get_cluster_tile_counts( + expert_idx, loc=loc, ip=ip + ) + cluster_tile_m_idx = -1 + cluster_tile_n_idx = -1 + + # Short side first: shorter dimension changes faster + # If m_cnt <= n_cnt: m is shorter, m changes faster + # local_idx = n_idx * m_cnt + m_idx + # If n_cnt < m_cnt: n is shorter, n changes faster + # local_idx = m_idx * n_cnt + n_idx + if cluster_tile_m_cnt <= cluster_tile_n_cnt: + # M is shorter or equal, M changes faster + cluster_tile_m_idx = local_idx % cluster_tile_m_cnt + cluster_tile_n_idx = local_idx // cluster_tile_m_cnt + else: + # N is shorter, N changes faster + cluster_tile_n_idx = local_idx % cluster_tile_n_cnt + cluster_tile_m_idx = local_idx // cluster_tile_n_cnt + + return (cluster_tile_m_idx, cluster_tile_n_idx) + + @dsl_user_op + @cute.jit + def _get_cluster_tile_counts( + self, + expert_idx: Int32, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32]: + """Get (cluster_tile_m_cnt, cluster_tile_n_cnt) for a given expert.""" + if const_expr(self.scenario == "2Dx2D"): + # Fixed M=hidden, N=intermediate + cluster_tile_m_cnt = (self.hidden + self.cluster_tile_m - 1) // self.cluster_tile_m + cluster_tile_n_cnt = (self.intermediate + self.cluster_tile_n - 1) // self.cluster_tile_n + else: # 2Dx3D + # Variable M (tokens), fixed N + tokens_i = self.offs[expert_idx] + if expert_idx > 0: + tokens_i = tokens_i - self.offs[expert_idx - 1] # type: ignore[operator] + cluster_tile_m_cnt = ( + tokens_i + self.cluster_tile_m - 1 # type: ignore[operator] + ) // self.cluster_tile_m + cluster_tile_n_cnt = ( + self.intermediate + self.cluster_tile_n - 1 + ) // self.cluster_tile_n + return (cluster_tile_m_cnt, cluster_tile_n_cnt) + + @dsl_user_op + @cute.jit + def _compute_k_tile_cnt( + self, + expert_idx: Int32, + *, + loc=None, + ip=None, + ) -> Int32: + """ + Compute the number of K tiles for this expert. + + 2Dx3D: K = hidden (fixed) -> k_tile_cnt = ceil(hidden / cta_tile_k) + 2Dx2D: K = tokens_i (variable) -> k_tile_cnt = ceil(tokens_i / cta_tile_k) + """ + if const_expr(self.scenario == "2Dx3D"): + # K is hidden (fixed) + return (self.hidden + self.cta_tile_k - 1) // self.cta_tile_k + else: # 2Dx2D + # K is tokens_i (variable per expert) + tokens_i = self.offs[expert_idx] + if expert_idx > cutlass.Int32(0): + tokens_i = tokens_i - self.offs[expert_idx - 1] # type: ignore[operator] + return (tokens_i + self.cta_tile_k - 1) // self.cta_tile_k # type: ignore[return-value, operator] diff --git a/reference/moe_moe_sched_extension.py b/reference/moe_moe_sched_extension.py new file mode 100644 index 00000000..e07ce22b --- /dev/null +++ b/reference/moe_moe_sched_extension.py @@ -0,0 +1,443 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" +MoE Scheduler Extension. + +Bridges the MoE tile scheduler (MoEStaticPersistentTileScheduler) with tensor-level +domain conversion and TMA descriptor selection. This is the "glue" layer between: + +- Scheduler: produces MoEWorkTileInfo (expert_idx, tile_m, tile_n, k_tile_cnt) +- OnlineTensormapDescCreator: builds/retrieves TMA descriptors from workspace +- Kernel: orchestrates everything + +Different kernel types (grouped_mm, scaled_grouped_mm, etc.) provide their own +MoESchedExtension subclass with kernel-specific domain conversion logic. + +Key design principles: +- Unified interface: get_gmem_tensor() for all tensor types +- Free implementation: no role-based templates, each subclass writes its own logic +- Composable utilities: compute_expert_token_range, rewrite_tensor_shape, etc. + are available as tools but not mandatory + +Architecture: + + Scheduler ──(produces)──> MoEWorkTileInfo + │ + expert_idx, tile_m, tile_n, k_cnt + │ + v + Extension ──(uses)──> OnlineTensormapDescCreator + │ │ + │ get_gmem_tensor() │ get_desc_ptr() + │ prefetch_for_expert() │ construct_and_write() + │ │ + └── internal calls ───────┘ + + Kernel (caller): the only place that knows all three exist +""" + +from abc import ABC, abstractmethod +from typing import Literal, Tuple, Union + +import cutlass +import cutlass.cute as cute +from cutlass.cute.typing import Pointer +from cutlass.cutlass_dsl import Int32 + +from dataclasses import dataclass + +from cutlass.utils.blockscaled_layout import tile_atom_to_shape_SF +from blackwell.kernel.moe.moe_utils import ( + OnlineTensormapDescCreator, + tensormap_ptr_for_copy, + compute_expert_token_range, + rewrite_tensor_shape, + prefetch_tma_descriptor, +) +from blackwell.kernel.moe.moe_persistent_scheduler import MoEWorkTileInfo + + +@dataclass(frozen=True) +class MoESchedExtension(ABC): + """ + Abstract base class for MoE scheduler extensions. + + Bridges MoEWorkTileInfo with tensor-level domain conversion and TMA + descriptor selection. Each kernel type (grouped_mm, scaled_grouped_mm, etc.) + provides its own subclass with kernel-specific logic. + + The extension: + - Holds a reference to an OnlineTensormapDescCreator for expert-wise desc retrieval + - Implements get_gmem_tensor() to convert MoE-view tensors to per-expert tensors + - Implements prefetch_for_expert() to prefetch expert-wise TMA descriptors + + Subclasses are free to add any additional attributes in __init__ (scenario, + codegen configs, etc.) and implement get_gmem_tensor with arbitrary logic + per tensor_name. No role-based templates or rigid patterns are imposed. + + Usage in kernel (caller): + ext = ConcreteSchedExtension(tensormap_ctor, scenario=...) + + while work_tile_info.is_valid_tile: + real_a, desc_a = ext.get_gmem_tensor("a", tma_tensor_a, offs, work_tile_info) + real_b, desc_b = ext.get_gmem_tensor("b", tma_tensor_b, offs, work_tile_info) + # Use real_a, desc_a in cute.copy ... + """ + + def __init__(self, tensormap_ctor: OnlineTensormapDescCreator): + super().__init__() + self.tensormap_ctor = tensormap_ctor + + @abstractmethod + def get_gmem_tensor( + self, + tensor_name: str, + gmem_tensor_in_moe_view: cute.Tensor, + offs: Union[cute.Tensor, Tuple[cute.Tensor, cute.Tensor]], + work_tile_info: MoEWorkTileInfo, + ) -> Tuple[cute.Tensor, "Pointer | None"]: + """ + Convert an MoE-view tensor to the real per-expert tensor for the + current work tile, and return the appropriate TMA descriptor pointer. + + The MoE-view tensor uses "fake" GEMM domain dimensions that span all + experts (e.g., fake_m = tokens_sum). This method slices/offsets it + to the current expert's actual region. + + :param tensor_name: Identifies which tensor (e.g., "a", "b", "c", "sfa") + :param gmem_tensor_in_moe_view: Tensor in fake GEMM MNKL domain + :param offs: Either a single cumsum tensor (experts,), or a tuple of + (offs_token, offs_padded) where offs_padded provides + padded offsets for scale-factor domain conversion. + :param work_tile_info: Current work tile from the scheduler + :return: (real_tensor, tma_desc_ptr_or_none) + - real_tensor: domain-offset and shape-rewritten tensor for this expert + - tma_desc_ptr: expert-wise desc ptr (already converted for cute.copy), + or None if the caller should use the global TMA descriptor + """ + ... + + @abstractmethod + def prefetch_for_expert(self, expert_idx: Int32) -> None: + """ + Prefetch expert-wise TMA descriptors for the given expert. + + Called when the scheduler advances to a new expert, allowing the TMA + descriptor cache to be warmed up before the descriptors are needed. + + :param expert_idx: Index of the expert whose descriptors to prefetch + """ + ... + + +# ============================================================================= +# Grouped MM Extension +# ============================================================================= + + +class GroupedMmSchedExtension(MoESchedExtension): + """ + MoE scheduler extension for grouped_mm: handles tensors a, b, c. + + Domain conversion logic per scenario: + + 2Dx3D: + A: (fake_m, k, 1) -> offset fake_m by token_offset, global desc + B: (n, k, fake_l) -> offset fake_l by expert_idx, global desc + C: (fake_m, n, 1) -> rewrite shape only, expert-wise desc + + 2Dx2D: + A: (m, fake_k, 1) -> rewrite shape only, expert-wise desc + B: (n, fake_k, 1) -> rewrite shape only, expert-wise desc + C: (m, n, fake_l) -> offset fake_l by expert_idx, global desc + """ + + def __init__( + self, + scenario: Literal["2Dx3D", "2Dx2D"], + tensormap_ctor: OnlineTensormapDescCreator, + ): + super().__init__(tensormap_ctor) + self.scenario = scenario + + @cute.jit + def get_gmem_tensor( + self, + tensor_name: str, + gmem_tensor_in_moe_view: cute.Tensor, + offs: cute.Tensor, + work_tile_info: MoEWorkTileInfo, + ): + expert_idx = work_tile_info.expert_idx + token_offset, tokens_i = compute_expert_token_range(offs, expert_idx) + + shape = gmem_tensor_in_moe_view.shape + c1 = cutlass.Int32(1) + + if cutlass.const_expr(self.scenario == "2Dx3D"): + if cutlass.const_expr(tensor_name == "a"): + # A: (fake_m, k, 1) -> offset fake_m, global desc + real = cute.domain_offset((token_offset, 0, 0), gmem_tensor_in_moe_view) + real = rewrite_tensor_shape(real, (tokens_i, shape[1], c1)) # type: ignore[index] + return (real, None) + elif cutlass.const_expr(tensor_name == "b"): + # B: (n, k, fake_l) -> offset fake_l, global desc + real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view) + real = rewrite_tensor_shape(real, (shape[0], shape[1], c1)) # type: ignore[index] + return (real, None) + elif cutlass.const_expr(tensor_name == "c"): + # C: (fake_m, n, 1) -> expert-wise desc, no offset + real = rewrite_tensor_shape( + gmem_tensor_in_moe_view, + (tokens_i, shape[1], c1), # type: ignore[index] + ) + desc = tensormap_ptr_for_copy( + self.tensormap_ctor.get_desc_ptr("c", expert_idx) + ) + return (real, desc) + + elif cutlass.const_expr(self.scenario == "2Dx2D"): + if cutlass.const_expr(tensor_name == "a"): + # A: (m, fake_k, 1) -> expert-wise desc, no offset + real = rewrite_tensor_shape( + gmem_tensor_in_moe_view, + (shape[0], tokens_i, c1), # type: ignore[index] + ) + desc = tensormap_ptr_for_copy( + self.tensormap_ctor.get_desc_ptr("a", expert_idx) + ) + return (real, desc) + elif cutlass.const_expr(tensor_name == "b"): + # B: (n, fake_k, 1) -> expert-wise desc, no offset + real = rewrite_tensor_shape( + gmem_tensor_in_moe_view, + (shape[0], tokens_i, c1), # type: ignore[index] + ) + desc = tensormap_ptr_for_copy( + self.tensormap_ctor.get_desc_ptr("b", expert_idx) + ) + return (real, desc) + elif cutlass.const_expr(tensor_name == "c"): + # C: (m, n, fake_l) -> offset fake_l, global desc + real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view) + real = rewrite_tensor_shape(real, (shape[0], shape[1], c1)) # type: ignore[index] + return (real, None) + + raise ValueError("Invalid scenario or GEMM tensor name.") + + @cute.jit + def prefetch_for_expert(self, expert_idx: Int32) -> None: + if cutlass.const_expr(self.scenario == "2Dx3D"): + prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("c", expert_idx)) + elif cutlass.const_expr(self.scenario == "2Dx2D"): + prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("a", expert_idx)) + prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("b", expert_idx)) + else: + raise ValueError("Invalid scenario.") + + + +# ============================================================================= +# Scaled Grouped MM Extension (block-scaled MoE) +# ============================================================================= + + +class ScaledGroupedMmSchedExtension(MoESchedExtension): + """ + MoE scheduler extension for scaled_grouped_mm: handles a, b, c, sfa, sfb. + + Extends GroupedMmSchedExtension with scale-factor tensor support. + SFA/SFB are passed as flat GEMM-domain tensors and atom-tiled per expert + via tile_atom_to_shape_SF. + + The offs parameter is always a tuple (offs_token, offs_padded): + - offs_token: cumsum offsets in data (activation) domain + - offs_padded: cumsum offsets in scale-factor domain (padded to atom granularity) + + sf_vec_size is obtained from self.tensormap_ctor.sf_vec_size. + + Domain conversion logic per scenario: + + 2Dx3D: + A: (fake_m, k, 1) -> offset fake_m by token_offset, global desc + B: (n, k, fake_l) -> offset fake_l by expert_idx, global desc + C: (fake_m, n, 1) -> rewrite shape, expert-wise desc + SFA: (fake_m_pad, k_pad, 1) -> offset fake_m_pad by padded_offset, + atom-tile, global desc + SFB: (n_pad, k_pad, fake_l) -> offset fake_l by expert_idx, + atom-tile, global desc + + 2Dx2D: + A: (m, fake_k, 1) -> rewrite shape, expert-wise desc + B: (n, fake_k, 1) -> rewrite shape, expert-wise desc + C: (m, n, fake_l) -> offset fake_l by expert_idx, global desc + SFA: (m_pad, fake_k_pad, 1) -> offset fake_k_pad by padded_offset, + atom-tile, expert-wise desc + SFB: (n_pad, fake_k_pad, 1) -> offset fake_k_pad by padded_offset, + atom-tile, expert-wise desc + """ + + def __init__( + self, + scenario: Literal["2Dx3D", "2Dx2D"], + tensormap_ctor: OnlineTensormapDescCreator, + ): + super().__init__(tensormap_ctor) + self.scenario = scenario + + @cute.jit + def get_gmem_tensor( + self, + tensor_name: str, + gmem_tensor_in_moe_view: cute.Tensor, + offs: Tuple[cute.Tensor, cute.Tensor], + work_tile_info: MoEWorkTileInfo, + ): + # Unpack the offs tuple + offs_token, offs_padded = offs + + expert_idx = work_tile_info.expert_idx + token_offset, tokens_i = compute_expert_token_range(offs_token, expert_idx) + padded_offset, padded_size_i = compute_expert_token_range( + offs_padded, expert_idx + ) + + shape = gmem_tensor_in_moe_view.shape + stride = gmem_tensor_in_moe_view.stride + c1 = cutlass.Int32(1) + sf_vec_size = self.tensormap_ctor.sf_vec_size + + if cutlass.const_expr(self.scenario == "2Dx3D"): + if cutlass.const_expr(tensor_name == "a"): + # A: (fake_m, k, 1) -> offset fake_m, global desc + real = cute.domain_offset((token_offset, 0, 0), gmem_tensor_in_moe_view) + real = rewrite_tensor_shape(real, (tokens_i, shape[1], c1)) # type: ignore[index] + return (real, None) + + elif cutlass.const_expr(tensor_name == "b"): + # B: (n, k, fake_l) -> offset fake_l, global desc + real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view) + real = rewrite_tensor_shape(real, (shape[0], shape[1], c1)) # type: ignore[index] + return (real, None) + + elif cutlass.const_expr(tensor_name == "c"): + # C: (fake_m, n, 1) -> expert-wise desc + real = rewrite_tensor_shape( + gmem_tensor_in_moe_view, + (tokens_i, shape[1], c1), # type: ignore[index] + ) + desc = tensormap_ptr_for_copy( + self.tensormap_ctor.get_desc_ptr("c", expert_idx) + ) + return (real, desc) + + elif cutlass.const_expr(tensor_name == "sfa"): + # SFA: (fake_m_pad, k_pad, 1) -> offset fake_m_pad, atom-tile, global desc + real = cute.domain_offset( + (padded_offset, 0, 0), gmem_tensor_in_moe_view + ) + per_expert_shape = (padded_size_i, shape[1], c1) # type: ignore[index] + sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size) + real = cute.make_tensor( + real.iterator, cute.make_layout(sf_layout.shape, stride=stride) + ) + return (real, None) + + elif cutlass.const_expr(tensor_name == "sfb"): + # SFB: (n_pad, k_pad, fake_l) -> offset fake_l, atom-tile, global desc + real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view) + per_expert_shape = (shape[0], shape[1], c1) # type: ignore[index] + sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size) + real = cute.make_tensor( + real.iterator, cute.make_layout(sf_layout.shape, stride=stride) + ) + return (real, None) + + elif cutlass.const_expr(self.scenario == "2Dx2D"): + if cutlass.const_expr(tensor_name == "a"): + # A: (m, fake_k, 1) -> expert-wise desc + real = rewrite_tensor_shape( + gmem_tensor_in_moe_view, + (shape[0], tokens_i, c1), # type: ignore[index] + ) + desc = tensormap_ptr_for_copy( + self.tensormap_ctor.get_desc_ptr("a", expert_idx) + ) + return (real, desc) + + elif cutlass.const_expr(tensor_name == "b"): + # B: (n, fake_k, 1) -> expert-wise desc + real = rewrite_tensor_shape( + gmem_tensor_in_moe_view, + (shape[0], tokens_i, c1), # type: ignore[index] + ) + desc = tensormap_ptr_for_copy( + self.tensormap_ctor.get_desc_ptr("b", expert_idx) + ) + return (real, desc) + + elif cutlass.const_expr(tensor_name == "c"): + # C: (m, n, fake_l) -> offset fake_l, global desc + real = cute.domain_offset((0, 0, expert_idx), gmem_tensor_in_moe_view) + real = rewrite_tensor_shape(real, (shape[0], shape[1], c1)) # type: ignore[index] + return (real, None) + + elif cutlass.const_expr(tensor_name == "sfa"): + # SFA: (m_pad, fake_k_pad, 1) -> offset fake_k_pad, atom-tile, expert-wise desc + per_expert_shape = (shape[0], padded_size_i, c1) # type: ignore[index] + sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size) + real = rewrite_tensor_shape(gmem_tensor_in_moe_view, sf_layout.shape) + desc = tensormap_ptr_for_copy( + self.tensormap_ctor.get_desc_ptr("sfa", expert_idx) + ) + return (real, desc) + + elif cutlass.const_expr(tensor_name == "sfb"): + # SFB: (n_pad, fake_k_pad, 1) -> offset fake_k_pad, atom-tile, expert-wise desc + per_expert_shape = (shape[0], padded_size_i, c1) # type: ignore[index] + sf_layout = tile_atom_to_shape_SF(per_expert_shape, sf_vec_size) + real = rewrite_tensor_shape(gmem_tensor_in_moe_view, sf_layout.shape) + desc = tensormap_ptr_for_copy( + self.tensormap_ctor.get_desc_ptr("sfb", expert_idx) + ) + return (real, desc) + + raise ValueError("Invalid scenario or tensor name.") + + @cute.jit + def prefetch_for_expert(self, expert_idx: Int32) -> None: + if cutlass.const_expr(self.scenario == "2Dx3D"): + prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("c", expert_idx)) + elif cutlass.const_expr(self.scenario == "2Dx2D"): + prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("a", expert_idx)) + prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("b", expert_idx)) + prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("sfa", expert_idx)) + prefetch_tma_descriptor(self.tensormap_ctor.get_desc_ptr("sfb", expert_idx)) + else: + raise ValueError("Invalid scenario.") diff --git a/reference/moe_moe_utils.py b/reference/moe_moe_utils.py new file mode 100644 index 00000000..e21d0389 --- /dev/null +++ b/reference/moe_moe_utils.py @@ -0,0 +1,910 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" +Online TMA Descriptor Construction Utilities. + +Provides utilities for dynamically creating TMA descriptors at kernel runtime +based on runtime-provided information (problem sizes, pointers, etc.). + +Key components: +- OnlineTensormapDescCreator: Simplified ABC for TMA descriptor builders (2 abstract methods) +- TensormapWorkspace: Helper for linear workspace layout of TMA descriptors +- MoEGroupedGemmTensormapConstructor: TMA descriptor constructor for MoE Grouped GEMM +- GeneralGroupedGemmTensormapConstructor: TMA descriptor constructor for general Grouped GEMM +- Pointer utility functions (ptr_offset_bytes, gmem_ptr_to_generic, etc.) +- tensormap_ptr_for_copy: Convert raw desc ptr to cute.copy-compatible type +- compute_expert_token_range: Compute per-expert token offset and count from offs +- rewrite_tensor_shape: Debug-friendly tensor shape rewrite utility +""" + +from abc import ABC, abstractmethod +from typing import Optional, Literal, Tuple, Union + +import cutlass +import cutlass.cute as cute +from cutlass.cute.typing import AddressSpace, Pointer +from cutlass.cute.nvgpu import cpasync +from cutlass.cutlass_dsl import dsl_user_op, Int32 +from cutlass._mlir import ir +from cutlass._mlir.dialects import llvm +from cutlass._mlir.dialects import cute as _cute_ir +from cutlass._mlir.dialects import cute_nvgpu as _cute_nvgpu_ir +from dataclasses import dataclass + +from cutlass.utils.blockscaled_layout import tile_atom_to_shape_SF + +TensormapDescBytes = 128 + +# ============================================================================= +# Pointer Utilities +# ============================================================================= + + +@dsl_user_op +@cute.jit +def spin_wait( + ptr: Pointer, condition, fail_sleep_cycles: int = 100, *, loc=None, ip=None +) -> None: + """ + Generic spin-wait. + Example usage: + # Wait until counter >= total_blocks + spin_wait(counter_ptr, lambda x: x >= total_blocks, fail_sleep_cycles=100) + + # Wait until flag == 1 + spin_wait(flag_ptr, lambda x: x == 1) + """ + current = cute.arch.load(ptr, ptr.dtype, cop="cg", loc=loc, ip=ip) + while not condition(current): + # Load with L1 cache bypass (ld.global.cg) + if cutlass.const_expr(fail_sleep_cycles > 0): + cute.arch.nanosleep(sleep_time=fail_sleep_cycles, loc=loc, ip=ip) + current = cute.arch.load(ptr, ptr.dtype, cop="cg", loc=loc, ip=ip) + + +@dsl_user_op +def gmem_ptr_to_generic( + gmem_ptr: Pointer, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Pointer: + if gmem_ptr.memspace != AddressSpace.gmem: + raise ValueError( + f"gmem_ptr_to_generic requires pointer in gmem address space, " + f"got {gmem_ptr.memspace}" + ) + # Get LLVM pointer and cast to generic address space + llvm_ptr = gmem_ptr.to_llvm_ptr(loc=loc, ip=ip) + generic_llvm_ptr = llvm.addrspacecast( + llvm.PointerType.get(AddressSpace.generic), llvm_ptr, loc=loc, ip=ip + ) + # Create a new cute.Pointer with generic address space, preserving alignment + return cute.make_ptr( + gmem_ptr.dtype, + generic_llvm_ptr, + AddressSpace.generic, + assumed_align=gmem_ptr.alignment, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def generic_ptr_to_gmem( + generic_ptr: Pointer, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Pointer: + if generic_ptr.memspace != AddressSpace.generic: + raise ValueError( + f"generic_ptr_to_gmem requires pointer in generic address space, " + f"got {generic_ptr.memspace}" + ) + # Get LLVM pointer and cast to gmem address space + llvm_ptr = generic_ptr.to_llvm_ptr(loc=loc, ip=ip) + gmem_llvm_ptr = llvm.addrspacecast( + llvm.PointerType.get(AddressSpace.gmem), llvm_ptr, loc=loc, ip=ip + ) + # Create a new cute.Pointer with gmem address space, preserving alignment + return cute.make_ptr( + generic_ptr.dtype, + gmem_llvm_ptr, + AddressSpace.gmem, + assumed_align=generic_ptr.alignment, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def prefetch_tma_descriptor(tma_desc_ptr: Pointer, *, loc=None, ip=None) -> None: + """ + Prefetch a TMA descriptor from global memory. + + This function prefetches the TMA descriptor pointed to by tma_desc_ptr + into the TMA descriptor cache. The pointer must be in generic or global + address space. If a gmem pointer is passed, it will be automatically + converted to generic address space. + + :param tma_desc_ptr: Pointer to the TMA descriptor in global or generic memory + :type tma_desc_ptr: Pointer + :raises ValueError: If pointer is not in generic or global address space + """ + if tma_desc_ptr.memspace not in (AddressSpace.gmem, AddressSpace.generic): + raise ValueError( + f"prefetch_tma_descriptor requires pointer in gmem or generic address space, " + f"got {tma_desc_ptr.memspace}" + ) + # Convert gmem pointer to generic if needed + if tma_desc_ptr.memspace == AddressSpace.gmem: + tma_desc_ptr = gmem_ptr_to_generic(tma_desc_ptr, loc=loc, ip=ip) + # Convert cute.Pointer to LLVM pointer for prefetch + llvm_ptr = tma_desc_ptr.to_llvm_ptr(loc=loc, ip=ip) + from cutlass.cute.arch.nvvm_wrappers import prefetch as nvvm_prefetch + + nvvm_prefetch(llvm_ptr, tensormap=True, loc=loc, ip=ip) + + +def ptr_offset_bytes(ptr: Pointer, byte_offset: int) -> Pointer: + """Offset a pointer by a given number of bytes.""" + element_offset = byte_offset * 8 // ptr.dtype.width + return ptr + element_offset + + +@dsl_user_op +def tensormap_ptr_for_copy(raw_ptr: Pointer, *, loc=None, ip=None) -> Pointer: + """ + Convert a raw TMA descriptor gmem pointer to the type expected by cute.copy. + + cute.copy requires the tma_desc_ptr to be in generic address space and + recast to TmaDescriptorTiledType. This utility performs both conversions. + + :param raw_ptr: Raw pointer to TMA descriptor in gmem + :type raw_ptr: Pointer + :return: Pointer compatible with cute.copy's tma_desc_ptr parameter + :rtype: Pointer + """ + generic_ptr = gmem_ptr_to_generic(raw_ptr, loc=loc, ip=ip) + tma_desc_ptr_ty = _cute_ir.PtrType.get( + _cute_nvgpu_ir.TmaDescriptorTiledType.get(), + generic_ptr.memspace, + generic_ptr.alignment, + ) + return _cute_ir.recast_iter(tma_desc_ptr_ty, generic_ptr.value) + + +# ============================================================================= +# MoE Utilities +# ============================================================================= + + +@dsl_user_op +@cute.jit +def compute_expert_token_range( + offs: cute.Tensor, + expert_idx: Int32, + *, + loc=None, + ip=None, +) -> Tuple[Int32, Int32]: + """ + Compute token offset and count for a given expert from the cumsum offs tensor. + + :param offs: Cumulative sum tensor of token counts per expert, shape (experts,) + :param expert_idx: Index of the expert + :return: (token_offset, tokens_i) where token_offset is the start position + and tokens_i is the number of tokens for this expert + """ + token_offset = Int32(0) + if expert_idx > Int32(0): + token_offset = offs[expert_idx - 1] # type: ignore[assignment] + tokens_i = offs[expert_idx] - token_offset + return token_offset, tokens_i + + +@dsl_user_op +def rewrite_tensor_shape( + tensor: cute.Tensor, + new_shape: Tuple, + *, + loc=None, + ip=None, +) -> cute.Tensor: + """ + Rewrite tensor shape while keeping the same stride and iterator. + + This is primarily for debug friendliness - shows the actual expert's shape + instead of the fake global shape. No runtime overhead as it becomes + dead code in non-debug builds. + + :param tensor: Source tensor whose stride and iterator to preserve + :param new_shape: New shape to apply + :return: New tensor with the given shape but original stride and iterator + """ + new_layout = cute.make_layout(new_shape, stride=tensor.stride, loc=loc, ip=ip) + return cute.make_tensor(tensor.iterator, new_layout, loc=loc, ip=ip) + + +# ============================================================================= +# TMA Descriptor Workspace Helper +# ============================================================================= + + +class TensormapWorkspace: + """ + Helper for linear workspace layout of TMA descriptors. + + Manages address calculation for a workspace buffer containing TMA descriptors + organized as: for each executor (e.g., expert or group), a fixed set of + named descriptor slots. + + Layout: [slot_0_exec_0, slot_1_exec_0, ..., slot_0_exec_1, slot_1_exec_1, ...] + + Example: + # 2Dx3D MoE: only C is expert-wise + workspace = TensormapWorkspace(workspace_ptr, ["c"]) + + # 2Dx2D MoE: A and B are expert-wise + workspace = TensormapWorkspace(workspace_ptr, ["a", "b"]) + + # General grouped GEMM: all three tensors + workspace = TensormapWorkspace(workspace_ptr, ["a", "b", "c"]) + """ + + def __init__(self, workspace_ptr: Pointer, slot_names: list): + """ + :param workspace_ptr: Pointer to the beginning of the workspace buffer + :param slot_names: Ordered list of tensor names, defining the slot layout + per executor. e.g., ["a", "b", "c"] + """ + self.workspace_ptr = workspace_ptr + self._name_to_slot = {name: i for i, name in enumerate(slot_names)} + self._slots_per_executor = len(slot_names) + + @cute.jit + def get_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer: + """ + Get the workspace pointer for a specific TMA descriptor. + + :param tensor_name: Name of the tensor (must be one of the slot_names) + :param executor_idx: Index of the executor (e.g., group_idx or expert_idx) + :return: Aligned pointer to the TMA descriptor in workspace + """ + if cutlass.const_expr(tensor_name not in self._name_to_slot): + raise ValueError( + f"Invalid tensor_name '{tensor_name}', " + f"expected one of {list(self._name_to_slot.keys())}" + ) + slot = self._name_to_slot[tensor_name] + byte_offset = ( + executor_idx * self._slots_per_executor + slot + ) * TensormapDescBytes + return ptr_offset_bytes(self.workspace_ptr, byte_offset).align( + TensormapDescBytes + ) + + @staticmethod + def size_bytes(num_slots: int, num_executors: int) -> int: + """ + Calculate workspace size in bytes. + + :param num_slots: Number of descriptor slots per executor + :param num_executors: Total number of executors (e.g., expert_cnt or group_cnt) + :return: Total workspace size in bytes + """ + return num_slots * num_executors * TensormapDescBytes + + +# ============================================================================= +# Online TMA Descriptor Creator (Abstract Base Class) +# ============================================================================= + + +@dataclass(frozen=True) +class OnlineTensormapDescCreator(ABC): + """ + Abstract base class for building TMA descriptors online (at kernel runtime). + + Subclasses store all needed parameters (both codegen-time configs and runtime + values) as explicit instance attributes in __init__. No dict-based APIs. + + Subclasses must implement exactly 2 abstract methods: + - construct_and_write: Build TMA descriptor(s) for one executor and write to workspace + - get_desc_ptr: Return raw gmem pointer to a specific descriptor in workspace + + To convert the raw pointer for use with cute.copy, callers should use the + standalone tensormap_ptr_for_copy() utility. + """ + + @abstractmethod + def construct_and_write(self, executor_idx: Int32, dependency=None) -> None: + """ + Build TMA descriptor(s) for one executor and write to workspace. + + :param executor_idx: Index of the executor (e.g., group_idx or expert_idx). + Semantics may vary by subclass when ``dependency`` is provided. + :param dependency: Optional pipeline consumer for inter-warp-group + synchronization. When provided, the subclass decides when to wait + (via ``dependency.wait_and_advance()``) and release. The subclass + also decides how to interpret ``executor_idx`` in this mode. + """ + ... + + @abstractmethod + def get_desc_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer: + """ + Get the raw gmem pointer to a specific TMA descriptor in workspace. + + :param tensor_name: Name identifying which tensor's descriptor + :param executor_idx: Index of the executor (e.g., group_idx or expert_idx) + :return: Raw pointer (gmem) to the TMA descriptor + """ + ... + + +# ============================================================================= +# MoE Grouped GEMM Tensormap Constructor +# ============================================================================= + + +class MoEGroupedGemmTensormapConstructor(OnlineTensormapDescCreator): + """ + Tensormap descriptor constructor for MoE Grouped GEMM (expert-wise descriptors only). + + Non-expert-wise descriptors are passed directly at kernel launch. + This class only handles: + - 2Dx3D: C descriptors (expert-wise, to avoid write conflicts) + - 2Dx2D: A and B descriptors (expert-wise, tokens is reduction axis) + + All parameters are stored as explicit instance attributes (no dicts). + + Workspace layout: + - 2Dx3D: [C_0, C_1, ..., C_{n-1}] + - 2Dx2D: [A_0, A_1, ..., A_{n-1}, B_0, B_1, ..., B_{n-1}] + """ + + def __init__( + self, + scenario: Literal["2Dx3D", "2Dx2D"], + # Codegen-time configs + a_dtype, + b_dtype, + c_dtype, + a_smem_layout, + b_smem_layout, + epi_smem_layout, + a_tma_op, + b_tma_op, + c_tma_op, + tiled_mma, + mma_tiler, + cluster_layout_vmnk_shape, + epi_tile, + # Runtime params + a_tensor: cute.Tensor, # fake GEMM domain A + b_tensor: cute.Tensor, # fake GEMM domain B + c_tensor: cute.Tensor, # fake GEMM domain C + offs: cute.Tensor, # (experts,) cumsum + workspace_ptr: Pointer, + ) -> None: + super().__init__() + self.scenario = scenario + # Codegen-time configs + self.a_dtype = a_dtype + self.b_dtype = b_dtype + self.c_dtype = c_dtype + self.a_smem_layout = a_smem_layout + self.b_smem_layout = b_smem_layout + self.epi_smem_layout = epi_smem_layout + self.a_tma_op = a_tma_op + self.b_tma_op = b_tma_op + self.c_tma_op = c_tma_op + self.tiled_mma = tiled_mma + self.mma_tiler = mma_tiler + self.cluster_layout_vmnk_shape = cluster_layout_vmnk_shape + self.epi_tile = epi_tile + # Runtime params + self.a_tensor = a_tensor + self.b_tensor = b_tensor + self.c_tensor = c_tensor + self.offs = offs + # Workspace with scenario-specific slot layout + if scenario == "2Dx3D": + self.workspace = TensormapWorkspace(workspace_ptr, ["c"]) + else: + self.workspace = TensormapWorkspace(workspace_ptr, ["a", "b"]) + + @staticmethod + def get_workspace_size(scenario: Literal["2Dx3D", "2Dx2D"], expert_cnt: int) -> int: + """Calculate workspace size in bytes for tensormap descriptors.""" + if scenario == "2Dx3D": + return TensormapWorkspace.size_bytes(1, expert_cnt) # only C + else: + return TensormapWorkspace.size_bytes(2, expert_cnt) # A and B + + @cute.jit + def get_desc_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer: + return self.workspace.get_ptr(tensor_name, executor_idx) + + @cute.jit + def construct_and_write(self, executor_idx: Int32, dependency=None) -> None: + """ + Create expert-wise tensormap descriptors for the given expert. + + - 2Dx3D: Creates C descriptor for this expert + - 2Dx2D: Creates A and B descriptors for this expert + """ + if cutlass.const_expr(self.scenario == "2Dx3D"): + self._construct_c_desc_2dx3d(executor_idx) + else: # 2Dx2D + self._construct_ab_descs_2dx2d(executor_idx) + + @cute.jit + def _construct_c_desc_2dx3d(self, expert_idx: Int32) -> None: + """ + 2Dx3D: Create expert-wise C descriptor. + C tensor: (fake_m, n, 1) = (tokens_sum, intermediate, 1) + Slice fake_m -> (tokens_i, intermediate, 1) per expert. + """ + token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx) + + c_ptr = self.c_tensor.iterator + c_stride = self.c_tensor.stride + intermediate = self.c_tensor.shape[1] # type: ignore[index] + + c1 = cutlass.Int32(1) + c0 = cutlass.Int32(0) + + c_ptr_i = c_ptr + token_offset * c_stride[0] # type: ignore[index] + c_layout_i = cute.make_layout( + (tokens_i, intermediate, c1), + stride=(c_stride[0], c_stride[1], c0), # type: ignore[index] + ) + c_tensor_i = cute.make_tensor(c_ptr_i, c_layout_i) + + tma_atom_c, _ = cpasync.make_tiled_tma_atom( + self.c_tma_op, + c_tensor_i, + self.epi_smem_layout, + self.epi_tile, + ) + cpasync.copy_tensormap(tma_atom_c, self.get_desc_ptr("c", expert_idx)) + + @cute.jit + def _construct_ab_descs_2dx2d(self, expert_idx: Int32) -> None: + """ + 2Dx2D: Create expert-wise A and B descriptors. + A: (m, fake_k, 1) -> slice to (m, tokens_i, 1) + B: (n, fake_k, 1) -> slice to (n, tokens_i, 1) + """ + token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx) + + c1 = cutlass.Int32(1) + c0 = cutlass.Int32(0) + + # A tensor: (m, fake_k, 1) -> (m, tokens_i, 1) + a_ptr = self.a_tensor.iterator + a_stride = self.a_tensor.stride + a_m = self.a_tensor.shape[0] # type: ignore[index] + + a_ptr_i = a_ptr + token_offset * a_stride[1] # type: ignore[index] + a_layout_i = cute.make_layout( + (a_m, tokens_i, c1), + stride=(a_stride[0], a_stride[1], c0), # type: ignore[index] + ) + a_tensor_i = cute.make_tensor(a_ptr_i, a_layout_i) + + tma_atom_a, _ = cute.nvgpu.make_tiled_tma_atom_A( + self.a_tma_op, + a_tensor_i, + self.a_smem_layout, + self.mma_tiler, + self.tiled_mma, + self.cluster_layout_vmnk_shape, + ) + cpasync.copy_tensormap(tma_atom_a, self.get_desc_ptr("a", expert_idx)) + + # B tensor: (n, fake_k, 1) -> (n, tokens_i, 1) + b_ptr = self.b_tensor.iterator + b_stride = self.b_tensor.stride + b_n = self.b_tensor.shape[0] # type: ignore[index] + + b_ptr_i = b_ptr + token_offset * b_stride[1] # type: ignore[index] + b_layout_i = cute.make_layout( + (b_n, tokens_i, c1), + stride=(b_stride[0], b_stride[1], c0), # type: ignore[index] + ) + b_tensor_i = cute.make_tensor(b_ptr_i, b_layout_i) + + tma_atom_b, _ = cute.nvgpu.make_tiled_tma_atom_B( + self.b_tma_op, + b_tensor_i, + self.b_smem_layout, + self.mma_tiler, + self.tiled_mma, + self.cluster_layout_vmnk_shape, + ) + cpasync.copy_tensormap(tma_atom_b, self.get_desc_ptr("b", expert_idx)) + + +# ============================================================================= +# MoE Scaled Grouped GEMM Tensormap Constructor +# ============================================================================= + + +class MoEScaledGroupedGemmTensormapConstructor(OnlineTensormapDescCreator): + """ + Tensormap descriptor constructor for MoE Scaled Grouped GEMM (block-scaled). + + .. py:attribute:: ChunkSize + :value: 128 + + Number of experts processed per chunk in the desc_init_kernel. + Must match the warp-group width (4 warps × 32 threads). + + Extends MoEGroupedGemmTensormapConstructor with SFA/SFB descriptor support. + + Expert-wise descriptors only — non-expert-wise descriptors are passed + directly at kernel launch. + + Workspace layout: + - 2Dx3D: [C_0, C_1, ..., C_{n-1}] (1 slot per expert) + - 2Dx2D: [A_0, B_0, SFA_0, SFB_0, A_1, B_1, SFA_1, SFB_1, ...] (4 slots per expert) + + :param scenario: "2Dx3D" or "2Dx2D" + :param sf_vec_size: Scale factor vector size (32 for MXFP8/MXFP4, 16 for NVFP4) + :param a_dtype: Data type for tensor A + :param b_dtype: Data type for tensor B + :param c_dtype: Data type for tensor C + :param sf_dtype: Data type for scale factors (SFA/SFB) + :param a_smem_layout: SMEM layout for A TMA + :param b_smem_layout: SMEM layout for B TMA + :param epi_smem_layout: SMEM layout for epilogue (C) TMA + :param sfa_smem_layout: SMEM layout for SFA TMA + :param sfb_smem_layout: SMEM layout for SFB TMA + :param a_tma_op: TMA operation for A + :param b_tma_op: TMA operation for B + :param c_tma_op: TMA operation for C (S2G store or reduce) + :param sfa_tma_op: TMA operation for SFA + :param sfb_tma_op: TMA operation for SFB + :param tiled_mma: TiledMma for A/B/SFA/C TMA atom construction + :param tiled_mma_sfb: TiledMma for SFB (separate due to 2CTA replication) + :param mma_tiler: MMA tiler shape (M, N, K) + :param mma_tiler_sfb: MMA tiler shape for SFB + :param cluster_layout_vmnk_shape: Cluster layout shape for A/B/SFA multicast + :param cluster_layout_sfb_vmnk_shape: Cluster layout shape for SFB multicast + :param epi_tile: Epilogue tile shape + :param a_tensor: Fake GEMM domain A tensor + :param b_tensor: Fake GEMM domain B tensor + :param c_tensor: Fake GEMM domain C tensor + :param sfa_tensor: Fake GEMM domain SFA tensor (atom-tiled layout) + :param sfb_tensor: Fake GEMM domain SFB tensor (atom-tiled layout) + :param offs: (experts,) cumsum offsets in data domain + :param offs_padded: (experts,) cumsum offsets in padded scale domain + :param workspace_ptr: Pointer to workspace for TMA descriptors + :param expert_cnt: Total number of experts + """ + + ChunkSize = 128 + + def __init__( + self, + scenario: Literal["2Dx3D", "2Dx2D"], + sf_vec_size: int, + # Codegen-time configs: dtypes + a_dtype, + b_dtype, + c_dtype, + sf_dtype, + # Codegen-time configs: SMEM layouts + a_smem_layout, + b_smem_layout, + epi_smem_layout, + sfa_smem_layout, + sfb_smem_layout, + # Codegen-time configs: TMA ops + a_tma_op, + b_tma_op, + c_tma_op, + sfa_tma_op, + sfb_tma_op, + # Codegen-time configs: MMA / cluster / tile + tiled_mma, + tiled_mma_sfb, + mma_tiler, + mma_tiler_sfb, + cluster_layout_vmnk_shape, + cluster_layout_sfb_vmnk_shape, + epi_tile, + # Runtime params + a_tensor: cute.Tensor, + b_tensor: cute.Tensor, + c_tensor: cute.Tensor, + sfa_tensor: cute.Tensor, + sfb_tensor: cute.Tensor, + offs: cute.Tensor, + offs_padded: cute.Tensor, + workspace_ptr: Pointer, + expert_cnt: Optional[Union[Int32, int]] = None, + ) -> None: + super().__init__() + self.scenario = scenario + self.sf_vec_size = sf_vec_size + # Dtypes + self.a_dtype = a_dtype + self.b_dtype = b_dtype + self.c_dtype = c_dtype + self.sf_dtype = sf_dtype + # SMEM layouts + self.a_smem_layout = a_smem_layout + self.b_smem_layout = b_smem_layout + self.epi_smem_layout = epi_smem_layout + self.sfa_smem_layout = sfa_smem_layout + self.sfb_smem_layout = sfb_smem_layout + # TMA ops + self.a_tma_op = a_tma_op + self.b_tma_op = b_tma_op + self.c_tma_op = c_tma_op + self.sfa_tma_op = sfa_tma_op + self.sfb_tma_op = sfb_tma_op + # MMA / cluster / tile + self.tiled_mma = tiled_mma + self.tiled_mma_sfb = tiled_mma_sfb + self.mma_tiler = mma_tiler + self.mma_tiler_sfb = mma_tiler_sfb + self.cluster_layout_vmnk_shape = cluster_layout_vmnk_shape + self.cluster_layout_sfb_vmnk_shape = cluster_layout_sfb_vmnk_shape + self.epi_tile = epi_tile + # Runtime params + self.a_tensor = a_tensor + self.b_tensor = b_tensor + self.c_tensor = c_tensor + self.sfa_tensor = sfa_tensor + self.sfb_tensor = sfb_tensor + self.offs = offs + self.offs_padded = offs_padded + self.expert_cnt = expert_cnt + # Workspace with scenario-specific slot layout + if scenario == "2Dx3D": + self.workspace = TensormapWorkspace(workspace_ptr, ["c"]) + else: + self.workspace = TensormapWorkspace(workspace_ptr, ["a", "b", "sfa", "sfb"]) + + @staticmethod + def get_workspace_size(scenario: Literal["2Dx3D", "2Dx2D"], expert_cnt: int) -> int: + """Calculate workspace size in bytes for tensormap descriptors.""" + if scenario == "2Dx3D": + return TensormapWorkspace.size_bytes(1, expert_cnt) # C only + else: + return TensormapWorkspace.size_bytes(4, expert_cnt) # A, B, SFA, SFB + + @cute.jit + def get_desc_ptr(self, tensor_name: str, executor_idx: Int32) -> Pointer: + return self.workspace.get_ptr(tensor_name, executor_idx) + + @cute.jit + def construct_and_write(self, lane_in_group: Int32, dependency=None) -> None: + """ + Create expert-wise tensormap descriptors for all experts. + + ``lane_in_group`` is the thread's position within its warp group + (0..ChunkSize-1). The method loops internally over all experts in + chunks of ``ChunkSize``, with two-phase pipeline synchronization + per chunk. + + Per-chunk execution: + + 1. Phase 1: Build descriptors that do NOT depend on ``offs_padded`` + (A/B for 2Dx2D, C for 2Dx3D). Overlaps with Group A's prefix sum. + 2. Barrier: ``consumer.wait_and_advance()`` — all threads participate. + 3. Phase 2: Build descriptors that depend on ``offs_padded`` + (SFA/SFB for 2Dx2D). Reads padded offsets from SMEM buffer. + 4. Release: ``handle.release()`` — all threads participate. + + :param lane_in_group: Thread's position within the warp group (0..127). + :param dependency: ``(PipelineConsumer, smem_offs_padded)`` — the + consumer for mbarrier sync, and the SMEM tensor of shape + ``(ChunkSize + 1,)`` with layout ``[carry, offs_padded[0..127]]``. + """ + consumer, smem_offs_padded = dependency + assert self.expert_cnt is not None + num_chunks = (self.expert_cnt + self.ChunkSize - 1) // self.ChunkSize + + chunk_idx = cutlass.Int32(0) + while chunk_idx < num_chunks: + expert_idx = chunk_idx * self.ChunkSize + lane_in_group + in_bounds = expert_idx < self.expert_cnt + + # Phase 1: non-dependent descriptors + if in_bounds: + if cutlass.const_expr(self.scenario == "2Dx2D"): + self._construct_ab_descs_2dx2d(expert_idx) + else: + self._construct_c_desc_2dx3d(expert_idx) + + # All threads participate in barrier (fixed arrive count) + handle = consumer.wait_and_advance() + + # Phase 2: dependent descriptors (read padded offsets from SMEM) + if in_bounds: + if cutlass.const_expr(self.scenario == "2Dx2D"): + # smem_offs_padded layout: [carry, chunk[0], ..., chunk[127]] + # padded_offset = smem[lane] (prev expert's cumulative) + # padded_end = smem[lane + 1] (this expert's cumulative) + padded_offset = smem_offs_padded[lane_in_group] + padded_size_i = smem_offs_padded[lane_in_group + 1] - padded_offset + self._construct_sf_descs_2dx2d_direct( + expert_idx, padded_offset, padded_size_i + ) + + # All threads release (fixed arrive count) + handle.release() + + chunk_idx += 1 + + # ----------------------------------------------------------------- + # 2Dx3D: C descriptor (same as MoEGroupedGemmTensormapConstructor) + # ----------------------------------------------------------------- + + @cute.jit + def _construct_c_desc_2dx3d(self, expert_idx: Int32) -> None: + """ + 2Dx3D: Create expert-wise C descriptor. + C: (fake_m, n, 1) -> slice to (tokens_i, n, 1) per expert. + """ + token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx) + c1 = cutlass.Int32(1) + + c_i = cute.domain_offset((token_offset, 0, 0), self.c_tensor) + c_i = rewrite_tensor_shape(c_i, (tokens_i, self.c_tensor.shape[1], c1)) # type: ignore[index] + + tma_atom_c, _ = cpasync.make_tiled_tma_atom( + self.c_tma_op, + c_i, + self.epi_smem_layout, + self.epi_tile, + ) + cpasync.copy_tensormap(tma_atom_c, self.get_desc_ptr("c", expert_idx)) + + # ----------------------------------------------------------------- + # 2Dx2D: A, B descriptors (same as MoEGroupedGemmTensormapConstructor) + # ----------------------------------------------------------------- + + @cute.jit + def _construct_ab_descs_2dx2d(self, expert_idx: Int32) -> None: + """ + 2Dx2D: Create expert-wise A and B descriptors. + A: (m, fake_k, 1) -> slice to (m, tokens_i, 1) + B: (n, fake_k, 1) -> slice to (n, tokens_i, 1) + """ + token_offset, tokens_i = compute_expert_token_range(self.offs, expert_idx) + c1 = cutlass.Int32(1) + + # A: (m, fake_k, 1) -> domain_offset + rewrite shape + a_i = cute.domain_offset((0, token_offset, 0), self.a_tensor) + a_i = rewrite_tensor_shape(a_i, (self.a_tensor.shape[0], tokens_i, c1)) # type: ignore[index] + + tma_atom_a, _ = cute.nvgpu.make_tiled_tma_atom_A( + self.a_tma_op, + a_i, + self.a_smem_layout, + self.mma_tiler, + self.tiled_mma, + self.cluster_layout_vmnk_shape, + ) + cpasync.copy_tensormap(tma_atom_a, self.get_desc_ptr("a", expert_idx)) + + # B: (n, fake_k, 1) -> domain_offset + rewrite shape + b_i = cute.domain_offset((0, token_offset, 0), self.b_tensor) + b_i = rewrite_tensor_shape(b_i, (self.b_tensor.shape[0], tokens_i, c1)) # type: ignore[index] + + tma_atom_b, _ = cute.nvgpu.make_tiled_tma_atom_B( + self.b_tma_op, + b_i, + self.b_smem_layout, + self.mma_tiler, + self.tiled_mma, + self.cluster_layout_vmnk_shape, + ) + cpasync.copy_tensormap(tma_atom_b, self.get_desc_ptr("b", expert_idx)) + + # ----------------------------------------------------------------- + # 2Dx2D: SFA, SFB descriptors (new for block-scaled) + # ----------------------------------------------------------------- + + @cute.jit + def _construct_sf_descs_2dx2d_direct( + self, + expert_idx: Int32, + padded_offset: Int32, + padded_size_i: Int32, + ) -> None: + """ + 2Dx2D: Create expert-wise SFA and SFB descriptors with pre-computed + padded offset and size. + + This variant allows the caller to supply padded offsets from SMEM + (in desc_init_kernel) instead of reading from ``self.offs_padded`` in GMEM. + """ + c1 = cutlass.Int32(1) + + a_chunks_to_move = ( + padded_offset + // self.sf_vec_size + * cute.size(self.sfa_tensor, mode=[0]) + // 128 + ) + a_elems_to_move = ( + cute.size(self.sfa_tensor, mode=[0]) * padded_offset // self.sf_vec_size + ) + b_chunks_to_move = ( + padded_offset + // self.sf_vec_size + * cute.size(self.sfb_tensor, mode=[0]) + // 128 + ) + b_elems_to_move = ( + cute.size(self.sfb_tensor, mode=[0]) * padded_offset // self.sf_vec_size + ) + + per_expert_sfa_shape = (self.sfa_tensor.shape[0], padded_size_i, c1) # type: ignore[index] + sfa_layout_i = tile_atom_to_shape_SF(per_expert_sfa_shape, self.sf_vec_size) + sfa_i = cute.make_tensor( + self.sfa_tensor.iterator + a_elems_to_move, sfa_layout_i + ) + + tma_atom_sfa, _ = cute.nvgpu.make_tiled_tma_atom_A( + self.sfa_tma_op, + sfa_i, + self.sfa_smem_layout, + self.mma_tiler, + self.tiled_mma, + self.cluster_layout_vmnk_shape, + internal_type=cutlass.Uint64, + ) + cpasync.copy_tensormap(tma_atom_sfa, self.get_desc_ptr("sfa", expert_idx)) + + per_expert_sfb_shape = (self.sfb_tensor.shape[0], padded_size_i, c1) # type: ignore[index] + sfb_layout_i = tile_atom_to_shape_SF(per_expert_sfb_shape, self.sf_vec_size) + sfb_i = cute.make_tensor( + self.sfb_tensor.iterator + b_elems_to_move, sfb_layout_i + ) + + tma_atom_sfb, _ = cute.nvgpu.make_tiled_tma_atom_B( + self.sfb_tma_op, + sfb_i, + self.sfb_smem_layout, + self.mma_tiler_sfb, + self.tiled_mma_sfb, + self.cluster_layout_sfb_vmnk_shape, + internal_type=cutlass.Uint64, + ) + cpasync.copy_tensormap(tma_atom_sfb, self.get_desc_ptr("sfb", expert_idx)) diff --git a/reference/moe_torch_grouped_mm.py b/reference/moe_torch_grouped_mm.py new file mode 100644 index 00000000..13c6352e --- /dev/null +++ b/reference/moe_torch_grouped_mm.py @@ -0,0 +1,2019 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import sys +from typing import Optional, Tuple, Literal, Type, Union + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass.cute.typing import Pointer +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait + +if __name__ == "__main__": + current_dir = os.path.dirname(os.path.abspath(__file__)) + sys.path.insert(0, os.path.join(current_dir, "../../..")) + +from blackwell.kernel.moe.moe_utils import ( + MoEGroupedGemmTensormapConstructor, +) +from blackwell.kernel.moe.moe_persistent_scheduler import ( + MoEStaticSchedulerParams, + MoEStaticPersistentTileScheduler, + MoEWorkTileInfo, +) +from blackwell.kernel.moe.moe_sched_extension import GroupedMmSchedExtension +from cutlass.utils.gemm.sm100 import ( + transform_partitioned_tensor_layout, + epilogue_tmem_copy_and_partition, + epilogue_smem_copy_and_partition, +) + + +class GroupedGemmKernel: + """ + Grouped GEMM kernel for MoE operations. + + PyTorch interface (from torch.nn.functional.grouped_mm): + - 2Dx3D (Forward): mat_a(tokens_sum, K) x mat_b(experts, K, N) -> out(tokens_sum, N) + - 2Dx2D (Weight grad): mat_a(hidden, tokens_sum) x mat_b(tokens_sum, intermediate) -> out(experts, hidden, intermediate) + + Kernel interface uses "fake" GEMM MNKL domain: + + 2Dx3D: + A_cute: (gemm_fake_m, gemm_k, 1) # fake_m = tokens_sum, scheduler will offset + B_cute: (gemm_n, gemm_k, gemm_fake_l) # fake_l = expert_idx, scheduler will select + C_cute: (gemm_fake_m, gemm_n, 1) # fake_m = tokens_sum, scheduler will offset + + 2Dx2D: + A_cute: (gemm_m, gemm_fake_k, 1) # fake_k = tokens_sum, scheduler will offset + B_cute: (gemm_n, gemm_fake_k, 1) # fake_k = tokens_sum, scheduler will offset + C_cute: (gemm_m, gemm_n, gemm_fake_l) # fake_l = expert_idx, scheduler will select + + The scheduler handles the fake dimensions by: + - For fake_m/fake_k: Computing token_offset from offs and adjusting tensor coord + - For fake_l: Selecting expert slice via L coordinate + """ + + def __init__( + self, + scenario: Literal["2Dx3D", "2Dx2D"], + out_dtype: Type[cutlass.Numeric], + accumulate_on_output: bool, + separate_tensormap_init: bool = True, + fixed_expert_cnt: Optional[int] = None, + acc_dtype: Type[cutlass.Numeric] = cutlass.Float32, + mma_tiler_mnk: Tuple[int, int, int] = (128, 128, 64), + cluster_shape_mnk: Tuple[int, int, int] = (1, 1, 1), + use_2cta_instrs: bool = False, + ): + # User-provided configs + self.scenario = scenario + self.out_dtype = out_dtype + self.accumulate_on_output = accumulate_on_output + self.separate_tensormap_init = separate_tensormap_init + self.fixed_expert_cnt = fixed_expert_cnt # Not used yet... + self.acc_dtype = acc_dtype + self.mma_tiler_mnk = mma_tiler_mnk + self.cluster_shape_mn = (cluster_shape_mnk[0], cluster_shape_mnk[1]) + self.use_2cta_instrs = use_2cta_instrs + self.arch = "sm_100" + + if accumulate_on_output and scenario == "2Dx3D": + raise ValueError( + "Non-sense config: grad accumulate should only happens in 2Dx2D." + ) + + self._validate_mma_tiler_and_cluster_shape() + + # K dimension is deferred in _setup_attributes + self.mma_tiler = (mma_tiler_mnk[0], mma_tiler_mnk[1], 1) + + # CTA group for tcgen05 MMA + self.cta_group = ( + tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + # Occupancy and warp specialization + self.occupancy = 1 + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4 + self.tma_warp_id = 5 + self.sched_warp_id = 6 + self.threads_per_cta = 32 * len( + ( + self.mma_warp_id, + self.tma_warp_id, + self.sched_warp_id, + *self.epilogue_warp_id, + ) + ) + + # Barrier IDs for synchronization + self.epilog_sync_bar_id = 1 + self.tmem_alloc_sync_bar_id = 2 + self.tmem_dealloc_sync_bar_id = 3 + + def _validate_mma_tiler_and_cluster_shape(self): + """Validate codegen-time MMA tiler and cluster shape constraints.""" + m, n = self.mma_tiler_mnk[0], self.mma_tiler_mnk[1] + cm, cn = self.cluster_shape_mn + + if self.use_2cta_instrs: + valid_m = [128, 256] + else: + valid_m = [64, 128] + if m not in valid_m: + raise ValueError( + f"mma_tiler M ({m}) must be one of {valid_m} " + f"(use_2cta_instrs={self.use_2cta_instrs})" + ) + + if n not in range(32, 257, 32): + raise ValueError(f"mma_tiler N ({n}) must be a multiple of 32 in [32, 256]") + + if cm % (2 if self.use_2cta_instrs else 1) != 0: + raise ValueError( + f"cluster_shape M ({cm}) must be even when use_2cta_instrs=True" + ) + + is_pow2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if cm * cn > 16 or not is_pow2(cm) or not is_pow2(cn): + raise ValueError( + f"Invalid cluster_shape ({cm}, {cn}): each dim must be " + f"a power of 2, and product must be <= 16" + ) + + def _create_tiled_mma(self) -> cute.TiledMma: + """Create tiled MMA atom based on input dtypes and major modes.""" + return utils.sm100.make_trivial_tiled_mma( + self.a_dtype, + self.b_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + + def _setup_attributes(self) -> None: + """ + Set up configurations that depend on GEMM inputs. + + This method configures: + - tiled_mma with correct dtypes and major modes + - MMA/cluster/tile shapes + - Cluster layout + - Multicast CTA counts + - Epilogue tile shape + - Stage counts (ACC, A/B, C) + - SMEM layouts for A/B/C + - Tensor memory allocation columns + - TMA load bytes + """ + tiled_mma = self._create_tiled_mma() + + # Use user-specified K dimension directly from mma_tiler_mnk + # Verify K is a multiple of the MMA instruction's native K size + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + assert self.mma_tiler_mnk[2] % mma_inst_shape_k == 0, ( + f"mma_tiler K ({self.mma_tiler_mnk[2]}) must be a multiple of " + f"MMA instruction K ({mma_inst_shape_k})" + ) + self.mma_tiler = ( + self.mma_tiler[0], + self.mma_tiler[1], + self.mma_tiler_mnk[2], + ) + + # CTA tile shape + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + + # Cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + + # Multicast CTA counts + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + # Epilogue tile shape (always use TMA store for MoE) + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + + # C SMEM layout for epilogue + c_smem_layout = utils.sm100.make_smem_layout_epi( + self.c_dtype, self.c_layout, self.epi_tile, 1 + ) + + self.smem_capacity = utils.get_smem_capacity_in_bytes() + + # Compute stage counts + self.num_acc_stage = 2 + self.num_c_stage = 2 # Always use TMA store for MoE + + a_smem_layout_stage_one = utils.sm100.make_smem_layout_a( + tiled_mma, self.mma_tiler, self.a_dtype, 1 + ) + b_smem_layout_stage_one = utils.sm100.make_smem_layout_b( + tiled_mma, self.mma_tiler, self.b_dtype, 1 + ) + + ab_bytes_per_stage = cute.size_in_bytes( + self.a_dtype, a_smem_layout_stage_one + ) + cute.size_in_bytes(self.b_dtype, b_smem_layout_stage_one) + mbar_helpers_bytes = 1024 + c_bytes_per_stage = cute.size_in_bytes(self.c_dtype, c_smem_layout) + c_bytes = c_bytes_per_stage * self.num_c_stage + + self.num_sched_stages = 2 + sched_work_tile_bytes_per_stage = 16 # 4 fields * sizeof(Int32) + sched_bytes = sched_work_tile_bytes_per_stage * self.num_sched_stages + + fixed_overhead = mbar_helpers_bytes + c_bytes + sched_bytes + + self.num_ab_stage = ( + self.smem_capacity // self.occupancy - fixed_overhead + ) // ab_bytes_per_stage + + # Refine epilogue stages with remaining SMEM + self.num_c_stage += ( + self.smem_capacity + - self.occupancy * ab_bytes_per_stage * self.num_ab_stage + - self.occupancy * fixed_overhead + ) // (self.occupancy * c_bytes_per_stage) + + # SMEM layouts + self.a_smem_layout_staged = utils.sm100.make_smem_layout_a( + tiled_mma, self.mma_tiler, self.a_dtype, self.num_ab_stage + ) + self.b_smem_layout_staged = utils.sm100.make_smem_layout_b( + tiled_mma, self.mma_tiler, self.b_dtype, self.num_ab_stage + ) + self.c_smem_layout_staged = utils.sm100.make_smem_layout_epi( + self.c_dtype, self.c_layout, self.epi_tile, self.num_c_stage + ) + + # Tensor memory allocation columns + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C( + cute.append(acc_shape, self.num_acc_stage) + ) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols( + tCtAcc_fake, arch=self.arch + ) + + # TMA load bytes + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size + + def get_workspace_size(self, expert_cnt: int) -> int: + """ + Workspace size for expert-wise TMA descriptors. + + 2Dx3D: Need C desc per expert -> expert_cnt * TensormapDescBytes + 2Dx2D: Need A and B desc per expert -> 2 * expert_cnt * TensormapDescBytes + """ + return MoEGroupedGemmTensormapConstructor.get_workspace_size( + self.scenario, expert_cnt + ) + + @cute.jit + def __call__( + self, + mat_a: cute.Tensor, # PyTorch mat_a + mat_b: cute.Tensor, # PyTorch mat_b + out: cute.Tensor, # PyTorch output + offs: cute.Tensor, # (experts,) cumsum + bias: Optional[cute.Tensor], + workspace: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + ) -> None: + """ + Launch the grouped GEMM kernel. + + This method: + 1. Transforms PyTorch tensors to GEMM domain tensors + 2. Infers dtypes and major modes from GEMM domain tensors + 3. Sets up kernel attributes + 4. Creates TMA atoms for A, B, C + 5. Creates scheduler parameters + 6. Launches the kernel + """ + if cutlass.const_expr(bias is not None): + raise NotImplementedError("bias is not supported yet (align with torch).") + + # ===================================================================== + # Step 1: Transform PyTorch tensors to GEMM domain (fake MNKL) + # ===================================================================== + + c1 = cutlass.Int32(1) + c0 = cutlass.Int32(0) + + if cutlass.const_expr(self.scenario == "2Dx3D"): + # mat_a: (tokens_sum, hidden) -> A_cute: (fake_m, k, 1) + tokens_sum, hidden = mat_a.shape + a_gemm = cute.make_tensor( + mat_a.iterator, + cute.make_layout( + (tokens_sum, hidden, c1), + stride=(mat_a.stride[0], mat_a.stride[1], c0), + ), + ) + + # mat_b: (experts, hidden, intermediate) -> B_cute: (n, k, fake_l) + experts, hidden_b, intermediate = mat_b.shape + b_gemm = cute.make_tensor( + mat_b.iterator, + cute.make_layout( + (intermediate, hidden_b, experts), + stride=(mat_b.stride[2], mat_b.stride[1], mat_b.stride[0]), + ), + ) + + # out: (tokens_sum, intermediate) -> C_cute: (fake_m, n, 1) + c_gemm = cute.make_tensor( + out.iterator, + cute.make_layout( + (tokens_sum, intermediate, c1), + stride=(out.stride[0], out.stride[1], c0), + ), + ) + + expert_cnt = experts + intermediate_dim = intermediate + hidden_dim = hidden + + else: # 2Dx2D + # mat_a: (hidden, tokens_sum) -> A_cute: (m, fake_k, 1) + hidden, tokens_sum = mat_a.shape + a_gemm = cute.make_tensor( + mat_a.iterator, + cute.make_layout( + (hidden, tokens_sum, c1), + stride=(mat_a.stride[0], mat_a.stride[1], c0), + ), + ) + + # mat_b: (tokens_sum, intermediate) -> B_cute: (n, fake_k, 1) + tokens_sum_b, intermediate = mat_b.shape + b_gemm = cute.make_tensor( + mat_b.iterator, + cute.make_layout( + (intermediate, tokens_sum_b, c1), + stride=(mat_b.stride[1], mat_b.stride[0], c0), + ), + ) + + # out: (experts, hidden, intermediate) -> C_cute: (m, n, fake_l) + experts, hidden_c, intermediate_c = out.shape + c_gemm = cute.make_tensor( + out.iterator, + cute.make_layout( + (hidden_c, intermediate_c, experts), + stride=(out.stride[1], out.stride[2], out.stride[0]), + ), + ) + + expert_cnt = experts + intermediate_dim = intermediate + hidden_dim = hidden + + # ===================================================================== + # Step 2: Infer dtypes and major modes from GEMM domain tensors + # ===================================================================== + + self.a_dtype: Type[cutlass.Numeric] = a_gemm.element_type + self.b_dtype: Type[cutlass.Numeric] = b_gemm.element_type + self.c_dtype: Type[cutlass.Numeric] = c_gemm.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a_gemm).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b_gemm).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c_gemm) + + # ===================================================================== + # Step 3: Setup kernel attributes + # ===================================================================== + + k = self.mma_tiler_mnk[2] + a_tile_bits = self.a_dtype.width * k + b_tile_bits = self.b_dtype.width * k + if cutlass.const_expr(a_tile_bits % 256 != 0): + raise ValueError( + f"a_dtype ({self.a_dtype.width}b) * mma_tiler K ({k}) = " + f"{a_tile_bits}b, must be a multiple of 256b (MMA instruction K width)" + ) + if cutlass.const_expr(b_tile_bits % 256 != 0): + raise ValueError( + f"b_dtype ({self.b_dtype.width}b) * mma_tiler K ({k}) = " + f"{b_tile_bits}b, must be a multiple of 256b (MMA instruction K width)" + ) + + self._setup_attributes() + tiled_mma = self._create_tiled_mma() + + # ===================================================================== + # Step 4: Create TMA atoms for A, B, C + # ===================================================================== + + # TMA load for A + a_op = utils.sm100.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a_gemm, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # TMA load for B + b_op = utils.sm100.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b_gemm, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # TMA store for C (or TMA reduce for accumulate_on_output) + if cutlass.const_expr(self.accumulate_on_output): + c_tma_op = cpasync.CopyReduceBulkTensorTileS2GOp() + else: + c_tma_op = cpasync.CopyBulkTensorTileS2GOp() + + epi_smem_layout = cute.select(self.c_smem_layout_staged, mode=[0, 1]) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + c_tma_op, c_gemm, epi_smem_layout, self.epi_tile + ) + + # ===================================================================== + # Step 5: Create MoEStaticSchedulerParams and compute grid + # ===================================================================== + + sched_params = MoEStaticSchedulerParams( + scenario=self.scenario, + expert_shape=(expert_cnt, intermediate_dim, hidden_dim), + cta_tile_shape_mnk=self.cta_tile_shape_mnk, + cluster_shape_mn=self.cluster_shape_mn, + ) + + grid = MoEStaticSchedulerParams.get_grid_shape( + sched_params, max_active_clusters + ) + + # ===================================================================== + # Step 5.5: Launch desc init kernel (if separate_tensormap_init) + # ===================================================================== + # + # Pre-initialize expert-wise TMA descriptors in workspace before + # the main kernel. Stream ordering guarantees completion before + # the main kernel starts. + # + # 2Dx3D: C desc per expert (C has dynamic fake_m per expert) + # 2Dx2D: A,B desc per expert (A,B have dynamic fake_k per expert) + # + + if cutlass.const_expr(self.separate_tensormap_init): + self.desc_init_kernel( + tiled_mma, + a_gemm, + b_gemm, + c_gemm, + offs, + expert_cnt, + workspace.iterator, + self.cluster_layout_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.c_smem_layout_staged, + self.epi_tile, + ).launch( + grid=(expert_cnt, 1, 1), + block=[32, 1, 1], + stream=stream, + min_blocks_per_mp=1, + ) + + # ===================================================================== + # Step 6: Launch kernel + # ===================================================================== + + self.kernel( + tiled_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c, + a_gemm, + b_gemm, + c_gemm, + offs, + sched_params, + workspace.iterator, + self.cluster_layout_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.c_smem_layout_staged, + self.epi_tile, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + stream=stream, + min_blocks_per_mp=self.occupancy, + ) + + # GPU device kernel: TMA descriptor initialization + @cute.kernel + def desc_init_kernel( + self, + tiled_mma: cute.TiledMma, + a_gemm: cute.Tensor, # GEMM domain A (fake MNKL) + b_gemm: cute.Tensor, # GEMM domain B (fake MNKL) + c_gemm: cute.Tensor, # GEMM domain C (fake MNKL) + offs: cute.Tensor, # (experts,) cumsum + expert_cnt: Union[cutlass.Int32, int], + workspace_ptr: Pointer, + cluster_layout_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout], + epi_tile: cute.Tile, + ): + """ + Separate kernel to pre-initialize expert-wise TMA descriptors. + + Grid: (expert_cnt, 1, 1) - one block per expert + Block: (32, 1, 1) - one warp per block + + Each block constructs and writes TMA descriptors for one expert + to the pre-allocated workspace buffer. + + 2Dx3D: Creates C descriptor per expert (C has dynamic fake_m per expert) + 2Dx2D: Creates A and B descriptors per expert (A/B have dynamic fake_k per expert) + """ + # ================================================================= + # Reconstruct TMA constructor with explicit attributes + # ================================================================= + + a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, None, 0)) + b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, None, 0)) + epi_smem_layout = cute.select(c_smem_layout_staged, mode=[0, 1]) + + a_tma_op = utils.sm100.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_tma_op = utils.sm100.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + if cutlass.const_expr(self.accumulate_on_output): + c_tma_op = cpasync.CopyReduceBulkTensorTileS2GOp() + else: + c_tma_op = cpasync.CopyBulkTensorTileS2GOp() + + tensormap_ctor = MoEGroupedGemmTensormapConstructor( + scenario=self.scenario, + a_dtype=self.a_dtype, + b_dtype=self.b_dtype, + c_dtype=self.c_dtype, + a_smem_layout=a_smem_layout, + b_smem_layout=b_smem_layout, + epi_smem_layout=epi_smem_layout, + a_tma_op=a_tma_op, + b_tma_op=b_tma_op, + c_tma_op=c_tma_op, + tiled_mma=tiled_mma, + mma_tiler=self.mma_tiler, + cluster_layout_vmnk_shape=cluster_layout_vmnk.shape, + epi_tile=epi_tile, + a_tensor=a_gemm, + b_tensor=b_gemm, + c_tensor=c_gemm, + offs=offs, + workspace_ptr=workspace_ptr, + ) + + # ================================================================= + # Each block constructs descriptors for one expert + # ================================================================= + + expert_idx, _, _ = cute.arch.block_idx() + tensormap_ctor.construct_and_write(expert_idx) + + # GPU device kernel: main GEMM kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + tma_tensor_a: cute.Tensor, + tma_atom_b: cute.CopyAtom, + tma_tensor_b: cute.Tensor, + tma_atom_c: cute.CopyAtom, + tma_tensor_c: cute.Tensor, + a_gemm: cute.Tensor, # GEMM domain A (fake MNKL) + b_gemm: cute.Tensor, # GEMM domain B (fake MNKL) + c_gemm: cute.Tensor, # GEMM domain C (fake MNKL) + offs: cute.Tensor, # (experts,) cumsum + sched_params: MoEStaticSchedulerParams, + workspace_ptr: Pointer, + cluster_layout_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout], + epi_tile: cute.Tile, + ): + """ + GPU device kernel for MoE Grouped GEMM. + + Warp specialization: + - Warps 0-3: Epilogue warps (TMEM -> RMEM -> SMEM -> GMEM) + - Warp 4: MMA warp (tcgen05.mma) + - Warp 5: TMA load warp (also prefetches expert-wise TMA descriptors) + + The kernel uses MoEStaticPersistentTileScheduler to iterate over tiles + across all experts. For each tile: + 1. TMA load warp fetches A/B tiles using get_gmem_tensor + 2. MMA warp performs matrix multiply-accumulate + 3. Epilogue warps store results using TMA store/reduce + + Note: Python objects holding MLIR values cannot be kernel params. + The following are constructed inside the kernel from individually-passed params: + - tensormap_ctor: MoEGroupedGemmTensormapConstructor (online tensormap builder) + - ext: GroupedMmSchedExtension (domain conversion + TMA desc selection) + """ + # ================================================================= + # Reconstruct dicts that can't be passed as kernel params + # ================================================================= + + # Construct TMA descriptor creator and scheduler extension + a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, None, 0)) + b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, None, 0)) + epi_smem_layout = cute.select(c_smem_layout_staged, mode=[0, 1]) + + a_tma_op = utils.sm100.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_tma_op = utils.sm100.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + if cutlass.const_expr(self.accumulate_on_output): + c_tma_op = cpasync.CopyReduceBulkTensorTileS2GOp() + else: + c_tma_op = cpasync.CopyBulkTensorTileS2GOp() + + tensormap_ctor = MoEGroupedGemmTensormapConstructor( + scenario=self.scenario, + a_dtype=self.a_dtype, + b_dtype=self.b_dtype, + c_dtype=self.c_dtype, + a_smem_layout=a_smem_layout, + b_smem_layout=b_smem_layout, + epi_smem_layout=epi_smem_layout, + a_tma_op=a_tma_op, + b_tma_op=b_tma_op, + c_tma_op=c_tma_op, + tiled_mma=tiled_mma, + mma_tiler=self.mma_tiler, + cluster_layout_vmnk_shape=cluster_layout_vmnk.shape, + epi_tile=epi_tile, + a_tensor=a_gemm, + b_tensor=b_gemm, + c_tensor=c_gemm, + offs=offs, + workspace_ptr=workspace_ptr, + ) + ext = GroupedMmSchedExtension( + scenario=self.scenario, tensormap_ctor=tensormap_ctor + ) + + # ================================================================= + # Kernel setup + # ================================================================= + + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # CTA/thread coordinates + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + tidx, _, _ = cute.arch.thread_idx() + + # ================================================================= + # SharedStorage + # ================================================================= + + @cute.struct + class SharedStorage: + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + acc_full_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_acc_stage * 2 + ] + sched_buf: cute.struct.MemRange[cutlass.Int32, self.num_sched_stages * 4] + sched_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_sched_stages * 2 + ] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + + # ================================================================= + # Pipelines + # ================================================================= + + # AB pipeline (TMA load → MMA) + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer + ) + ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + + # ACC pipeline (MMA → epilogue) + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = ( + len(self.epilogue_warp_id) * 32 * (2 if use_2cta_instrs else 1) + ) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_acc_consumer_threads + ) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ) + + # Scheduler pipeline (sched warp → tma/mma/epi warps) + sched_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32) + num_sched_consumer_threads = 32 * len( + (self.tma_warp_id, self.mma_warp_id, *self.epilogue_warp_id) + ) + sched_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_sched_consumer_threads + ) + sched_pipeline = pipeline.PipelineAsync.create( + num_stages=self.num_sched_stages, + producer_group=sched_producer_group, + consumer_group=sched_consumer_group, + barrier_storage=storage.sched_mbar_ptr.data_ptr(), + defer_sync=True, + ) + + # TMEM allocator + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)), + ) + tmem = utils.TmemAllocator( + storage.tmem_holding_buf.ptr, + barrier_for_retrieve=tmem_alloc_barrier, + allocator_warp_id=self.epilogue_warp_id[0], + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr.ptr, + ) + + # Cluster barrier sync after init + pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True) + + # ================================================================= + # SMEM tensors A/B + # ================================================================= + + # (MMA, MMA_M, MMA_K, STAGE) + sA = smem.allocate_tensor( + element_type=self.a_dtype, + layout=a_smem_layout_staged.outer, + byte_alignment=128, + swizzle=a_smem_layout_staged.inner, + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = smem.allocate_tensor( + element_type=self.b_dtype, + layout=b_smem_layout_staged.outer, + byte_alignment=128, + swizzle=b_smem_layout_staged.inner, + ) + + # Multicast masks + a_full_mcast_mask = None + b_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + + # MMA fragments (SMEM → TMEM partitions) + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = tiled_mma.make_fragment_C( + cute.append(acc_shape, self.num_acc_stage) + ) + + # Cluster wait before TMEM alloc + pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn) + + # ================================================================= + # Scheduler warp (warp 6) + # ================================================================= + + sched_buf_ptr = storage.sched_buf.data_ptr() + sched_copy_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), cutlass.Int32, num_bits_per_copy=128 + ) + sched_buf_tensor = cute.make_tensor( + sched_buf_ptr, cute.make_layout((4, self.num_sched_stages), stride=(1, 4)) + ) + + if warp_idx == self.sched_warp_id: + scheduler = MoEStaticPersistentTileScheduler.create( + sched_params, offs, cute.arch.block_idx(), cute.arch.grid_dim() + ) + + sched_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_sched_stages + ) + + # Always produce the initial work_tile_info first + work_tile_info = scheduler.initial_work_tile_info() + sched_pipeline.producer_acquire(sched_producer_state) + rmem = work_tile_info.to_rmem_tensor() + cute.copy( + sched_copy_atom, + rmem, + sched_buf_tensor[(None, sched_producer_state.index)], + ) + cute.arch.fence_proxy("async.shared", space="cta") + sched_pipeline.producer_commit(sched_producer_state) + sched_producer_state.advance() + + # Iterate remaining tiles starting from the first advance + work_tile_info = scheduler.advance_to_next_work() + while work_tile_info.is_valid_tile: + ext.prefetch_for_expert(work_tile_info.expert_idx) + sched_pipeline.producer_acquire(sched_producer_state) + rmem = work_tile_info.to_rmem_tensor() + cute.copy( + sched_copy_atom, + rmem, + sched_buf_tensor[(None, sched_producer_state.index)], + ) + cute.arch.fence_proxy("async.shared", space="cta") + sched_pipeline.producer_commit(sched_producer_state) + sched_producer_state.advance() + + work_tile_info = scheduler.advance_to_next_work() + + # Write invalid sentinel (expert_idx = -1) so consumers exit + sched_pipeline.producer_acquire(sched_producer_state) + sentinel = MoEWorkTileInfo( + cutlass.Int32(-1), cutlass.Int32(0), cutlass.Int32(0), cutlass.Int32(0) + ) + rmem = sentinel.to_rmem_tensor() + cute.copy( + sched_copy_atom, + rmem, + sched_buf_tensor[(None, sched_producer_state.index)], + ) + cute.arch.fence_proxy("async.shared", space="cta") + sched_pipeline.producer_commit(sched_producer_state) + + sched_pipeline.producer_tail(sched_producer_state) + + # ================================================================= + # TMA load warp (warp 5) + # ================================================================= + + if warp_idx == self.tma_warp_id: + sched_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_sched_stages + ) + + # Read initial work_tile_info + sched_pipeline.consumer_wait(sched_consumer_state) + rmem = cute.make_rmem_tensor((4,), cutlass.Int32) + cute.copy( + sched_copy_atom, + sched_buf_tensor[(None, sched_consumer_state.index)], + rmem, + ) + work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem) + cute.arch.fence_acq_rel_cta() + sched_pipeline.consumer_release(sched_consumer_state) + sched_consumer_state.advance() + + while work_tile_info.is_valid_tile: + k_tile_cnt = work_tile_info.k_tile_cnt + + # Get real GEMM domain tensors + TMA desc ptrs via extension + real_a, desc_ptr_a = ext.get_gmem_tensor( + "a", + tma_tensor_a, + offs, + work_tile_info, + ) + real_b, desc_ptr_b = ext.get_gmem_tensor( + "b", + tma_tensor_b, + offs, + work_tile_info, + ) + + # local_tile for this tile's A and B + gA_mkl = cute.local_tile( + real_a, + cute.slice_(self.mma_tiler, (None, 0, None)), + (None, None, None), + ) + gB_nkl = cute.local_tile( + real_b, + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None), + ) + + # MMA partition for TMA + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + tCgA = thr_mma.partition_A(gA_mkl) + tCgB = thr_mma.partition_B(gB_nkl) + + # TMA partition + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # Slice to current tile coords (L=0 for MoE, expert already selected) + mma_tile_m = work_tile_info.tile_m_idx // cute.size( + tiled_mma.thr_id.shape + ) + tAgA_slice = tAgA[(None, mma_tile_m, None, 0)] + tBgB_slice = tBgB[(None, work_tile_info.tile_n_idx, None, 0)] + + # TMA load loop + ab_producer.reset() + peek_ab_empty_status = ab_producer.try_acquire() + + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + handle = ab_producer.acquire_and_advance(peek_ab_empty_status) + peek_ab_empty_status = cutlass.Boolean(1) + if handle.count + 1 < k_tile_cnt: + peek_ab_empty_status = ab_producer.try_acquire() + cute.copy( + tma_atom_a, + tAgA_slice[(None, handle.count)], + tAsA[(None, handle.index)], + tma_bar_ptr=handle.barrier, + tma_desc_ptr=desc_ptr_a, + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_slice[(None, handle.count)], + tBsB[(None, handle.index)], + tma_bar_ptr=handle.barrier, + tma_desc_ptr=desc_ptr_b, + mcast_mask=b_full_mcast_mask, + ) + + # Read next work_tile_info + sched_pipeline.consumer_wait(sched_consumer_state) + rmem = cute.make_rmem_tensor((4,), cutlass.Int32) + cute.copy( + sched_copy_atom, + sched_buf_tensor[(None, sched_consumer_state.index)], + rmem, + ) + work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem) + cute.arch.fence_acq_rel_cta() + sched_pipeline.consumer_release(sched_consumer_state) + sched_consumer_state.advance() + + ab_producer.tail() + + # ================================================================= + # MMA warp (warp 4) + # ================================================================= + + if warp_idx == self.mma_warp_id: + # Retrieve TMEM + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + sched_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_sched_stages + ) + + # Read initial work_tile_info + sched_pipeline.consumer_wait(sched_consumer_state) + rmem = cute.make_rmem_tensor((4,), cutlass.Int32) + cute.copy( + sched_copy_atom, + sched_buf_tensor[(None, sched_consumer_state.index)], + rmem, + ) + work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem) + cute.arch.fence_acq_rel_cta() + sched_pipeline.consumer_release(sched_consumer_state) + sched_consumer_state.advance() + + while work_tile_info.is_valid_tile: + k_tile_cnt = work_tile_info.k_tile_cnt + if is_leader_cta: + tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] + + # AB consumer mainloop + ab_consumer.reset() + peek_ab_full_status = cutlass.Boolean(1) + if k_tile_cnt > 0: + peek_ab_full_status = ab_consumer.try_wait() + acc_pipeline.producer_acquire(acc_producer_state) + + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + handle = ab_consumer.wait_and_advance(peek_ab_full_status) + peek_ab_full_status = cutlass.Boolean(1) + if handle.count + 1 < k_tile_cnt: + peek_ab_full_status = ab_consumer.try_wait() + tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile != 0) + tile_crd = (None, None, None, handle.index) + cute.gemm( + tiled_mma, tCtAcc, tCrA[tile_crd], tCrB[tile_crd], tCtAcc + ) + handle.release() + + if k_tile_cnt > 0: + acc_pipeline.producer_commit(acc_producer_state) + if k_tile_cnt > 0: + acc_producer_state.advance() + + # Read next work_tile_info + sched_pipeline.consumer_wait(sched_consumer_state) + rmem = cute.make_rmem_tensor((4,), cutlass.Int32) + cute.copy( + sched_copy_atom, + sched_buf_tensor[(None, sched_consumer_state.index)], + rmem, + ) + work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem) + cute.arch.fence_acq_rel_cta() + sched_pipeline.consumer_release(sched_consumer_state) + sched_consumer_state.advance() + + acc_pipeline.producer_tail(acc_producer_state) + + # ================================================================= + # SMEM tensor C (allocated after MMA section, same as dense) + # ================================================================= + + sC = smem.allocate_tensor( + element_type=self.c_dtype, + layout=c_smem_layout_staged.outer, + byte_alignment=128, + swizzle=c_smem_layout_staged.inner, + ) + + # ================================================================= + # Epilogue warps (warps 0-3) + # ================================================================= + + if warp_idx < self.mma_warp_id: + # Allocate TMEM + tmem.allocate(self.num_tmem_alloc_cols) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + sched_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_sched_stages + ) + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilogue_warp_id), + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, producer_group=c_producer_group + ) + + epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=self.epilog_sync_bar_id, + num_threads=32 * len(self.epilogue_warp_id), + ) + + # Epilogue copy setup (same for all tiles, depends only on shapes) + # Transform ACC layout: ((ATOM_M, ATOM_N), MMA_M, MMA_N, STAGE) + # -> ((ATOM_M, MMA_M), (ATOM_N, MMA_N), STAGE) + tCtAcc_transformed = transform_partitioned_tensor_layout(tCtAcc_base) + + num_tiles_executed = cutlass.Int32(0) + + # Read initial work_tile_info + sched_pipeline.consumer_wait(sched_consumer_state) + rmem = cute.make_rmem_tensor((4,), cutlass.Int32) + cute.copy( + sched_copy_atom, + sched_buf_tensor[(None, sched_consumer_state.index)], + rmem, + ) + work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem) + cute.arch.fence_acq_rel_cta() + sched_pipeline.consumer_release(sched_consumer_state) + sched_consumer_state.advance() + + while work_tile_info.is_valid_tile: + k_tile_cnt = work_tile_info.k_tile_cnt + # Get real C tensor + TMA desc ptr via extension + real_c, desc_ptr_c = ext.get_gmem_tensor( + "c", + tma_tensor_c, + offs, + work_tile_info, + ) + + # local_tile + partition for C + gC_mnl = cute.local_tile( + real_c, + cute.slice_(self.mma_tiler, (None, None, 0)), + (None, None, None), + ) + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + tCgC = thr_mma.partition_C(gC_mnl) + tCgC_transformed = transform_partitioned_tensor_layout(tCgC) + + mma_tile_coord_mnl = ( + work_tile_info.tile_m_idx // cute.size(tiled_mma.thr_id.shape), + work_tile_info.tile_n_idx, + cutlass.Int32(0), + ) + + # Partition for TMEM → RMEM copy + tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = ( + epilogue_tmem_copy_and_partition( + self, + tidx, + tCtAcc_transformed, + tCgC_transformed, + epi_tile, + use_2cta_instrs, + ) + ) + tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = epilogue_smem_copy_and_partition( + self, tiled_copy_t2r, tTR_rC, tidx, sC + ) + + # TMA partition for C store (with expert-wise desc_ptr) + tCgC_epi = cute.flat_divide(tCgC_transformed, epi_tile) + bSG_sC, bSG_gC_partitioned = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + cute.group_modes(sC, 0, 2), + cute.group_modes(tCgC_epi, 0, 2), + ) + bSG_gC = bSG_gC_partitioned[(None, None, None, *mma_tile_coord_mnl)] + + # Set TMEM buffer for current tile + tTR_tAcc = tTR_tAcc_base[ + (None, None, None, None, None, acc_consumer_state.index) + ] + + # Wait for accumulator buffer full + if k_tile_cnt > 0: + acc_pipeline.consumer_wait(acc_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + + # Store accumulator to global memory in subtiles + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + num_prev_subtiles = num_tiles_executed * subtile_cnt + + for subtile_idx in range(subtile_cnt): + # TMEM → RMEM + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + if cutlass.const_expr(self.scenario == "2Dx2D"): + if k_tile_cnt > 0: + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + else: + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # Convert to output dtype + acc_vec = cute.zeros_like(tiled_copy_r2s.retile(tTR_rAcc)) + if cutlass.const_expr(self.scenario == "2Dx2D"): + if k_tile_cnt > 0: + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + else: + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + acc_vec = acc_vec.to(self.c_dtype) + tRS_rC.store(acc_vec) + + # RMEM → SMEM + c_buffer = (num_prev_subtiles + subtile_idx) % self.num_c_stage + cute.copy( + tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)] + ) + cute.arch.fence_proxy("async.shared", space="cta") + epilog_sync_barrier.arrive_and_wait() + + # SMEM → GMEM (TMA store or TMA reduce) + if warp_idx == self.epilogue_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, subtile_idx)], + tma_desc_ptr=desc_ptr_c, + ) + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + epilog_sync_barrier.arrive_and_wait() + + # Release accumulator buffer + if k_tile_cnt > 0: + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + num_tiles_executed += cutlass.Int32(1) + + # Read next work_tile_info + sched_pipeline.consumer_wait(sched_consumer_state) + rmem = cute.make_rmem_tensor((4,), cutlass.Int32) + cute.copy( + sched_copy_atom, + sched_buf_tensor[(None, sched_consumer_state.index)], + rmem, + ) + work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem) + cute.arch.fence_acq_rel_cta() + sched_pipeline.consumer_release(sched_consumer_state) + sched_consumer_state.advance() + + # Wait for C store complete + c_pipeline.producer_tail() + + # Free TMEM + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr) + + +# ============================================================================= +# Host Validation +# ============================================================================= + +from dataclasses import dataclass, field +import re + +import numpy as np +import torch +import cutlass.torch as cutlass_torch + + +def torch_version_lt(major: int, minor: int) -> bool: + """Best-effort torch version check that tolerates local build suffixes.""" + match = re.match(r"^\s*(\d+)\.(\d+)", torch.__version__) + if match is None: + print( + "WARNING: failed to parse torch.__version__, " + "falling back to torch._grouped_mm host reference." + ) + return True + version = (int(match.group(1)), int(match.group(2))) + return version < (major, minor) + + +@dataclass +class ProblemDesc: + tokens: int + experts: int + top_k_select: int + balance_route: bool + hidden: int + intermediate: int + scenario: Literal["2Dx3D", "2Dx2D"] + ab_dtype: torch.dtype + out_dtype: torch.dtype + acc_dtype: torch.dtype + grad_accumulate: bool = False + # GEMM-domain layout control (which axis is stride-1) + # A (M, K): "k_major" (default) or "m_major" + # B (N, K): "n_major" (default) or "k_major" + # C (M, N): "n_major" (default) or "m_major" + a_layout: Literal["k_major", "m_major"] = "k_major" + b_layout: Literal["k_major", "n_major"] = "n_major" + c_layout: Literal["m_major", "n_major"] = "n_major" + + def __str__(self) -> str: + d = lambda t: str(t).split(".")[-1] + route = "balanced" if self.balance_route else "random" + return ( + f"ProblemDesc: {self.scenario} | tokens={self.tokens} experts={self.experts} " + f"top_k={self.top_k_select} route={route} | hidden={self.hidden} intermediate={self.intermediate} | " + f"{d(self.ab_dtype)}->{d(self.out_dtype)}(acc={d(self.acc_dtype)}) grad_acc={self.grad_accumulate} | " + f"layout: A={self.a_layout} B={self.b_layout} C={self.c_layout}" + ) + + +@dataclass +class ImplDesc: + mma_tiler_mnk: Tuple[int, int, int] + cluster_shape_mnk: Tuple[int, int, int] + use_2cta_instrs: bool + static_expert_cnt: Optional[int] = None + separate_tensormap_init: bool = True + + def __str__(self) -> str: + tile = ",".join(map(str, self.mma_tiler_mnk)) + cluster = ",".join(map(str, self.cluster_shape_mnk)) + static_e = ( + self.static_expert_cnt if self.static_expert_cnt is not None else "dynamic" + ) + return ( + f"ImplDesc: tile={tile} cluster={cluster} 2cta={self.use_2cta_instrs} | " + f"static_E={static_e} sep_tmap={self.separate_tensormap_init}" + ) + + +@dataclass +class MiscDesc: + perf_run: bool = False + perf_e2e: bool = False + compare_with_bmm: bool = False + compare_with_sol: bool = False + no_torch_210: bool = field(init=False) + + def __post_init__(self): + self.no_torch_210 = torch_version_lt(2, 10) + if self.perf_e2e and not self.perf_run: + raise ValueError("--perf_e2e requires --perf_run to be enabled.") + if self.perf_e2e and self.compare_with_sol: + raise ValueError( + "--perf_e2e and --compare_with_sol are mutually exclusive." + ) + + def __str__(self) -> str: + ref = "bmm" if self.compare_with_bmm else "grouped_mm" + return ( + f"MiscDesc: perf={self.perf_run} perf_e2e={self.perf_e2e} " + f"ref={ref} sol={self.compare_with_sol} no_torch_210={self.no_torch_210}" + ) + + +def l2_flush(size_mb: int = 400) -> None: + """Best-effort L2 flush by touching a large temporary tensor.""" + num_bytes = size_mb * 1024 * 1024 + flush_buf = torch.randint(0, 256, (num_bytes,), dtype=torch.uint8, device="cuda") + del flush_buf + + +class GroupedGemmTester: + def __init__(self, problem: ProblemDesc, impl: ImplDesc, misc: MiscDesc): + self.problem = problem + self.impl = impl + self.misc = misc + + self.tokens_after_repeat = problem.tokens * problem.top_k_select + self.expert_cnt = problem.experts + self.hidden = problem.hidden + self.intermediate = problem.intermediate + + self.A_tensor: torch.Tensor = None + self.B_tensor: torch.Tensor = None + self.C_tensor: torch.Tensor = None + self.C_ref_tensor: torch.Tensor = None + self.offs_tensor: torch.Tensor = None + self.workspace_tensor: torch.Tensor = None + + # This should be a common func + self.temp_type_mapping = { + torch.float32: cutlass.Float32, + torch.bfloat16: cutlass.BFloat16, + torch.float16: cutlass.Float16, + } + + def _generate_offs(self) -> torch.Tensor: + """Generate group-end offsets. + + Some experts may receive 0 tokens (valid in real MoE routing). + """ + total = self.tokens_after_repeat + expert_cnt = self.expert_cnt + + if self.problem.balance_route: + base = total // expert_cnt + remainder = total % expert_cnt + sizes = [base + (1 if i < remainder else 0) for i in range(expert_cnt)] + else: + proportions = np.random.dirichlet([0.5] * expert_cnt) + raw = np.floor(proportions * total).astype(int) + deficit = total - raw.sum() + while deficit > 0: + idx = int(np.argmin(raw / (proportions * total + 1e-12))) + raw[idx] += 1 + deficit -= 1 + while deficit < 0: + ratios = np.where( + raw > 0, + raw / (proportions * total + 1e-12), + -np.inf, + ) + idx = int(np.argmax(ratios)) + raw[idx] -= 1 + deficit += 1 + sizes = raw.tolist() + + assert sum(sizes) == total + + cum = 0 + offsets = [] + for s in sizes: + cum += s + offsets.append(cum) + return torch.tensor(offsets, dtype=torch.int32, device="cuda") + + def _generate_tensor(self, shape: Tuple) -> torch.Tensor: + if self.misc.perf_run: + return torch.randn(shape, dtype=self.problem.ab_dtype, device="cuda") + else: + return torch.randint(-1, 2, shape, device="cuda", dtype=torch.int8).to( + self.problem.ab_dtype + ) + + def _get_stream(self) -> cuda.CUstream: + return cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + def generate_inputs(self) -> None: + self.offs_tensor = self._generate_offs() + + tokens = self.tokens_after_repeat + hidden = self.hidden + intermediate = self.intermediate + expert_cnt = self.expert_cnt + + if self.problem.scenario == "2Dx3D": + # PyTorch shape: A (tokens, hidden), B (expert_cnt, hidden, intermediate), C (tokens, intermediate) + # GEMM domain: A (M=tokens, K=hidden), B (N=intermediate, K=hidden), C (M=tokens, N=intermediate) + + # GEMM A: k_major → K(hidden) stride-1; m_major → M(tokens) stride-1 + if self.problem.a_layout == "k_major": + self.A_tensor = self._generate_tensor((tokens, hidden)) + else: + self.A_tensor = self._generate_tensor((hidden, tokens)).T + + # GEMM B: n_major → N(intermediate) stride-1; k_major → K(hidden) stride-1 + if self.problem.b_layout == "n_major": + self.B_tensor = self._generate_tensor( + (expert_cnt, hidden, intermediate) + ) + else: + self.B_tensor = self._generate_tensor( + (expert_cnt, intermediate, hidden) + ).transpose(1, 2) + + # GEMM C: n_major → N(intermediate) stride-1; m_major → M(tokens) stride-1 + if self.problem.c_layout == "n_major": + self.C_tensor = torch.full( + (tokens, intermediate), + -1, + dtype=self.problem.out_dtype, + device="cuda", + ) + else: + self.C_tensor = torch.full( + (intermediate, tokens), + -1, + dtype=self.problem.out_dtype, + device="cuda", + ).T + + elif self.problem.scenario == "2Dx2D": + # PyTorch shape: mat_a (hidden, tokens), mat_b (tokens, intermediate), out (expert_cnt, hidden, intermediate) + # out matches weight shape (expert_cnt, hidden, intermediate) for weight gradient + # GEMM domain: A (M=hidden, K=tokens), B (N=intermediate, K=tokens), C (M=hidden, N=intermediate) + + # GEMM A: k_major → K(tokens) stride-1; m_major → M(hidden) stride-1 + if self.problem.a_layout == "k_major": + self.A_tensor = self._generate_tensor((hidden, tokens)) + else: + self.A_tensor = self._generate_tensor((tokens, hidden)).T + + # GEMM B: n_major → N(intermediate) stride-1; k_major → K(tokens) stride-1 + if self.problem.b_layout == "n_major": + self.B_tensor = self._generate_tensor((tokens, intermediate)) + else: + self.B_tensor = self._generate_tensor((intermediate, tokens)).T + + # GEMM C: n_major → N(intermediate) stride-1; m_major → M(hidden) stride-1 + if self.problem.c_layout == "n_major": + self.C_tensor = torch.full( + (expert_cnt, hidden, intermediate), + -1, + dtype=self.problem.out_dtype, + device="cuda", + ) + else: + self.C_tensor = torch.full( + (expert_cnt, intermediate, hidden), + -1, + dtype=self.problem.out_dtype, + device="cuda", + ).transpose(1, 2) + if self.problem.grad_accumulate: + self.C_tensor *= 0 + else: + raise ValueError(f"Unknown scenario: {self.problem.scenario}") + + def compute_reference(self) -> None: + if self.misc.perf_run: + return + if self.misc.compare_with_bmm: + self._compute_reference_bmm() + else: + self._compute_reference_grouped_mm() + + def _compute_reference_grouped_mm(self) -> None: + grouped_mm_op = ( + torch._grouped_mm + if self.misc.no_torch_210 + else torch.nn.functional.grouped_mm + ) + self.C_ref_tensor = grouped_mm_op( + self.A_tensor, + self.B_tensor, + offs=self.offs_tensor, + out_dtype=self.problem.out_dtype, + ) + + def _compute_reference_bmm(self) -> None: + """Manual per-expert torch.mm loop as reference (avoids grouped_mm bugs on small cases).""" + # Preallocate the full reference output to avoid keeping both the per-expert + # results list and the final cat/stack result alive at the same time. + self.C_ref_tensor = torch.empty_like(self.C_tensor) + + prev = 0 + for i in range(self.expert_cnt): + cur = self.offs_tensor[i].item() + if self.problem.scenario == "2Dx3D": + # A (tokens, hidden), B (E, hidden, intermediate) → C_i (tokens_i, intermediate) + a_slice = self.A_tensor[prev:cur, :] + b_slice = self.B_tensor[i] + self.C_ref_tensor[prev:cur, :].copy_(torch.mm(a_slice, b_slice)) + else: # 2Dx2D + # A (hidden, tokens), B (tokens, intermediate) → C_i (hidden, intermediate) + a_slice = self.A_tensor[:, prev:cur] + b_slice = self.B_tensor[prev:cur, :] + self.C_ref_tensor[i, :, :].copy_(torch.mm(a_slice, b_slice)) + prev = cur + + def create_kernel(self) -> GroupedGemmKernel: + return GroupedGemmKernel( + scenario=self.problem.scenario, + out_dtype=self.temp_type_mapping[self.problem.out_dtype], + accumulate_on_output=self.problem.grad_accumulate + and self.problem.scenario == "2Dx2D", + separate_tensormap_init=self.impl.separate_tensormap_init, + fixed_expert_cnt=self.impl.static_expert_cnt, + acc_dtype=self.temp_type_mapping[self.problem.acc_dtype], + mma_tiler_mnk=self.impl.mma_tiler_mnk, + cluster_shape_mnk=self.impl.cluster_shape_mnk, + use_2cta_instrs=self.impl.use_2cta_instrs, + ) + + def run_kernel(self, kernel: GroupedGemmKernel) -> Optional[float]: + """Run our CuTe kernel. + + Returns: + Average kernel time in ms when perf_e2e is enabled, None otherwise. + """ + workspace_size = kernel.get_workspace_size(self.expert_cnt) + self.workspace_tensor = torch.full( + (workspace_size,), 255, dtype=torch.uint8, device="cuda" + ) + + torch.cuda.synchronize() + + ab_cutlass_dtype = self.temp_type_mapping[self.problem.ab_dtype] + out_cutlass_dtype = self.temp_type_mapping[self.problem.out_dtype] + + a_cute, self.A_tensor = cutlass_torch.cute_tensor_like( + self.A_tensor, ab_cutlass_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_cute, self.B_tensor = cutlass_torch.cute_tensor_like( + self.B_tensor, ab_cutlass_dtype, is_dynamic_layout=True, assumed_align=16 + ) + c_cute, self.C_tensor = cutlass_torch.cute_tensor_like( + self.C_tensor, out_cutlass_dtype, is_dynamic_layout=True, assumed_align=16 + ) + is_dynamic_expert_cnt = self.impl.static_expert_cnt is None + offs_cute, self.offs_tensor = cutlass_torch.cute_tensor_like( + self.offs_tensor, + cutlass.Int32, + is_dynamic_layout=is_dynamic_expert_cnt, + assumed_align=16, + ) + workspace_cute, self.workspace_tensor = cutlass_torch.cute_tensor_like( + self.workspace_tensor, + cutlass.Uint8, + is_dynamic_layout=is_dynamic_expert_cnt, + assumed_align=128, + ) + + # Query max active clusters from hardware + cluster_size = self.impl.cluster_shape_mnk[0] * self.impl.cluster_shape_mnk[1] + max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size) + print(f"A_tensor: {tuple(self.A_tensor.shape)}:{self.A_tensor.stride()}") + print(f"B_tensor: {tuple(self.B_tensor.shape)}:{self.B_tensor.stride()}") + print( + f"offset_tensor: {tuple(self.offs_tensor.shape)}:{self.offs_tensor.stride()}" + ) + print(f"C_tensor: {tuple(self.C_tensor.shape)}:{self.C_tensor.stride()}") + + stream = self._get_stream() + + if self.misc.perf_e2e: + compiled = cute.compile( + kernel, + a_cute, + b_cute, + c_cute, + offs_cute, + None, # bias + workspace_cute, + max_active_clusters, + stream, + ) + + warmup_iters = 4 + timed_iters = 4 + + for _ in range(warmup_iters): + l2_flush() + compiled( + a_cute, + b_cute, + c_cute, + offs_cute, + None, # bias + workspace_cute, + stream, + ) + torch.cuda.synchronize() + + times = [] + for _ in range(timed_iters): + l2_flush() + torch.cuda.synchronize() + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + start_evt.record() + compiled( + a_cute, + b_cute, + c_cute, + offs_cute, + None, # bias + workspace_cute, + stream, + ) + end_evt.record() + torch.cuda.synchronize() + times.append(start_evt.elapsed_time(end_evt)) + + avg_ms = sum(times) / len(times) + print(f"[perf_e2e] Individual times (ms): {[f'{t:.4f}' for t in times]}") + print(f"[perf_e2e] Average kernel time: {avg_ms:.4f} ms") + return avg_ms + else: + l2_flush() + kernel( + a_cute, + b_cute, + c_cute, + offs_cute, + None, # bias + workspace_cute, + max_active_clusters, + stream, + ) + torch.cuda.synchronize() + return None + + def validate(self) -> None: + if not self.misc.perf_run: + assert torch.equal(self.C_tensor, self.C_ref_tensor), ( + "Validation failed: C_tensor != C_ref_tensor" + ) + + def run_sol_comparison(self) -> None: + """Run a dense batched GEMM as Speed-of-Light reference. + + Reuses the same tensor memory from the grouped run by + view/reshape/permute -- zero GPU allocation. + """ + import sys, os + + _examples_root = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..", "..") + ) + if _examples_root not in sys.path: + sys.path.insert(0, _examples_root) + + from blackwell.kernel.dense_gemm.dense_gemm_persistent import ( + PersistentDenseGemmKernel, + ) + + tokens = self.tokens_after_repeat + experts = self.expert_cnt + assert tokens % experts == 0, ( + f"compare_with_sol requires tokens*top_k ({tokens}) " + f"evenly divisible by experts ({experts}) so every group " + f"has exactly the same size" + ) + tpe = tokens // experts + + if self.problem.scenario == "2Dx3D": + M, N, K, L = tpe, self.intermediate, self.hidden, experts + else: # 2Dx2D + M, N, K, L = self.hidden, self.intermediate, tpe, experts + + # Reshape into GEMM-domain batch-last: A(M,K,L), B(N,K,L), C(M,N,L). + # Data values are irrelevant (perf only) — just need correct shape + # and stride pattern so the dense kernel sees the right major mode. + if self.problem.a_layout == "k_major": + a_sol = self.A_tensor.contiguous().view(L, M, K).permute(1, 2, 0) + leading_dim_a = 1 + else: + a_sol = self.A_tensor.contiguous().view(L, K, M).permute(2, 1, 0) + leading_dim_a = 0 + + if self.problem.b_layout == "n_major": + b_sol = self.B_tensor.contiguous().view(L, K, N).permute(2, 1, 0) + leading_dim_b = 0 + else: + b_sol = self.B_tensor.contiguous().view(L, N, K).permute(1, 2, 0) + leading_dim_b = 1 + + if self.problem.c_layout == "n_major": + c_sol = self.C_tensor.contiguous().view(L, M, N).permute(1, 2, 0) + leading_dim_c = 1 + else: + c_sol = self.C_tensor.contiguous().view(L, N, M).permute(2, 1, 0) + leading_dim_c = 0 + + from cutlass.cute.runtime import from_dlpack + + a_cute_sol = from_dlpack(a_sol, assumed_align=16).mark_layout_dynamic( + leading_dim=leading_dim_a + ) + b_cute_sol = from_dlpack(b_sol, assumed_align=16).mark_layout_dynamic( + leading_dim=leading_dim_b + ) + c_cute_sol = from_dlpack(c_sol, assumed_align=16).mark_layout_dynamic( + leading_dim=leading_dim_c + ) + + mma_tiler_mn = self.impl.mma_tiler_mnk[:2] + cluster_shape_mn = self.impl.cluster_shape_mnk[:2] + cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1] + max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size) + + sol_kernel = PersistentDenseGemmKernel( + acc_dtype=self.temp_type_mapping[self.problem.acc_dtype], + use_2cta_instrs=self.impl.use_2cta_instrs, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + use_tma_store=True, + ) + + print(f"\n[SOL] Dense BMM: M={M} N={N} K={K} L={L}") + print(f"[SOL] a_sol: {tuple(a_sol.shape)}:{a_sol.stride()}") + print(f"[SOL] b_sol: {tuple(b_sol.shape)}:{b_sol.stride()}") + print(f"[SOL] c_sol: {tuple(c_sol.shape)}:{c_sol.stride()}") + + l2_flush() + sol_kernel( + a_cute_sol, + b_cute_sol, + c_cute_sol, + max_active_clusters, + self._get_stream(), + ) + torch.cuda.synchronize() + + def run(self) -> None: + from torch.profiler import profile, ProfilerActivity + + print(self.problem) + print(self.impl) + print(self.misc) + self.generate_inputs() + kernel = self.create_kernel() + + if self.misc.perf_e2e: + self.run_kernel(kernel) + else: + with profile( + activities=[ProfilerActivity.CUDA], record_shapes=True + ) as prof: + self.compute_reference() + self.run_kernel(kernel) + if ( + self.misc.compare_with_sol + and self.misc.perf_run + and self.problem.balance_route + ): + self.run_sol_comparison() + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)) + + self.validate() + + +if __name__ == "__main__": + import argparse + + def parse_dtype(s: str) -> torch.dtype: + return getattr(torch, s) + + def parse_tuple(s: str) -> Tuple[int, ...]: + return tuple(int(x) for x in s.split(",")) + + parser = argparse.ArgumentParser() + parser.add_argument("--tokens", type=int, default=128) + parser.add_argument("--experts", type=int, default=128) + parser.add_argument("--top_k_select", type=int, default=8) + parser.add_argument("--balance_route", action="store_true", default=False) + parser.add_argument("--hidden", type=int, default=2048) + parser.add_argument("--intermediate", type=int, default=7168) + parser.add_argument( + "--scenario", type=str, default="2Dx3D", choices=["2Dx3D", "2Dx2D"] + ) + parser.add_argument("--ab_dtype", type=str, default="bfloat16") + parser.add_argument("--out_dtype", type=str, default="bfloat16") + parser.add_argument("--acc_dtype", type=str, default="float32") + parser.add_argument("--grad_accumulate", action="store_true", default=False) + parser.add_argument( + "--a_layout", type=str, default="k_major", choices=["k_major", "m_major"] + ) + parser.add_argument( + "--b_layout", type=str, default="n_major", choices=["k_major", "n_major"] + ) + parser.add_argument( + "--c_layout", type=str, default="n_major", choices=["m_major", "n_major"] + ) + parser.add_argument("--mma_tiler_mnk", type=str, default="128,128,64") + parser.add_argument("--cluster_shape_mnk", type=str, default="1,1,1") + parser.add_argument("--use_2cta_instrs", action="store_true", default=False) + parser.add_argument("--static_expert_cnt", type=int, default=None) + parser.add_argument("--separate_tensormap_init", action="store_true", default=False) + parser.add_argument("--perf_run", action="store_true", default=False) + parser.add_argument("--perf_e2e", action="store_true", default=False) + parser.add_argument("--compare_with_bmm", action="store_true", default=False) + parser.add_argument("--compare_with_sol", action="store_true", default=False) + args = parser.parse_args() + + problem = ProblemDesc( + tokens=args.tokens, + experts=args.experts, + top_k_select=args.top_k_select, + balance_route=args.balance_route, + hidden=args.hidden, + intermediate=args.intermediate, + scenario=args.scenario, + ab_dtype=parse_dtype(args.ab_dtype), + out_dtype=parse_dtype(args.out_dtype), + acc_dtype=parse_dtype(args.acc_dtype), + grad_accumulate=args.grad_accumulate, + a_layout=args.a_layout, + b_layout=args.b_layout, + c_layout=args.c_layout, + ) + if not args.separate_tensormap_init: + print( + "Change separate_tensormap_init to True as current the fused version not implmented yet." + ) + args.separate_tensormap_init = True + impl = ImplDesc( + mma_tiler_mnk=parse_tuple(args.mma_tiler_mnk), + cluster_shape_mnk=parse_tuple(args.cluster_shape_mnk), + use_2cta_instrs=args.use_2cta_instrs, + static_expert_cnt=args.static_expert_cnt, + separate_tensormap_init=args.separate_tensormap_init, + ) + + misc = MiscDesc( + perf_run=args.perf_run, + perf_e2e=args.perf_e2e, + compare_with_bmm=args.compare_with_bmm, + compare_with_sol=args.compare_with_sol, + ) + if misc.no_torch_210: + misc.compare_with_bmm = True + print("Override to set --compare_with_bmm to avoid possible torch crash.") + + tester = GroupedGemmTester(problem, impl, misc) + tester.run() + print("PASS") diff --git a/reference/moe_torch_scaled_grouped_mm.py b/reference/moe_torch_scaled_grouped_mm.py new file mode 100644 index 00000000..5a2d98fe --- /dev/null +++ b/reference/moe_torch_scaled_grouped_mm.py @@ -0,0 +1,3901 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" +Scaled Grouped GEMM for MoE operations with block scaling (MXFP8, MXFP4, NVFP4). + +PyTorch interface (from torch.nn.functional.scaled_grouped_mm): +- 2Dx3D (Forward): mat_a(tokens_sum, K) x mat_b(experts, K, N) -> out(tokens_sum, N) +- 2Dx2D (Weight grad): mat_a(M, tokens_sum) x mat_b(tokens_sum, N) -> out(experts, M, N) + +Kernel interface uses GEMM MNKL domain (same as torch_grouped_mm.py): + A_cute: (M, K, L) + B_cute: (N, K, L) + C_cute: (M, N, L) + SFA_cute, SFB_cute: scale factors with block-scaled atom layout + +The scheduler handles fake dimensions by computing token_offset from offs. +""" + +import os +import sys +from typing import Optional, Tuple, Literal, Type, Union + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass.cute.typing import Pointer +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait + +if __name__ == "__main__": + current_dir = os.path.dirname(os.path.abspath(__file__)) + sys.path.insert(0, os.path.join(current_dir, "../../..")) + +from blackwell.kernel.moe.moe_utils import ( + MoEScaledGroupedGemmTensormapConstructor, +) +from blackwell.kernel.moe.moe_persistent_scheduler import ( + MoEStaticSchedulerParams, + MoEStaticPersistentTileScheduler, + MoEWorkTileInfo, +) +from blackwell.kernel.moe.moe_sched_extension import ScaledGroupedMmSchedExtension +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils +from cutlass.utils.gemm.sm100 import ( + transform_partitioned_tensor_layout, + epilogue_tmem_copy_and_partition, + epilogue_smem_copy_and_partition, +) + +# ============================================================================= +# ScaledGroupedGemmKernel +# ============================================================================= + + +class ScaledGroupedGemmKernel: + """ + Scaled Grouped GEMM kernel for MoE operations with block scaling. + + Combines: + - MoE grouped structure from GroupedGemmKernel (scheduler warp, expert-wise + TMA descriptors, MoEStaticPersistentTileScheduler) + - Block-scaled MMA from Sm100BlockScaledPersistentDenseGemmKernel (SFA/SFB + tensors, blockscaled tiled_mma, SMEM→TMEM SF copy) + + Warp specialization (7 warps): + - Warps 0-3: Epilogue (TMEM → RMEM → SMEM → GMEM, global_scale multiply) + - Warp 4: MMA (tcgen05.mma.block_scale with SFA/SFB in TMEM) + - Warp 5: TMA load (A, B, SFA, SFB from GMEM → SMEM) + - Warp 6: Scheduler (MoEStaticPersistentTileScheduler, produces work tiles) + + __init__ parameters are codegen-time configuration only. + Runtime dtypes (a_dtype, b_dtype, sf_dtype, c_dtype) and layout modes + (a_major_mode, b_major_mode, c_layout) are inferred from input tensors + in __call__. + """ + + def __init__( + self, + scenario: Literal["2Dx3D", "2Dx2D"], + sf_vec_size: int, + accumulate_on_output: bool, + separate_tensormap_init: bool, + consistent_token_padding: bool, + acc_dtype: Type[cutlass.Numeric] = cutlass.Float32, + mma_tiler_mnk: Tuple[int, int, int] = (128, 128, 64), + cluster_shape_mnk: Tuple[int, int, int] = (1, 1, 1), + use_2cta_instrs: bool = False, + fixed_expert_cnt: Optional[int] = None, + ): + # ── User-provided codegen-time configuration ── + self.scenario = scenario + self.sf_vec_size = sf_vec_size + self.accumulate_on_output = accumulate_on_output + self.separate_tensormap_init = separate_tensormap_init + self.consistent_token_padding = consistent_token_padding + self.acc_dtype = acc_dtype + self.mma_tiler_mnk = mma_tiler_mnk + self.cluster_shape_mn = (cluster_shape_mnk[0], cluster_shape_mnk[1]) + self.use_2cta_instrs = use_2cta_instrs + self.fixed_expert_cnt = fixed_expert_cnt + self.arch = "sm_100" + + if accumulate_on_output and scenario == "2Dx3D": + raise ValueError( + "accumulate_on_output only makes sense for 2Dx2D (weight grad)." + ) + + self._validate_mma_tiler_and_cluster_shape() + + # ── MMA tiler — K is refined in _setup_attributes ── + self.mma_tiler = (mma_tiler_mnk[0], mma_tiler_mnk[1], 1) + + # ── CTA group for tcgen05 MMA ── + self.cta_group = ( + tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + # ── Warp specialization (7 warps) ── + self.occupancy = 1 + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4 + self.tma_warp_id = 5 + self.sched_warp_id = 6 + self.threads_per_cta = 32 * len( + ( + self.mma_warp_id, + self.tma_warp_id, + self.sched_warp_id, + *self.epilogue_warp_id, + ) + ) + + # ── Barrier IDs for synchronization ── + self.epilog_sync_bar_id = 1 + self.tmem_alloc_sync_bar_id = 2 + self.tmem_dealloc_sync_bar_id = 3 + + self.smem_capacity = utils.get_smem_capacity_in_bytes(self.arch) + self.num_tmem_alloc_cols = cute.arch.get_max_tmem_alloc_cols(self.arch) + + # ----------------------------------------------------------------- + # Workspace size + # ----------------------------------------------------------------- + + def get_workspace_size(self, expert_cnt: int) -> int: + """Workspace size for the aux init kernel. + + Layout: [TMA descriptors (managed by tensormap ctor)] [padded scale offsets] + """ + desc_bytes = MoEScaledGroupedGemmTensormapConstructor.get_workspace_size( + self.scenario, expert_cnt + ) + padded_offs_bytes = expert_cnt * 4 if not self.consistent_token_padding else 0 + return desc_bytes + padded_offs_bytes + + # ----------------------------------------------------------------- + # Static validation + # ----------------------------------------------------------------- + + def _validate_mma_tiler_and_cluster_shape(self): + """Validate codegen-time MMA tiler and cluster shape constraints.""" + m, n, k = self.mma_tiler_mnk + cm, cn = self.cluster_shape_mn + + if m not in [128, 256]: + raise ValueError(f"mma_tiler M ({m}) must be one of [128, 256]") + + per_cta_m = m // (2 if self.use_2cta_instrs else 1) + if per_cta_m != 128: + raise ValueError( + f"per-CTA mma_tiler M must be 128, got {per_cta_m} " + f"(mma_tiler_m={m}, use_2cta_instrs={self.use_2cta_instrs})" + ) + + if n not in [64, 128, 256]: + raise ValueError(f"mma_tiler N ({n}) must be one of [64, 128, 256]") + + sf_k_granularity = self.sf_vec_size * 4 + if k % sf_k_granularity != 0: + raise ValueError( + f"mma_tiler K ({k}) must be a multiple of " + f"sf_vec_size * 4 = {sf_k_granularity}" + ) + + if cm % (2 if self.use_2cta_instrs else 1) != 0: + raise ValueError( + f"cluster_shape M ({cm}) must be even when use_2cta_instrs=True" + ) + + is_pow2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if cm * cn > 16 or not is_pow2(cm) or not is_pow2(cn) or cm > 4 or cn > 4: + raise ValueError( + f"Invalid cluster_shape ({cm}, {cn}): each dim must be " + f"a power of 2 and <= 4, product must be <= 16" + ) + + if self.sf_vec_size not in {16, 32}: + raise ValueError(f"sf_vec_size ({self.sf_vec_size}) must be 16 or 32") + + # ----------------------------------------------------------------- + # _create_tiled_mma / _create_tiled_mma_sfb + # ----------------------------------------------------------------- + + def _create_tiled_mma(self) -> cute.TiledMma: + """Create blockscaled tiled MMA atom.""" + return sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.b_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mn, + ) + + def _create_tiled_mma_sfb(self) -> cute.TiledMma: + """Create blockscaled tiled MMA atom for SFB (always CtaGroup.ONE).""" + return sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.b_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mn_sfb, + ) + + # ----------------------------------------------------------------- + # _setup_attributes + # ----------------------------------------------------------------- + + def _setup_attributes(self) -> None: + """ + Set up configurations that depend on GEMM inputs. + + Configures: + - tiled_mma / tiled_mma_sfb with correct dtypes and major modes + - MMA/cluster/tile shapes + - Cluster layouts (main + sfb) + - Multicast CTA counts + - Epilogue tile shape + - Stage counts (ACC, AB+SF, C) + - SMEM layouts for A/B/SFA/SFB/C + - TMEM column counts (accumulator + SFA + SFB) + - TMA load bytes + - Overlapping accumulator support + """ + # ── MMA instruction shapes ── + self.mma_inst_shape_mn = (self.mma_tiler[0], self.mma_tiler[1]) + self.mma_inst_shape_mn_sfb = ( + self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1), + cute.round_up(self.mma_inst_shape_mn[1], 128), + ) + + tiled_mma = self._create_tiled_mma() + tiled_mma_sfb = self._create_tiled_mma_sfb() + + # ── MMA / cluster / tile shapes ── + # Use user-specified K dimension from mma_tiler_mnk + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + assert self.mma_tiler_mnk[2] % mma_inst_shape_k == 0, ( + f"mma_tiler K ({self.mma_tiler_mnk[2]}) must be a multiple of " + f"MMA instruction K ({mma_inst_shape_k})" + ) + mma_inst_tile_k = self.mma_tiler_mnk[2] // mma_inst_shape_k + self.mma_tiler = ( + self.mma_inst_shape_mn[0], + self.mma_inst_shape_mn[1], + self.mma_tiler_mnk[2], + ) + self.mma_tiler_sfb = ( + self.mma_inst_shape_mn_sfb[0], + self.mma_inst_shape_mn_sfb[1], + self.mma_tiler_mnk[2], + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + self.cta_tile_shape_mnk_sfb = ( + self.mma_tiler_sfb[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler_sfb[1], + self.mma_tiler_sfb[2], + ) + + # ── Cluster layouts ── + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + self.cluster_layout_sfb_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma_sfb.thr_id.shape,), + ) + + # ── Multicast CTA counts ── + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.num_mcast_ctas_sfb = cute.size(self.cluster_layout_sfb_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + self.is_sfb_mcast = self.num_mcast_ctas_sfb > 1 + + # ── Epilogue tile shape ── + self.epi_tile = sm100_utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + self.epi_tile_n = cute.size(self.epi_tile[1]) + + # ── Stage counts ── + self.num_acc_stage, self.num_ab_stage, self.num_c_stage = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.sf_dtype, + self.sf_vec_size, + self.smem_capacity, + self.occupancy, + ) + + self.num_sched_stages = 2 + + # ── SMEM layouts ── + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + + # ── Overlapping accumulator ── + # N=256: TMEM can't fit 2 full acc buffers + SF, so acc and SF share columns. + # The acc pipeline uses 1 barrier stage with phase-based toggling. + # N<256: TMEM fits 2 independent acc buffers, normal 2-stage pipeline. + self.overlapping_accum = self.cta_tile_shape_mnk[1] == 256 + self.num_acc_pipeline_stages = ( + 1 if self.overlapping_accum else self.num_acc_stage + ) + + # ── TMEM column counts ── + sf_atom_mn = 32 + self.num_sfa_tmem_cols = ( + self.cta_tile_shape_mnk[0] // sf_atom_mn + ) * mma_inst_tile_k + self.num_sfb_tmem_cols = ( + self.cta_tile_shape_mnk_sfb[1] // sf_atom_mn + ) * mma_inst_tile_k + self.num_sf_tmem_cols = self.num_sfa_tmem_cols + self.num_sfb_tmem_cols + self.num_accumulator_tmem_cols = self.cta_tile_shape_mnk[ + 1 + ] * self.num_acc_stage - ( + self.num_sf_tmem_cols if self.overlapping_accum else 0 + ) + + # Only when overlapping_accum, release accumulator buffer early in epilogue + self.iter_acc_early_release_in_epilogue = ( + self.num_sf_tmem_cols // self.epi_tile_n + ) + + # ── TMA load bytes (A + B + SFA + SFB per stage) ── + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + sfa_smem_layout = cute.slice_( + self.sfa_smem_layout_staged, (None, None, None, 0) + ) + sfb_smem_layout = cute.slice_( + self.sfb_smem_layout_staged, (None, None, None, 0) + ) + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout) + sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) + self.num_tma_load_bytes = ( + a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size + ) * atom_thr_size + + # ----------------------------------------------------------------- + # _compute_stages (static) + # ----------------------------------------------------------------- + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + epi_tile: cute.Tile, + c_dtype: Type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + smem_capacity: int, + occupancy: int, + ) -> Tuple[int, int, int]: + """Compute stage counts for ACC, A/B/SFA/SFB, and C.""" + num_acc_stage = 2 + num_c_stage = 2 + + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, + ) + sfa_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + 1, + ) + sfb_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + 1, + ) + c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, + ) + + ab_bytes_per_stage = ( + cute.size_in_bytes(a_dtype, a_smem_layout_stage_one) + + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfa_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one) + ) + mbar_helpers_bytes = 1024 + c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + c_bytes = c_bytes_per_stage * num_c_stage + + sched_work_tile_bytes_per_stage = 16 # 4 fields * sizeof(Int32) + num_sched_stages = 2 + sched_bytes = sched_work_tile_bytes_per_stage * num_sched_stages + + fixed_overhead = mbar_helpers_bytes + c_bytes + sched_bytes + + num_ab_stage = ( + smem_capacity // occupancy - fixed_overhead + ) // ab_bytes_per_stage + + num_c_stage += ( + smem_capacity + - occupancy * ab_bytes_per_stage * num_ab_stage + - occupancy * fixed_overhead + ) // (occupancy * c_bytes_per_stage) + + return num_acc_stage, num_ab_stage, num_c_stage + + # ----------------------------------------------------------------- + # mainloop_s2t_copy_and_partition (from dense_blockscaled) + # ----------------------------------------------------------------- + + def mainloop_s2t_copy_and_partition( + self, + sSF: cute.Tensor, + tSF: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for smem → tmem load of a scale factor tensor, + then partition smem (source) and tmem (destination). + """ + tCsSF_compact = cute.filter_zeros(sSF) + tCtSF_compact = cute.filter_zeros(tSF) + + copy_atom_s2t = cute.make_copy_atom( + tcgen05.Cp4x32x128bOp(self.cta_group), + self.sf_dtype, + ) + tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact) + thr_copy_s2t = tiled_copy_s2t.get_slice(0) + + tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact) + tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t, tCsSF_compact_s2t_ + ) + tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact) + + return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t + + # ----------------------------------------------------------------- + # __call__ (JIT entry point) + # ----------------------------------------------------------------- + + @cute.jit + def __call__( + self, + mat_a: cute.Tensor, # PyTorch mat_a (data) + mat_b: cute.Tensor, # PyTorch mat_b (data) + scale_a: cute.Tensor, # SFA (assembled block-scaled layout) + scale_b: cute.Tensor, # SFB (assembled block-scaled layout) + out: cute.Tensor, # Output C + offs: cute.Tensor, # (experts,) cumsum end offsets, int32 + workspace: cute.Tensor, # Expert-wise TMA desc + padded offs + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + global_scale_a: Optional[cute.Tensor] = None, # NVFP4: per-expert f32 scalar + global_scale_b: Optional[cute.Tensor] = None, # NVFP4: per-expert f32 scalar + bias: Optional[cute.Tensor] = None, + ) -> None: + """Launch the scaled grouped GEMM kernel.""" + if cutlass.const_expr(bias is not None): + raise NotImplementedError("bias is not supported yet (align with torch).") + + # ================================================================= + # Step 1: Transform PyTorch tensors to GEMM domain (fake MNKL) + # ================================================================= + c1 = cutlass.Int32(1) + c0 = cutlass.Int32(0) + + if cutlass.const_expr(self.scenario == "2Dx3D"): + # mat_a: (tokens_sum, hidden) -> A: (fake_m, k, 1) + tokens_sum, hidden = mat_a.shape + a_gemm = cute.make_tensor( + mat_a.iterator, + cute.make_layout( + (tokens_sum, hidden, c1), + stride=(mat_a.stride[0], mat_a.stride[1], c0), + ), + ) + # mat_b: (experts, hidden, intermediate) -> B: (n, k, fake_l) + experts, hidden_b, intermediate = mat_b.shape + b_gemm = cute.make_tensor( + mat_b.iterator, + cute.make_layout( + (intermediate, hidden_b, experts), + stride=(mat_b.stride[2], mat_b.stride[1], mat_b.stride[0]), + ), + ) + # out: (tokens_sum, intermediate) -> C: (fake_m, n, 1) + c_gemm = cute.make_tensor( + out.iterator, + cute.make_layout( + (tokens_sum, intermediate, c1), + stride=(out.stride[0], out.stride[1], c0), + ), + ) + expert_cnt = experts + intermediate_dim = intermediate + hidden_dim = hidden + + # SFA/SFB: scale tensors have host-padded dimensions. + # Use their own shape as the "data shape" for atom tiling. + tokens_sum_padded = scale_a.shape[0] + hidden_padded = scale_a.shape[1] * self.sf_vec_size + sfa_gemm = cute.make_tensor( + scale_a.iterator, + blockscaled_utils.tile_atom_to_shape_SF( + (tokens_sum_padded, hidden_padded, c1), self.sf_vec_size + ), + ) + intermediate_padded_mul_hidden_padded = scale_b.shape[1] + intermediate_padded = ( + intermediate_padded_mul_hidden_padded * self.sf_vec_size + ) // hidden_padded + sfb_gemm = cute.make_tensor( + scale_b.iterator, + blockscaled_utils.tile_atom_to_shape_SF( + (intermediate_padded, hidden_padded, experts), self.sf_vec_size + ), + ) + + else: # 2Dx2D + # mat_a: (hidden, tokens_sum) -> A: (m, fake_k, 1) + hidden, tokens_sum = mat_a.shape + a_gemm = cute.make_tensor( + mat_a.iterator, + cute.make_layout( + (hidden, tokens_sum, c1), + stride=(mat_a.stride[0], mat_a.stride[1], c0), + ), + ) + # mat_b: (tokens_sum, intermediate) -> B: (n, fake_k, 1) + tokens_sum_b, intermediate = mat_b.shape + b_gemm = cute.make_tensor( + mat_b.iterator, + cute.make_layout( + (intermediate, tokens_sum_b, c1), + stride=(mat_b.stride[1], mat_b.stride[0], c0), + ), + ) + # out: (experts, hidden, intermediate) -> C: (m, n, fake_l) + experts, hidden_c, intermediate_c = out.shape + c_gemm = cute.make_tensor( + out.iterator, + cute.make_layout( + (hidden_c, intermediate_c, experts), + stride=(out.stride[1], out.stride[2], out.stride[0]), + ), + ) + expert_cnt = experts + intermediate_dim = intermediate + hidden_dim = hidden + + # SFA/SFB: scale tensors have host-padded dimensions. + hidden_padded = scale_a.shape[0] + tokens_sum_padded = scale_a.shape[1] * self.sf_vec_size + sfa_gemm = cute.make_tensor( + scale_a.iterator, + blockscaled_utils.tile_atom_to_shape_SF( + (hidden_padded, tokens_sum_padded, c1), self.sf_vec_size + ), + ) + intermediate_padded = scale_b.shape[0] + sfb_gemm = cute.make_tensor( + scale_b.iterator, + blockscaled_utils.tile_atom_to_shape_SF( + (intermediate_padded, tokens_sum_padded, c1), self.sf_vec_size + ), + ) + + # ================================================================= + # Step 2: Infer dtypes and major modes + # ================================================================= + + self.a_dtype: Type[cutlass.Numeric] = a_gemm.element_type + self.b_dtype: Type[cutlass.Numeric] = b_gemm.element_type + self.c_dtype: Type[cutlass.Numeric] = c_gemm.element_type + self.sf_dtype: Type[cutlass.Numeric] = sfa_gemm.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a_gemm).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b_gemm).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c_gemm) + + # ================================================================= + # Step 3: Setup kernel attributes + # ================================================================= + + self._setup_attributes() + tiled_mma = self._create_tiled_mma() + tiled_mma_sfb = self._create_tiled_mma_sfb() + + # ================================================================= + # Step 4: Create TMA atoms for A, B, SFA, SFB, C + # ================================================================= + + # ── TMA load A ── + a_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a_gemm, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # ── TMA load B ── + b_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b_gemm, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # ── TMA load SFA ── + # sfa_gemm is already atom-tiled from tile_atom_to_shape_SF + sfa_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + sfa_smem_layout = cute.slice_( + self.sfa_smem_layout_staged, (None, None, None, 0) + ) + tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( + sfa_op, + sfa_gemm, + sfa_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=cutlass.Uint64, + ) + + # ── TMA load SFB ── + # sfb_gemm is already atom-tiled from tile_atom_to_shape_SF + sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB( + self.cluster_shape_mn, tiled_mma.thr_id + ) + sfb_smem_layout = cute.slice_( + self.sfb_smem_layout_staged, (None, None, None, 0) + ) + tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + sfb_op, + sfb_gemm, + sfb_smem_layout, + self.mma_tiler_sfb, + tiled_mma_sfb, + self.cluster_layout_sfb_vmnk.shape, + internal_type=cutlass.Uint64, + ) + + # ── TMA store/reduce C ── + if cutlass.const_expr(self.accumulate_on_output): + c_tma_op = cpasync.CopyReduceBulkTensorTileS2GOp() + else: + c_tma_op = cpasync.CopyBulkTensorTileS2GOp() + + epi_smem_layout = cute.select(self.c_smem_layout_staged, mode=[0, 1]) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + c_tma_op, c_gemm, epi_smem_layout, self.epi_tile + ) + + # ================================================================= + # Step 5: offs_padded tensor (written by desc_init_kernel) + # ================================================================= + + # consistent_token_padding=True → offs_padded=None, main kernel reuses offs + # consistent_token_padding=False → offs_padded in GMEM workspace, written by desc_init + if cutlass.const_expr(self.consistent_token_padding): + offs_padded = None + else: + desc_bytes = MoEScaledGroupedGemmTensormapConstructor.get_workspace_size( + self.scenario, expert_cnt + ) + offs_padded = cute.make_tensor( + cute.recast_ptr(workspace.iterator + desc_bytes, dtype=offs.dtype), + cute.make_layout((expert_cnt,)), + ) + + # ================================================================= + # Step 6: Create MoEStaticSchedulerParams and compute grid + # ================================================================= + + sched_params = MoEStaticSchedulerParams( + scenario=self.scenario, + expert_shape=(expert_cnt, intermediate_dim, hidden_dim), + cta_tile_shape_mnk=self.cta_tile_shape_mnk, + cluster_shape_mn=self.cluster_shape_mn, + ) + + grid = MoEStaticSchedulerParams.get_grid_shape( + sched_params, max_active_clusters + ) + + # ================================================================= + # Step 7: Launch desc_init_kernel (if separate_tensormap_init) + # ================================================================= + + if cutlass.const_expr(self.separate_tensormap_init): + self.desc_init_kernel( + tiled_mma, + tiled_mma_sfb, + a_gemm, + b_gemm, + c_gemm, + sfa_gemm, + sfb_gemm, + offs, + expert_cnt, + workspace.iterator, + self.cluster_layout_vmnk, + self.cluster_layout_sfb_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.sfa_smem_layout_staged, + self.sfb_smem_layout_staged, + self.c_smem_layout_staged, + self.epi_tile, + ).launch( + grid=(1, 1, 1), + block=[self._desc_init_block_threads, 1, 1], + stream=stream, + min_blocks_per_mp=1, + ) + + # ================================================================= + # Step 8: Launch main kernel + # ================================================================= + + self.kernel( + tiled_mma, + tiled_mma_sfb, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + tma_atom_c, + tma_tensor_c, + a_gemm, + b_gemm, + c_gemm, + sfa_gemm, + sfb_gemm, + offs, + sched_params, + workspace.iterator, + self.cluster_layout_vmnk, + self.cluster_layout_sfb_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.sfa_smem_layout_staged, + self.sfb_smem_layout_staged, + self.c_smem_layout_staged, + self.epi_tile, + offs_padded, + global_scale_a, + global_scale_b, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + stream=stream, + min_blocks_per_mp=self.occupancy, + ) + + # ----------------------------------------------------------------- + # desc_init_kernel (GPU device kernel) + # ----------------------------------------------------------------- + + # Number of warps per warp-group in desc_init_kernel. + _desc_init_warps_per_group = 4 + # Threads per warp-group (must equal MoEScaledGroupedGemmTensormapConstructor.ChunkSize). + _desc_init_group_threads = _desc_init_warps_per_group * 32 # 128 + # Total threads in desc_init_kernel (2 warp-groups × 4 warps each). + _desc_init_block_threads = _desc_init_group_threads * 2 # 256 + # Named barrier ID for warp-group-internal sync within Group A. + _desc_init_group_a_bar_id = 1 + + @cute.kernel + def desc_init_kernel( + self, + # ── MMA atoms ── + tiled_mma: cute.TiledMma, + tiled_mma_sfb: cute.TiledMma, + # ── GEMM domain tensors (fake MNKL) ── + a_gemm: cute.Tensor, + b_gemm: cute.Tensor, + c_gemm: cute.Tensor, + sfa_gemm: cute.Tensor, + sfb_gemm: cute.Tensor, + # ── Scheduling / workspace ── + offs: cute.Tensor, + expert_cnt: Union[cutlass.Int32, int], + workspace_ptr: Pointer, + # ── Cluster layouts ── + cluster_layout_vmnk: cute.Layout, + cluster_layout_sfb_vmnk: cute.Layout, + # ── SMEM layouts ── + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout], + epi_tile: cute.Tile, + ): + """ + Pre-initialize expert-wise TMA descriptors and compute padded scale + offsets (``offs_padded``). + + Grid: (1, 1, 1) + Block: (256, 1, 1) — 8 warps split into two groups of 4: + + - **Group A** (warps 0-3, threads 0..127): Compute ``offs_padded`` + prefix sum, write to SMEM + GMEM. + - **Group B** (warps 4-7, threads 128..255): Create TMA descriptors + via ``construct_and_write`` (chunked, with pipeline sync). + + Synchronization: + - Group A internal: NamedBarrier (for cross-warp prefix sum) + - Group A → Group B: PipelineAsync (mbarrier producer-consumer) + """ + chunk_size = self._desc_init_group_threads # 128 + full_mask = 0xFFFFFFFF + warp_size = 32 + + # ================================================================= + # Thread identity + # ================================================================= + + tidx, _, _ = cute.arch.thread_idx() + warp_idx = cute.arch.warp_idx() + lane_in_group = tidx % chunk_size # 0..127 within each group + + # ================================================================= + # Reconstruct TMA ops (same as before) + # ================================================================= + + a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, None, 0)) + b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, None, 0)) + sfa_smem_layout = cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)) + sfb_smem_layout = cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)) + epi_smem_layout = cute.select(c_smem_layout_staged, mode=[0, 1]) + + a_tma_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_tma_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + sfa_tma_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + sfb_tma_op = sm100_utils.cluster_shape_to_tma_atom_SFB( + self.cluster_shape_mn, tiled_mma.thr_id + ) + if cutlass.const_expr(self.accumulate_on_output): + c_tma_op = cpasync.CopyReduceBulkTensorTileS2GOp() + else: + c_tma_op = cpasync.CopyBulkTensorTileS2GOp() + + # ================================================================= + # GMEM offs_padded tensor (written by Group A, read by main kernel) + # Only allocated when consistent_token_padding=False. + # ================================================================= + + if cutlass.const_expr(not self.consistent_token_padding): + desc_bytes = MoEScaledGroupedGemmTensormapConstructor.get_workspace_size( + self.scenario, expert_cnt + ) + gmem_offs_padded = cute.make_tensor( + cute.recast_ptr(workspace_ptr + desc_bytes, dtype=offs.dtype), + cute.make_layout((expert_cnt,)), + ) + + # ================================================================= + # SMEM allocation + # ================================================================= + + smem = utils.SmemAllocator() + + @cute.struct + class DescInitStorage: + # offs_padded SMEM buffer: [carry, chunk[0..127]] + offs_padded_buf: cute.struct.MemRange[cutlass.Int32, chunk_size + 1] + # Cross-warp prefix sum scratch (one per warp in Group A) + warp_sums: cute.struct.MemRange[ + cutlass.Int32, self._desc_init_warps_per_group + ] + # Pipeline mbarrier storage (PipelineAsync with 1 stage needs 2 mbarriers) + pipeline_mbar: cute.struct.MemRange[cutlass.Int64, 2] + + storage = smem.allocate(DescInitStorage) + + # Make a tensor view for the SMEM offs_padded buffer + smem_offs_padded = cute.make_tensor( + storage.offs_padded_buf.data_ptr(), + cute.make_layout((chunk_size + 1,)), + ) + smem_warp_sums = cute.make_tensor( + storage.warp_sums.data_ptr(), + cute.make_layout((self._desc_init_warps_per_group,)), + ) + + # ================================================================= + # Pipeline: Group A (producer) → Group B (consumer) + # ================================================================= + + producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, chunk_size) + consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, chunk_size) + pipe = pipeline.PipelineAsync.create( + num_stages=1, + producer_group=producer_group, + consumer_group=consumer_group, + barrier_storage=storage.pipeline_mbar.data_ptr(), + ) + producer, consumer = pipe.make_participants() + + # Named barrier for Group A internal sync (cross-warp prefix sum) + group_a_sync = pipeline.NamedBarrier( + barrier_id=self._desc_init_group_a_bar_id, + num_threads=chunk_size, + ) + + # ================================================================= + # Padding granularity P + # ================================================================= + + if cutlass.const_expr(self.scenario == "2Dx2D"): + # tokens = K (reduce dim): pad scale cols → P = sf_vec_size × 4 + pad_granularity = self.sf_vec_size * 4 + else: + # tokens = M (non-reduce dim): pad scale rows → P = 128 + pad_granularity = 128 + + # ================================================================= + # Tensormap constructor (for Group B) + # ================================================================= + + tensormap_ctor = MoEScaledGroupedGemmTensormapConstructor( + scenario=self.scenario, + sf_vec_size=self.sf_vec_size, + a_dtype=self.a_dtype, + b_dtype=self.b_dtype, + c_dtype=self.c_dtype, + sf_dtype=self.sf_dtype, + a_smem_layout=a_smem_layout, + b_smem_layout=b_smem_layout, + epi_smem_layout=epi_smem_layout, + sfa_smem_layout=sfa_smem_layout, + sfb_smem_layout=sfb_smem_layout, + a_tma_op=a_tma_op, + b_tma_op=b_tma_op, + c_tma_op=c_tma_op, + sfa_tma_op=sfa_tma_op, + sfb_tma_op=sfb_tma_op, + tiled_mma=tiled_mma, + tiled_mma_sfb=tiled_mma_sfb, + mma_tiler=self.mma_tiler, + mma_tiler_sfb=self.mma_tiler_sfb, + cluster_layout_vmnk_shape=cluster_layout_vmnk.shape, + cluster_layout_sfb_vmnk_shape=cluster_layout_sfb_vmnk.shape, + epi_tile=epi_tile, + a_tensor=a_gemm, + b_tensor=b_gemm, + c_tensor=c_gemm, + sfa_tensor=sfa_gemm, + sfb_tensor=sfb_gemm, + offs=offs, + offs_padded=offs + if cutlass.const_expr(self.consistent_token_padding) + else gmem_offs_padded, + workspace_ptr=workspace_ptr, + expert_cnt=expert_cnt, + ) + + # ================================================================= + # Warp-group split + # ================================================================= + + num_chunks = (expert_cnt + chunk_size - 1) // chunk_size + + if warp_idx < self._desc_init_warps_per_group: + # ============================================================= + # Group A: produce offs_padded into SMEM (+ GMEM if needed) + # ============================================================= + + warp_in_group = warp_idx # 0..3 + lane_in_warp = tidx % warp_size + + carry = cutlass.Int32(0) + chunk_idx = cutlass.Int32(0) + + while chunk_idx < num_chunks: + expert_idx = chunk_idx * chunk_size + lane_in_group + + if cutlass.const_expr(self.consistent_token_padding): + # ── Fast path: offs_padded == offs, just load ── + offs_val = cutlass.Int32(0) + if expert_idx < expert_cnt: + offs_val = offs[expert_idx] + + # Wait for consumer to release SMEM from previous chunk + producer.acquire_and_advance() + + # Write SMEM: [carry, offs[chunk_base..chunk_base+127]] + if lane_in_group == cutlass.Int32(0): + smem_offs_padded[0] = carry + smem_offs_padded[lane_in_group + 1] = offs_val + + # Ensure all SMEM writes visible, then signal consumer + group_a_sync.arrive_and_wait() + producer.commit() + + # Only thread 0 needs carry (to write smem[0] next iteration) + if lane_in_group == cutlass.Int32(0): + carry = smem_offs_padded[chunk_size] + + else: + # ── Full path: compute prefix sum of padded sizes ── + + # Load and compute per-thread padded size + padded_size = cutlass.Int32(0) + if expert_idx < expert_cnt: + prev_off = cutlass.Int32(0) + if expert_idx > cutlass.Int32(0): + prev_off = offs[expert_idx - 1] + size_i = offs[expert_idx] - prev_off + padded_size = ( + (size_i + pad_granularity - 1) // pad_granularity + ) * pad_granularity + + # Stage 1: warp-level inclusive prefix sum (shfl_up) + val = padded_size + for d in [1, 2, 4, 8, 16]: + n = cute.arch.shuffle_sync_up( + val, d, mask=full_mask, mask_and_clamp=0 + ) + if lane_in_warp >= d: + val = val + n + + # Lane 31 of each warp holds the warp total + if lane_in_warp == warp_size - 1: + smem_warp_sums[warp_in_group] = val + + # Group A internal sync (warp_sums visible) + group_a_sync.arrive_and_wait() + + # Stage 2: cross-warp correction + cross_warp_prefix = cutlass.Int32(0) + if warp_in_group >= 1: + cross_warp_prefix = smem_warp_sums[0] + if warp_in_group >= 2: + cross_warp_prefix = cross_warp_prefix + smem_warp_sums[1] + if warp_in_group >= 3: + cross_warp_prefix = cross_warp_prefix + smem_warp_sums[2] + + offs_padded_val = carry + val + cross_warp_prefix + + # Wait for consumer to release SMEM from previous chunk + producer.acquire_and_advance() + + # Write SMEM: [carry, offs_padded[chunk_base..chunk_base+127]] + if lane_in_group == cutlass.Int32(0): + smem_offs_padded[0] = carry + smem_offs_padded[lane_in_group + 1] = offs_padded_val + + # Ensure all SMEM writes visible, then signal consumer + group_a_sync.arrive_and_wait() + producer.commit() + + # Write GMEM (overlaps with Group B's phase 2) + if expert_idx < expert_cnt: + gmem_offs_padded[expert_idx] = offs_padded_val + + # Update carry + carry = smem_offs_padded[chunk_size] + + chunk_idx += 1 + + else: + # ============================================================= + # Group B: create TMA descriptors (chunked, with pipeline sync) + # ============================================================= + + tensormap_ctor.construct_and_write( + lane_in_group, + dependency=(consumer, smem_offs_padded), + ) + + # ----------------------------------------------------------------- + # kernel (GPU device kernel) + # ----------------------------------------------------------------- + + @cute.kernel + def kernel( + self, + # ── MMA atoms ── + tiled_mma: cute.TiledMma, + tiled_mma_sfb: cute.TiledMma, + # ── TMA atoms and tensors: A ── + tma_atom_a: cute.CopyAtom, + tma_tensor_a: cute.Tensor, + # ── TMA atoms and tensors: B ── + tma_atom_b: cute.CopyAtom, + tma_tensor_b: cute.Tensor, + # ── TMA atoms and tensors: SFA ── + tma_atom_sfa: cute.CopyAtom, + tma_tensor_sfa: cute.Tensor, + # ── TMA atoms and tensors: SFB ── + tma_atom_sfb: cute.CopyAtom, + tma_tensor_sfb: cute.Tensor, + # ── TMA atoms and tensors: C ── + tma_atom_c: cute.CopyAtom, + tma_tensor_c: cute.Tensor, + # ── GEMM domain tensors ── + a_gemm: cute.Tensor, + b_gemm: cute.Tensor, + c_gemm: cute.Tensor, + sfa_gemm: cute.Tensor, + sfb_gemm: cute.Tensor, + # ── Scheduling / workspace ── + offs: cute.Tensor, + sched_params: MoEStaticSchedulerParams, + workspace_ptr: Pointer, + # ── Cluster layouts ── + cluster_layout_vmnk: cute.Layout, + cluster_layout_sfb_vmnk: cute.Layout, + # ── SMEM layouts ── + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout], + epi_tile: cute.Tile, + # ── Optional: padded offsets ── + offs_padded: Optional[cute.Tensor], + # ── Optional: NVFP4 per-expert global scales ── + global_scale_a: Optional[cute.Tensor], + global_scale_b: Optional[cute.Tensor], + ): + """ + GPU device kernel for MoE Scaled Grouped GEMM with block scaling. + + Backbone: torch_grouped_mm.py (7-warp MoE scheduler structure) + GEMM internals: dense_blockscaled_gemm_persistent.py + """ + # ================================================================= + # Reconstruct objects that can't be passed as kernel params + # ================================================================= + + a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, None, 0)) + b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, None, 0)) + sfa_smem_layout = cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)) + sfb_smem_layout = cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)) + epi_smem_layout = cute.select(c_smem_layout_staged, mode=[0, 1]) + + a_tma_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_tma_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + sfa_tma_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + sfb_tma_op = sm100_utils.cluster_shape_to_tma_atom_SFB( + self.cluster_shape_mn, tiled_mma.thr_id + ) + if cutlass.const_expr(self.accumulate_on_output): + c_tma_op = cpasync.CopyReduceBulkTensorTileS2GOp() + else: + c_tma_op = cpasync.CopyBulkTensorTileS2GOp() + + # Build offs tuple for the extension + if cutlass.const_expr(offs_padded is not None): + offs_for_ext = (offs, offs_padded) + else: + offs_for_ext = (offs, offs) + + tensormap_ctor = MoEScaledGroupedGemmTensormapConstructor( + scenario=self.scenario, + sf_vec_size=self.sf_vec_size, + a_dtype=self.a_dtype, + b_dtype=self.b_dtype, + c_dtype=self.c_dtype, + sf_dtype=self.sf_dtype, + a_smem_layout=a_smem_layout, + b_smem_layout=b_smem_layout, + epi_smem_layout=epi_smem_layout, + sfa_smem_layout=sfa_smem_layout, + sfb_smem_layout=sfb_smem_layout, + a_tma_op=a_tma_op, + b_tma_op=b_tma_op, + c_tma_op=c_tma_op, + sfa_tma_op=sfa_tma_op, + sfb_tma_op=sfb_tma_op, + tiled_mma=tiled_mma, + tiled_mma_sfb=tiled_mma_sfb, + mma_tiler=self.mma_tiler, + mma_tiler_sfb=self.mma_tiler_sfb, + cluster_layout_vmnk_shape=cluster_layout_vmnk.shape, + cluster_layout_sfb_vmnk_shape=cluster_layout_sfb_vmnk.shape, + epi_tile=epi_tile, + a_tensor=a_gemm, + b_tensor=b_gemm, + c_tensor=c_gemm, + sfa_tensor=sfa_gemm, + sfb_tensor=sfb_gemm, + offs=offs, + offs_padded=offs_padded if offs_padded is not None else offs, + workspace_ptr=workspace_ptr, + ) + ext = ScaledGroupedMmSchedExtension( + scenario=self.scenario, tensormap_ctor=tensormap_ctor + ) + + # ================================================================= + # Kernel setup + # ================================================================= + + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + tidx, _, _ = cute.arch.thread_idx() + + # ================================================================= + # SharedStorage + # ================================================================= + + @cute.struct + class SharedStorage: + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + acc_full_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_acc_pipeline_stages * 2 + ] + sched_buf: cute.struct.MemRange[cutlass.Int32, self.num_sched_stages * 4] + sched_mbar_ptr: cute.struct.MemRange[ + cutlass.Int64, self.num_sched_stages * 2 + ] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + + # ================================================================= + # Pipelines + # ================================================================= + + # AB pipeline (TMA load → MMA) — same as grouped_mm + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer + ) + ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ).make_participants() + + # ACC pipeline (MMA → epilogue) + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = ( + len(self.epilogue_warp_id) * 32 * (2 if use_2cta_instrs else 1) + ) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_acc_consumer_threads + ) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), + num_stages=self.num_acc_pipeline_stages, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ) + + # Scheduler pipeline (sched warp → tma/mma/epi warps) + sched_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32) + num_sched_consumer_threads = 32 * len( + (self.tma_warp_id, self.mma_warp_id, *self.epilogue_warp_id) + ) + sched_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_sched_consumer_threads + ) + sched_pipeline = pipeline.PipelineAsync.create( + num_stages=self.num_sched_stages, + producer_group=sched_producer_group, + consumer_group=sched_consumer_group, + barrier_storage=storage.sched_mbar_ptr.data_ptr(), + defer_sync=True, + ) + + # TMEM allocator + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)), + ) + tmem = utils.TmemAllocator( + storage.tmem_holding_buf.ptr, + barrier_for_retrieve=tmem_alloc_barrier, + allocator_warp_id=self.epilogue_warp_id[0], + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr.ptr, + ) + + # Cluster barrier sync after init + pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True) + + # ================================================================= + # SMEM tensors A/B/SFA/SFB + # ================================================================= + + sA = smem.allocate_tensor( + element_type=self.a_dtype, + layout=a_smem_layout_staged.outer, + byte_alignment=128, + swizzle=a_smem_layout_staged.inner, + ) + sB = smem.allocate_tensor( + element_type=self.b_dtype, + layout=b_smem_layout_staged.outer, + byte_alignment=128, + swizzle=b_smem_layout_staged.inner, + ) + sSFA = smem.allocate_tensor( + element_type=self.sf_dtype, + layout=sfa_smem_layout_staged, + byte_alignment=128, + ) + sSFB = smem.allocate_tensor( + element_type=self.sf_dtype, + layout=sfb_smem_layout_staged, + byte_alignment=128, + ) + + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + + # (MMA, MMA_M, MMA_N, STAGE=2) + tCtAcc_fake = tiled_mma.make_fragment_C( + cute.append(acc_shape, self.num_acc_stage) + ) + if cutlass.const_expr(self.overlapping_accum): + # Overlapping: two acc buffers share TMEM with SF columns, + # so the stage stride is smaller than a full N-width. + tCtAcc_fake = cute.make_tensor( + tCtAcc_fake.iterator, + cute.make_layout( + tCtAcc_fake.shape, + stride=( + tCtAcc_fake.stride[0], + tCtAcc_fake.stride[1], + tCtAcc_fake.stride[2], + (256 - self.num_sf_tmem_cols) * tCtAcc_fake.stride[0][1], + ), + ), + ) + + # Cluster wait before TMEM alloc + pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn) + + # ================================================================= + # Scheduler warp (warp 6) — same as grouped_mm + # ================================================================= + + sched_buf_ptr = storage.sched_buf.data_ptr() + sched_copy_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), cutlass.Int32, num_bits_per_copy=128 + ) + sched_buf_tensor = cute.make_tensor( + sched_buf_ptr, cute.make_layout((4, self.num_sched_stages), stride=(1, 4)) + ) + + if warp_idx == self.sched_warp_id: + scheduler = MoEStaticPersistentTileScheduler.create( + sched_params, offs, cute.arch.block_idx(), cute.arch.grid_dim() + ) + + sched_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_sched_stages + ) + + work_tile_info = scheduler.initial_work_tile_info() + sched_pipeline.producer_acquire(sched_producer_state) + rmem = work_tile_info.to_rmem_tensor() + cute.copy( + sched_copy_atom, + rmem, + sched_buf_tensor[(None, sched_producer_state.index)], + ) + cute.arch.fence_proxy("async.shared", space="cta") + sched_pipeline.producer_commit(sched_producer_state) + sched_producer_state.advance() + + work_tile_info = scheduler.advance_to_next_work() + while work_tile_info.is_valid_tile: + ext.prefetch_for_expert(work_tile_info.expert_idx) + sched_pipeline.producer_acquire(sched_producer_state) + rmem = work_tile_info.to_rmem_tensor() + cute.copy( + sched_copy_atom, + rmem, + sched_buf_tensor[(None, sched_producer_state.index)], + ) + cute.arch.fence_proxy("async.shared", space="cta") + sched_pipeline.producer_commit(sched_producer_state) + sched_producer_state.advance() + + work_tile_info = scheduler.advance_to_next_work() + + sched_pipeline.producer_acquire(sched_producer_state) + sentinel = MoEWorkTileInfo( + cutlass.Int32(-1), + cutlass.Int32(0), + cutlass.Int32(0), + cutlass.Int32(0), + ) + rmem = sentinel.to_rmem_tensor() + cute.copy( + sched_copy_atom, + rmem, + sched_buf_tensor[(None, sched_producer_state.index)], + ) + cute.arch.fence_proxy("async.shared", space="cta") + sched_pipeline.producer_commit(sched_producer_state) + + sched_pipeline.producer_tail(sched_producer_state) + + # ================================================================= + # TMA load warp (warp 5) + # ================================================================= + + if warp_idx == self.tma_warp_id: + # Multicast masks, only used in TMA load warp + a_full_mcast_mask = None + b_full_mcast_mask = None + sfa_full_mcast_mask = None + sfb_full_mcast_mask = None + if cutlass.const_expr( + self.is_a_mcast or self.is_b_mcast or use_2cta_instrs + ): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + sfa_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + sfb_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_sfb_vmnk, + block_in_cluster_coord_sfb_vmnk, + mcast_mode=1, + ) + + sched_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_sched_stages + ) + + # Read initial work_tile_info + sched_pipeline.consumer_wait(sched_consumer_state) + rmem = cute.make_rmem_tensor((4,), cutlass.Int32) + cute.copy( + sched_copy_atom, + sched_buf_tensor[(None, sched_consumer_state.index)], + rmem, + ) + work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem) + cute.arch.fence_acq_rel_cta() + sched_pipeline.consumer_release(sched_consumer_state) + sched_consumer_state.advance() + + while work_tile_info.is_valid_tile: + k_tile_cnt = work_tile_info.k_tile_cnt + + # Get real GEMM domain tensors + TMA desc ptrs via extension + real_a, desc_ptr_a = ext.get_gmem_tensor( + "a", + tma_tensor_a, + offs_for_ext, + work_tile_info, + ) + real_b, desc_ptr_b = ext.get_gmem_tensor( + "b", + tma_tensor_b, + offs_for_ext, + work_tile_info, + ) + real_sfa, desc_ptr_sfa = ext.get_gmem_tensor( + "sfa", + tma_tensor_sfa, + offs_for_ext, + work_tile_info, + ) + real_sfb, desc_ptr_sfb = ext.get_gmem_tensor( + "sfb", + tma_tensor_sfb, + offs_for_ext, + work_tile_info, + ) + + # local_tile for A, B + gA_mkl = cute.local_tile( + real_a, + cute.slice_(self.mma_tiler, (None, 0, None)), + (None, None, None), + ) + gB_nkl = cute.local_tile( + real_b, + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None), + ) + + # local_tile for SFA, SFB + gSFA_mkl = cute.local_tile( + real_sfa, + cute.slice_(self.mma_tiler, (None, 0, None)), + (None, None, None), + ) + gSFB_nkl = cute.local_tile( + real_sfb, + cute.slice_(self.mma_tiler_sfb, (0, None, None)), + (None, None, None), + ) + + # MMA partition for TMA + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_coord_v) + tCgA = thr_mma.partition_A(gA_mkl) + tCgB = thr_mma.partition_B(gB_nkl) + tCgSFA = thr_mma.partition_A(gSFA_mkl) + tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl) + + # TMA partition A + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA partition B + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + # TMA partition SFA + sfa_cta_layout = a_cta_layout + tAsSFA, tAgSFA = cpasync.tma_partition( + tma_atom_sfa, + block_in_cluster_coord_vmnk[2], + sfa_cta_layout, + cute.group_modes(sSFA, 0, 3), + cute.group_modes(tCgSFA, 0, 3), + ) + tAsSFA = cute.filter_zeros(tAsSFA) + tAgSFA = cute.filter_zeros(tAgSFA) + # TMA partition SFB + sfb_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape + ) + tBsSFB, tBgSFB = cpasync.tma_partition( + tma_atom_sfb, + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + cute.group_modes(sSFB, 0, 3), + cute.group_modes(tCgSFB, 0, 3), + ) + tBsSFB = cute.filter_zeros(tBsSFB) + tBgSFB = cute.filter_zeros(tBgSFB) + + # Slice to current tile coords (L=0, expert already selected) + mma_tile_m = work_tile_info.tile_m_idx // cute.size( + tiled_mma.thr_id.shape + ) + tAgA_slice = tAgA[(None, mma_tile_m, None, 0)] + tBgB_slice = tBgB[(None, work_tile_info.tile_n_idx, None, 0)] + tAgSFA_slice = tAgSFA[(None, mma_tile_m, None, 0)] + + # SFB slice — N=64 + slice_n = work_tile_info.tile_n_idx + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): + slice_n = work_tile_info.tile_n_idx // 2 + tBgSFB_slice = tBgSFB[(None, slice_n, None, 0)] + + # TMA load loop + ab_producer.reset() + peek_ab_empty_status = ab_producer.try_acquire() + + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + handle = ab_producer.acquire_and_advance(peek_ab_empty_status) + peek_ab_empty_status = cutlass.Boolean(1) + if handle.count + 1 < k_tile_cnt: + peek_ab_empty_status = ab_producer.try_acquire() + # TMA load A + cute.copy( + tma_atom_a, + tAgA_slice[(None, handle.count)], + tAsA[(None, handle.index)], + tma_bar_ptr=handle.barrier, + tma_desc_ptr=desc_ptr_a, + mcast_mask=a_full_mcast_mask, + ) + # TMA load B + cute.copy( + tma_atom_b, + tBgB_slice[(None, handle.count)], + tBsB[(None, handle.index)], + tma_bar_ptr=handle.barrier, + tma_desc_ptr=desc_ptr_b, + mcast_mask=b_full_mcast_mask, + ) + # TMA load SFA + cute.copy( + tma_atom_sfa, + tAgSFA_slice[(None, handle.count)], + tAsSFA[(None, handle.index)], + tma_bar_ptr=handle.barrier, + tma_desc_ptr=desc_ptr_sfa, + mcast_mask=sfa_full_mcast_mask, + ) + # TMA load SFB + cute.copy( + tma_atom_sfb, + tBgSFB_slice[(None, handle.count)], + tBsSFB[(None, handle.index)], + tma_bar_ptr=handle.barrier, + tma_desc_ptr=desc_ptr_sfb, + mcast_mask=sfb_full_mcast_mask, + ) + + # Read next work_tile_info + sched_pipeline.consumer_wait(sched_consumer_state) + rmem = cute.make_rmem_tensor((4,), cutlass.Int32) + cute.copy( + sched_copy_atom, + sched_buf_tensor[(None, sched_consumer_state.index)], + rmem, + ) + work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem) + cute.arch.fence_acq_rel_cta() + sched_pipeline.consumer_release(sched_consumer_state) + sched_consumer_state.advance() + + ab_producer.tail() + + # ================================================================= + # MMA warp (warp 4) + # ================================================================= + + if warp_idx == self.mma_warp_id: + # MMA fragments (SMEM → TMEM partitions), only used in this warp + tCrA = tiled_mma.make_fragment_A(sA) + tCrB = tiled_mma.make_fragment_B(sB) + + tmem.wait_for_alloc() + acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # SFA TMEM tensor + sfa_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + self.num_accumulator_tmem_cols, + dtype=self.sf_dtype, + ) + tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) + + # SFB TMEM tensor + sfb_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + self.num_accumulator_tmem_cols + self.num_sfa_tmem_cols, + dtype=self.sf_dtype, + ) + tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) + + # S2T copy partitions for SFA/SFB + ( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t, + tCtSFA_compact_s2t, + ) = self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA) + ( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t, + tCtSFB_compact_s2t, + ) = self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB) + + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_pipeline_stages + ) + sched_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_sched_stages + ) + + # Read initial work_tile_info + sched_pipeline.consumer_wait(sched_consumer_state) + rmem = cute.make_rmem_tensor((4,), cutlass.Int32) + cute.copy( + sched_copy_atom, + sched_buf_tensor[(None, sched_consumer_state.index)], + rmem, + ) + work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem) + cute.arch.fence_acq_rel_cta() + sched_pipeline.consumer_release(sched_consumer_state) + sched_consumer_state.advance() + + while work_tile_info.is_valid_tile: + k_tile_cnt = work_tile_info.k_tile_cnt + + # Get accumulator stage index + if cutlass.const_expr(self.overlapping_accum): + acc_stage_index = acc_producer_state.phase ^ 1 + else: + acc_stage_index = acc_producer_state.index + + if is_leader_cta: + tCtAcc = tCtAcc_base[(None, None, None, acc_stage_index)] + + # SFB TMEM pointer offset for N=64 + tCtSFB_mma = tCtSFB + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): + offset = cutlass.Int32((work_tile_info.tile_n_idx % 2) * 2) + shifted_ptr = cute.recast_ptr( + acc_tmem_ptr + + self.num_accumulator_tmem_cols + + self.num_sfa_tmem_cols + + offset, + dtype=self.sf_dtype, + ) + tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout) + + # AB consumer mainloop + ab_consumer.reset() + peek_ab_full_status = cutlass.Boolean(1) + if k_tile_cnt > 0: + peek_ab_full_status = ab_consumer.try_wait() + acc_pipeline.producer_acquire(acc_producer_state) + + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + handle = ab_consumer.wait_and_advance(peek_ab_full_status) + peek_ab_full_status = cutlass.Boolean(1) + if handle.count + 1 < k_tile_cnt: + peek_ab_full_status = ab_consumer.try_wait() + + # S2T copy SFA/SFB from SMEM to TMEM + s2t_stage_coord = ( + None, + None, + None, + None, + handle.index, + ) + cute.copy( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t[s2t_stage_coord], + tCtSFA_compact_s2t, + ) + cute.copy( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t[s2t_stage_coord], + tCtSFB_compact_s2t, + ) + + # Block-scaled GEMM with paired operands + tiled_mma.set(tcgen05.Field.ACCUMULATE, k_tile != 0) + tile_crd = (None, None, None, handle.index) + cute.gemm( + tiled_mma, + tCtAcc, + [tCrA[tile_crd], tCtSFA], + [tCrB[tile_crd], tCtSFB_mma], + tCtAcc, + ) + handle.release() + + if k_tile_cnt > 0: + acc_pipeline.producer_commit(acc_producer_state) + if k_tile_cnt > 0: + acc_producer_state.advance() + + # Read next work_tile_info + sched_pipeline.consumer_wait(sched_consumer_state) + rmem = cute.make_rmem_tensor((4,), cutlass.Int32) + cute.copy( + sched_copy_atom, + sched_buf_tensor[(None, sched_consumer_state.index)], + rmem, + ) + work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem) + cute.arch.fence_acq_rel_cta() + sched_pipeline.consumer_release(sched_consumer_state) + sched_consumer_state.advance() + + acc_pipeline.producer_tail(acc_producer_state) + + # ================================================================= + # SMEM tensor C (allocated after MMA section) + # ================================================================= + + sC = smem.allocate_tensor( + element_type=self.c_dtype, + layout=c_smem_layout_staged.outer, + byte_alignment=128, + swizzle=c_smem_layout_staged.inner, + ) + + # ================================================================= + # Epilogue warps (warps 0-3) + # ================================================================= + + if warp_idx < self.mma_warp_id: + tmem.allocate(self.num_tmem_alloc_cols) + tmem.wait_for_alloc() + acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_pipeline_stages + ) + sched_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_sched_stages + ) + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilogue_warp_id), + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, producer_group=c_producer_group + ) + + epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=self.epilog_sync_bar_id, + num_threads=32 * len(self.epilogue_warp_id), + ) + + # Layout transformation for epilogue + tCtAcc_transformed = transform_partitioned_tensor_layout(tCtAcc_base) + + num_tiles_executed = cutlass.Int32(0) + + # Read initial work_tile_info + sched_pipeline.consumer_wait(sched_consumer_state) + rmem = cute.make_rmem_tensor((4,), cutlass.Int32) + cute.copy( + sched_copy_atom, + sched_buf_tensor[(None, sched_consumer_state.index)], + rmem, + ) + work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem) + cute.arch.fence_acq_rel_cta() + sched_pipeline.consumer_release(sched_consumer_state) + sched_consumer_state.advance() + + while work_tile_info.is_valid_tile: + k_tile_cnt = work_tile_info.k_tile_cnt + + # Get real C tensor + TMA desc ptr + real_c, desc_ptr_c = ext.get_gmem_tensor( + "c", + tma_tensor_c, + offs_for_ext, + work_tile_info, + ) + # local_tile + partition for C + gC_mnl = cute.local_tile( + real_c, + cute.slice_(self.mma_tiler, (None, None, 0)), + (None, None, None), + ) + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + tCgC = thr_mma.partition_C(gC_mnl) + tCgC_transformed = transform_partitioned_tensor_layout(tCgC) + + mma_tile_coord_mnl = ( + work_tile_info.tile_m_idx // cute.size(tiled_mma.thr_id.shape), + work_tile_info.tile_n_idx, + cutlass.Int32(0), + ) + + # Partition for TMEM → RMEM copy + tiled_copy_t2r, tTR_tAcc_base_epi, tTR_rAcc = ( + epilogue_tmem_copy_and_partition( + self, + tidx, + tCtAcc_transformed, + tCgC_transformed, + epi_tile, + use_2cta_instrs, + ) + ) + tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = epilogue_smem_copy_and_partition( + self, tiled_copy_t2r, tTR_rC, tidx, sC + ) + + # TMA partition for C store + tCgC_epi = cute.flat_divide(tCgC_transformed, epi_tile) + bSG_sC, bSG_gC_partitioned = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + cute.group_modes(sC, 0, 2), + cute.group_modes(tCgC_epi, 0, 2), + ) + bSG_gC = bSG_gC_partitioned[(None, None, None, *mma_tile_coord_mnl)] + + # Get accumulator stage index + if cutlass.const_expr(self.overlapping_accum): + acc_stage_index = acc_consumer_state.phase + reverse_subtile = True if acc_stage_index == 0 else False + else: + acc_stage_index = acc_consumer_state.index + + # Set TMEM buffer for current tile + tTR_tAcc = tTR_tAcc_base_epi[ + (None, None, None, None, None, acc_stage_index) + ] + + # Wait for accumulator buffer full + if k_tile_cnt > 0: + acc_pipeline.consumer_wait(acc_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + + # Compute per-expert global_scale alpha for NVFP4 + if cutlass.const_expr(global_scale_a is not None): + expert_idx = work_tile_info.expert_idx + alpha = cute.arch.load( + global_scale_a.iterator + expert_idx, + cutlass.Float32, + ) * cute.arch.load( + global_scale_b.iterator + expert_idx, + cutlass.Float32, + ) + else: + alpha = None + + # Store accumulator to global memory in subtiles + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + num_prev_subtiles = num_tiles_executed * subtile_cnt + + for subtile_idx in cutlass.range(subtile_cnt): + real_subtile_idx = subtile_idx + if cutlass.const_expr(self.overlapping_accum): + if reverse_subtile: + real_subtile_idx = ( + self.cta_tile_shape_mnk[1] // self.epi_tile_n + - 1 + - subtile_idx + ) + + # TMEM → RMEM + tTR_tAcc_mn = tTR_tAcc[(None, None, None, real_subtile_idx)] + if cutlass.const_expr(self.scenario == "2Dx2D"): + if k_tile_cnt > 0: + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + else: + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # Early release for overlapping_accum + if cutlass.const_expr(self.overlapping_accum): + if subtile_idx == self.iter_acc_early_release_in_epilogue: + cute.arch.fence_view_async_tmem_load() + if k_tile_cnt > 0: + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + # Convert to output dtype, apply global_scale + acc_vec = cute.zeros_like(tiled_copy_r2s.retile(tTR_rAcc)) + if cutlass.const_expr(self.scenario == "2Dx2D"): + if k_tile_cnt > 0: + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + else: + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + if cutlass.const_expr(global_scale_a is not None): + acc_vec = acc_vec * alpha + acc_vec = acc_vec.to(self.c_dtype) + tRS_rC.store(acc_vec) + + # RMEM → SMEM + c_buffer = (num_prev_subtiles + subtile_idx) % self.num_c_stage + cute.copy( + tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)] + ) + cute.arch.fence_proxy("async.shared", space="cta") + epilog_sync_barrier.arrive_and_wait() + + # SMEM → GMEM (TMA store or TMA reduce) + if warp_idx == self.epilogue_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, real_subtile_idx)], + tma_desc_ptr=desc_ptr_c, + ) + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + epilog_sync_barrier.arrive_and_wait() + + # Release accumulator buffer (non-overlapping path) + if cutlass.const_expr(not self.overlapping_accum): + if k_tile_cnt > 0: + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + num_tiles_executed += cutlass.Int32(1) + + # Read next work_tile_info + sched_pipeline.consumer_wait(sched_consumer_state) + rmem = cute.make_rmem_tensor((4,), cutlass.Int32) + cute.copy( + sched_copy_atom, + sched_buf_tensor[(None, sched_consumer_state.index)], + rmem, + ) + work_tile_info = MoEWorkTileInfo.from_rmem_tensor(rmem) + cute.arch.fence_acq_rel_cta() + sched_pipeline.consumer_release(sched_consumer_state) + sched_consumer_state.advance() + + # Wait for C store complete + c_pipeline.producer_tail() + + # Free TMEM + tmem.relinquish_alloc_permit() + epilog_sync_barrier.arrive_and_wait() + tmem.free(acc_tmem_ptr) + + +# ============================================================================= +# Non-Kernel Part +# ============================================================================= + +from dataclasses import dataclass, field +import re + +import numpy as np +import torch +import cutlass.torch as cutlass_torch + +# ============================================================================= +# Utility functions +# ============================================================================= + + +def ceil_div(a: int, b: int) -> int: + return (a + b - 1) // b + + +def round_up(a: int, b: int) -> int: + return ceil_div(a, b) * b + + +def torch_version_lt(major: int, minor: int) -> bool: + """Best-effort torch version check that tolerates local build suffixes.""" + match = re.match(r"^\s*(\d+)\.(\d+)", torch.__version__) + if match is None: + print( + "WARNING: failed to parse torch.__version__, " + "falling back to manual host reference." + ) + return True + version = (int(match.group(1)), int(match.group(2))) + return version < (major, minor) + + +def offs_to_group_sizes(offs: torch.Tensor) -> list[int]: + """Convert cumulative end offsets to per-group sizes.""" + offs_cpu = offs.cpu().tolist() + prev = 0 + sizes = [] + for end in offs_cpu: + sizes.append(end - prev) + prev = end + return sizes + + +def l2_flush(size_mb: int = 400) -> None: + """Best-effort L2 flush by touching a large temporary tensor.""" + num_bytes = size_mb * 1024 * 1024 + flush_buf = torch.randint(0, 256, (num_bytes,), dtype=torch.uint8, device="cuda") + del flush_buf + + +# ============================================================================= +# Format configuration +# +# Note: For all current formats, sf_vec_size == blocksize. +# The kernel can derive sf_vec_size from blocksize directly. +# ============================================================================= + +_FORMAT_CONFIG = { + "mxfp8": { + "data_dtype": torch.float8_e4m3fn, + "blocksize": 32, + "scale_dtype": torch.float8_e8m0fnu, + "has_global_scale": False, + }, + "mxfp4": { + "data_dtype": torch.float4_e2m1fn_x2, + "blocksize": 32, + "scale_dtype": torch.float8_e8m0fnu, + "has_global_scale": False, + }, + "nvfp4": { + "data_dtype": torch.float4_e2m1fn_x2, + "blocksize": 16, + "scale_dtype": torch.float8_e4m3fn, + "has_global_scale": True, + }, +} + +# FP4 nibble encoding: value → 4-bit nibble (float4 e2m1 format) +# 0 → 0x0 +# 0.5 → 0x1 1.0 → 0x2 1.5 → 0x3 +# 2.0 → 0x4 3.0 → 0x5 4.0 → 0x6 6.0 → 0x7 +# -0 → 0x8 -0.5 → 0x9 -1.0 → 0xA -1.5 → 0xB +# -2.0 → 0xC -3.0 → 0xD -4.0 → 0xE -6.0 → 0xF + +# Correctness-friendly: only {0, 1, -1} → nibbles {0x0, 0x2, 0xA} +_FP4_CORRECTNESS_NIBBLES = torch.tensor([0x0, 0x2, 0xA], dtype=torch.uint8) +# Perf: all 16 valid nibbles (index == nibble value) +_FP4_PERF_NIBBLES = torch.arange(16, dtype=torch.uint8) +_FP4_DECODE_TABLE = torch.tensor( + [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ], + dtype=torch.float32, +) + + +# ============================================================================= +# Scale shape computation +# ============================================================================= + + +def compute_scale_shape( + scenario: str, + operand: str, + group_sizes: list[int], + hidden: int, + intermediate: int, + K_fixed: int, + blocksize: int, + expert_cnt: int, +) -> tuple[int, ...]: + """ + Compute the assembled (swizzled 32_4_4) scale tensor shape. + + Swizzle 32_4_4 pads each group's scale to rows=round_up(non_K, 128), + cols=round_up(ceil_div(K, blocksize), 4), then flattens per group. + + Scale layout per scenario/operand: + 2Dx3D A: groups along M (variable per expert), K fixed + -> (sum(round_up(M_g, 128)), round_up(ceil_div(K, bs), 4)) + 2Dx3D B: per-expert (K, N same for all) + -> (G, round_up(N, 128) * round_up(ceil_div(K, bs), 4)) + 2Dx2D A: M fixed, groups along K (variable per expert) + -> (round_up(M, 128), sum(round_up(ceil_div(K_g, bs), 4))) + 2Dx2D B: N fixed, groups along K (variable per expert) + -> (round_up(N, 128), sum(round_up(ceil_div(K_g, bs), 4))) + + Args: + scenario: "2Dx3D" or "2Dx2D" + operand: "a" or "b" + group_sizes: per-expert sizes of the grouped dimension + (M sizes for 2Dx3D, K sizes for 2Dx2D) + hidden: M dimension (hidden_size) + intermediate: N dimension (intermediate_size) + K_fixed: K dimension (used where K is fixed across experts) + blocksize: 32 for MXFP8/MXFP4, 16 for NVFP4 + expert_cnt: number of experts (G) + """ + if scenario == "2Dx3D": + # group_sizes = per-expert M sizes; K is fixed for all experts + if operand == "a": + total_rows = sum(round_up(mg, 128) for mg in group_sizes) + total_cols = round_up(ceil_div(K_fixed, blocksize), 4) + return (total_rows, total_cols) + else: + padded_N = round_up(intermediate, 128) + padded_K_scale = round_up(ceil_div(K_fixed, blocksize), 4) + return (expert_cnt, padded_N * padded_K_scale) + else: # 2Dx2D + # group_sizes = per-expert K sizes; M and N are fixed + if operand == "a": + padded_M = round_up(hidden, 128) + total_cols = sum(round_up(ceil_div(kg, blocksize), 4) for kg in group_sizes) + return (padded_M, total_cols) + else: + padded_N = round_up(intermediate, 128) + total_cols = sum(round_up(ceil_div(kg, blocksize), 4) for kg in group_sizes) + return (padded_N, total_cols) + + +def to_blocked(scale_2d: torch.Tensor) -> torch.Tensor: + """Pad and apply the Blackwell 32_4_4 scale swizzle to one raw scale tensor.""" + if scale_2d.dim() != 2: + raise ValueError(f"Expected 2D scale tensor, got {scale_2d.dim()}D.") + rows, cols = scale_2d.shape + if rows == 0 or cols == 0: + return scale_2d.new_empty((0,)) + + row_blocks = ceil_div(rows, 128) + col_blocks = ceil_div(cols, 4) + padded_rows = row_blocks * 128 + padded_cols = col_blocks * 4 + + padded = scale_2d + if (rows, cols) != (padded_rows, padded_cols): + padded = torch.zeros( + (padded_rows, padded_cols), dtype=scale_2d.dtype, device=scale_2d.device + ) + padded[:rows, :cols] = scale_2d + + blocks = padded.view(row_blocks, 128, col_blocks, 4).permute(0, 2, 1, 3) + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + return rearranged.flatten() + + +def pad_and_swizzle_single(raw_scale_2d: torch.Tensor) -> torch.Tensor: + if raw_scale_2d.dim() != 2: + raise ValueError(f"Expected 2D scale tensor, got {raw_scale_2d.dim()}D.") + return to_blocked(raw_scale_2d) + + +def create_raw_scale_tensor( + non_k_size: int, + k_size: int, + blocksize: int, + scale_dtype: torch.dtype, + device: str = "cuda", +) -> torch.Tensor: + """Create one raw, non-swizzled scale tensor with exact values in {1, 2}.""" + scale_cols = ceil_div(k_size, blocksize) + return ( + torch.randint( + 1, + 3, + (non_k_size, scale_cols), + dtype=torch.float32, + device=device, + ) + .to(scale_dtype) + .reshape(non_k_size, scale_cols) + ) + + +def cat_byte_reinterpretable_tensors( + tensors: list[torch.Tensor], dim: int = 0 +) -> torch.Tensor: + """Concatenate byte-backed float tensors via uint8 view when native cat is unsupported.""" + if not tensors: + raise ValueError("Expected at least one tensor to concatenate.") + first = tensors[0] + if first.is_floating_point() and first.element_size() == 1: + concatenated = torch.cat( + [tensor.view(torch.uint8) for tensor in tensors], dim=dim + ) + return concatenated.view(first.dtype) + return torch.cat(tensors, dim=dim) + + +def stack_byte_reinterpretable_tensors( + tensors: list[torch.Tensor], dim: int = 0 +) -> torch.Tensor: + """Stack byte-backed float tensors via uint8 view when native stack is unsupported.""" + if not tensors: + raise ValueError("Expected at least one tensor to stack.") + first = tensors[0] + if first.is_floating_point() and first.element_size() == 1: + stacked = torch.stack([tensor.view(torch.uint8) for tensor in tensors], dim=dim) + return stacked.view(first.dtype) + return torch.stack(tensors, dim=dim) + + +def assemble_raw_scales_2d2d( + raw_scales: list[torch.Tensor], non_k_size: int +) -> torch.Tensor: + flat_parts = [pad_and_swizzle_single(scale) for scale in raw_scales] + all_flat = cat_byte_reinterpretable_tensors(flat_parts, dim=0) + return all_flat.reshape(round_up(non_k_size, 128), -1) + + +def assemble_raw_scales_2d3d_3d_side(raw_scales: list[torch.Tensor]) -> torch.Tensor: + flat_parts = [pad_and_swizzle_single(scale) for scale in raw_scales] + return stack_byte_reinterpretable_tensors(flat_parts, dim=0) + + +def assemble_raw_scales_2d3d_2d_side(raw_scales: list[torch.Tensor]) -> torch.Tensor: + flat_parts = [pad_and_swizzle_single(scale) for scale in raw_scales] + all_flat = cat_byte_reinterpretable_tensors(flat_parts, dim=0) + total_rows = sum(round_up(scale.shape[0], 128) for scale in raw_scales) + return all_flat.reshape(total_rows, -1) + + +def fp4_packed_dim(tensor: torch.Tensor) -> int: + positive_strides = [ + (abs(stride), idx) for idx, stride in enumerate(tensor.stride()) if stride > 0 + ] + if not positive_strides: + return tensor.dim() - 1 + return min(positive_strides)[1] + + +def unpack_fp4_to_f32(packed: torch.Tensor) -> torch.Tensor: + """Unpack a float4_e2m1fn_x2 tensor into float32 along the packed dimension.""" + packed_dim = fp4_packed_dim(packed) + raw = packed.view(torch.uint8) + + if packed_dim != raw.dim() - 1: + perm = list(range(raw.dim())) + perm[packed_dim], perm[-1] = perm[-1], perm[packed_dim] + raw = raw.permute(perm).contiguous() + else: + perm = None + + lo = (raw & 0x0F).to(torch.int64) + hi = (raw >> 4).to(torch.int64) + lut = _FP4_DECODE_TABLE.to(raw.device) + + unpacked_shape = list(raw.shape) + unpacked_shape[-1] *= 2 + unpacked = torch.empty(unpacked_shape, dtype=torch.float32, device=raw.device) + unpacked[..., ::2] = lut[lo] + unpacked[..., 1::2] = lut[hi] + + if perm is not None: + unpacked = unpacked.permute(perm) + return unpacked + + +def slice_tensor_logical_dim( + tensor: torch.Tensor, dim: int, start: int, end: int +) -> torch.Tensor: + """Slice along the logical dimension, compensating for FP4 packing when needed.""" + if tensor.dtype == torch.float4_e2m1fn_x2 and dim == fp4_packed_dim(tensor): + if start % 2 != 0 or end % 2 != 0: + raise ValueError( + f"FP4 packed slicing requires even indices, got start={start}, end={end}." + ) + start = start // 2 + end = end // 2 + return tensor.narrow(dim, start, end - start) + + +def dequant_block_scale_to_fp32( + data: torch.Tensor, + raw_scale: torch.Tensor, + blocksize: int, + global_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Dequantize a single 2D tensor using raw block scales into fp32.""" + if data.dtype == torch.float4_e2m1fn_x2: + data_fp32 = unpack_fp4_to_f32(data) + else: + data_fp32 = data.to(torch.float32) + + if data_fp32.dim() != 2 or raw_scale.dim() != 2: + raise ValueError( + f"Expected 2D tensors, got data={data_fp32.dim()}D raw_scale={raw_scale.dim()}D." + ) + + expected_scale_shape = (data_fp32.shape[0], ceil_div(data_fp32.shape[1], blocksize)) + if tuple(raw_scale.shape) != expected_scale_shape: + raise ValueError( + f"Scale shape mismatch: expected {expected_scale_shape}, got {tuple(raw_scale.shape)}." + ) + + scale_fp32 = raw_scale.to(torch.float32) + expanded_scale = scale_fp32.repeat_interleave(blocksize, dim=-1)[ + :, : data_fp32.shape[1] + ] + result = data_fp32 * expanded_scale + + if global_scale is not None: + result = result * global_scale.to(torch.float32).reshape(1, 1) + return result + + +def transpose_rhs_for_block_dequant(data: torch.Tensor) -> torch.Tensor: + """Convert a (K, N) RHS slice into an (N, K) tensor for block dequant along K.""" + if data.dim() != 2: + raise ValueError(f"Expected 2D RHS tensor, got {data.dim()}D.") + if data.dtype == torch.float4_e2m1fn_x2: + # Avoid contiguous()/copy_ on FP4 tensors; unpack first, then transpose in fp32. + return unpack_fp4_to_f32(data).transpose(0, 1) + return data.transpose(0, 1) + + +# ============================================================================= +# Host Validation +# ============================================================================= + + +@dataclass +class ProblemDesc: + tokens: int + experts: int + top_k_select: int + balance_route: bool + hidden: int + intermediate: int + scenario: Literal["2Dx3D", "2Dx2D"] + kind: Literal["mxfp8", "mxfp4", "nvfp4"] + out_dtype: torch.dtype = torch.bfloat16 + acc_dtype: torch.dtype = torch.float32 + grad_accumulate: bool = False + # If True, the user guarantees activation tensors (with tokens_sum dim) + # are padded per-group to the same granularity as the block-scale layout: + # 2Dx3D (groups along M): each group's M_g padded to 128 + # 2Dx2D (groups along K): each group's K_g padded to sf_vec_size * 4 + # This enables the kernel to skip padded-offset computation. + # Currently NOT implemented — forced to False at CLI level. + consistent_token_padding: bool = False + # GEMM-domain layout control (which axis is stride-1) + # Only effective for FP8. FP4 always uses the torch-expected layout + # (K stride-1 for both A and B). + # A (M, K): "k_major" → K stride-1 (default) | "m_major" → M stride-1 + # B (N, K): "k_major" → K stride-1 (default) | "n_major" → N stride-1 + # C (M, N): "n_major" → N stride-1 (default) | "m_major" → M stride-1 + # Note: default b_layout is "k_major" (unlike torch_grouped_mm.py's "n_major") + # because torch.nn.functional.scaled_grouped_mm expects K stride-1 for B. + a_layout: Literal["k_major", "m_major"] = "k_major" + b_layout: Literal["k_major", "n_major"] = "k_major" + c_layout: Literal["n_major", "m_major"] = "n_major" + + def __str__(self) -> str: + d = lambda t: str(t).split(".")[-1] + route = "balanced" if self.balance_route else "random" + return ( + f"ProblemDesc: {self.scenario} | kind={self.kind} | " + f"tokens={self.tokens} experts={self.experts} " + f"top_k={self.top_k_select} route={route} | " + f"hidden={self.hidden} intermediate={self.intermediate} | " + f"out={d(self.out_dtype)} acc={d(self.acc_dtype)} " + f"grad_acc={self.grad_accumulate} " + f"consistent_pad={self.consistent_token_padding} | " + f"layout: A={self.a_layout} B={self.b_layout} C={self.c_layout}" + ) + + +@dataclass +class ImplDesc: + mma_tiler_mnk: Tuple[int, int, int] = (128, 128, 64) + cluster_shape_mnk: Tuple[int, int, int] = (1, 1, 1) + use_2cta_instrs: bool = False + static_expert_cnt: Optional[int] = None + separate_tensormap_init: bool = True + + def __str__(self) -> str: + tile = ",".join(map(str, self.mma_tiler_mnk)) + cluster = ",".join(map(str, self.cluster_shape_mnk)) + static_e = ( + self.static_expert_cnt if self.static_expert_cnt is not None else "dynamic" + ) + return ( + f"ImplDesc: tile={tile} cluster={cluster} " + f"2cta={self.use_2cta_instrs} | " + f"static_E={static_e} sep_tmap={self.separate_tensormap_init}" + ) + + +@dataclass +class MiscDesc: + perf_run: bool = False + perf_e2e: bool = False + compare_with_sol: bool = False + no_torch_210: bool = field(init=False) + + def __post_init__(self): + self.no_torch_210 = torch_version_lt(2, 10) + if self.perf_e2e and not self.perf_run: + raise ValueError("--perf_e2e requires --perf_run to be enabled.") + if self.perf_e2e and self.compare_with_sol: + raise ValueError( + "--perf_e2e and --compare_with_sol are mutually exclusive." + ) + + def __str__(self) -> str: + return ( + f"MiscDesc: perf={self.perf_run} perf_e2e={self.perf_e2e} " + f"sol={self.compare_with_sol} no_torch_210={self.no_torch_210}" + ) + + +class ScaledGroupedGemmTester: + def __init__(self, problem: ProblemDesc, impl: ImplDesc, misc: MiscDesc): + self.problem = problem + self.impl = impl + self.misc = misc + + self.cfg = _FORMAT_CONFIG[problem.kind] + self.tokens_after_repeat = problem.tokens * problem.top_k_select + self.expert_cnt = problem.experts + self.hidden = problem.hidden + self.intermediate = problem.intermediate + + self.A_tensor: Optional[torch.Tensor] = None + self.B_tensor: Optional[torch.Tensor] = None + self.C_tensor: Optional[torch.Tensor] = None + self.C_ref_tensor: Optional[torch.Tensor] = None + self.scale_a_tensor: Optional[torch.Tensor] = None + self.scale_b_tensor: Optional[torch.Tensor] = None + self.raw_scale_a_tensors: Optional[list[torch.Tensor]] = None + self.raw_scale_b_tensors: Optional[list[torch.Tensor]] = None + self.global_scale_a: Optional[torch.Tensor] = None + self.global_scale_b: Optional[torch.Tensor] = None + self.offs_tensor: Optional[torch.Tensor] = None + self.workspace_tensor: Optional[torch.Tensor] = None + + if problem.grad_accumulate and problem.scenario == "2Dx3D": + raise ValueError( + "grad_accumulate only makes sense for 2Dx2D (weight grad) scenario." + ) + + # ----------------------------------------------------------------- + # Offs generation (aligned to blocksize) + # ----------------------------------------------------------------- + + def _generate_offs(self) -> torch.Tensor: + """Generate group-end offsets aligned to blocksize. + + Some experts may receive 0 tokens (valid in real MoE routing). + Each non-empty group's size is a multiple of blocksize. + """ + blocksize = self.cfg["blocksize"] + total = self.tokens_after_repeat + expert_cnt = self.expert_cnt + + assert total % blocksize == 0, ( + f"tokens_after_repeat ({total}) must be divisible by " + f"blocksize ({blocksize})" + ) + n_slots = total // blocksize + + if self.problem.balance_route: + # Distribute as evenly as possible; some experts get 0 if n_slots < expert_cnt + base = n_slots // expert_cnt + remainder = n_slots % expert_cnt + slots = [base + (1 if i < remainder else 0) for i in range(expert_cnt)] + else: + # Dirichlet distribution: naturally allows 0-size groups + # alpha=1.0 → uniform on simplex (moderate variation) + # alpha<1.0 → skewed (few experts get most tokens) + # alpha>1.0 → more uniform + proportions = np.random.dirichlet([0.5] * expert_cnt) + raw = np.floor(proportions * n_slots).astype(int) + deficit = n_slots - raw.sum() + while deficit > 0: + idx = int(np.argmin(raw / (proportions * n_slots + 1e-12))) + raw[idx] += 1 + deficit -= 1 + while deficit < 0: + ratios = np.where( + raw > 0, + raw / (proportions * n_slots + 1e-12), + -np.inf, + ) + idx = int(np.argmax(ratios)) + raw[idx] -= 1 + deficit += 1 + slots = raw.tolist() + + assert sum(slots) == n_slots + + cum = 0 + offsets = [] + for s in slots: + cum += s * blocksize + offsets.append(cum) + return torch.tensor(offsets, dtype=torch.int32, device="cuda") + + # ----------------------------------------------------------------- + # Tensor creation helpers + # ----------------------------------------------------------------- + + def _create_fp8_tensor(self, shape: tuple) -> torch.Tensor: + """Create FP8 tensor. + + - correctness mode: randint {-1, 0, 1} via bf16 cast + - perf mode: random valid fp8 bit patterns via uint8 + (float8_e4m3fn NaN encodings 0x7F/0xFF are replaced) + """ + data_dtype = self.cfg["data_dtype"] + elem_cnt = 1 + for s in shape: + elem_cnt *= s + if self.misc.perf_run: + raw = torch.randint(0, 256, (elem_cnt,), dtype=torch.uint8, device="cuda") + # float8_e4m3fn: 0x7F and 0xFF are NaN → clamp to valid max + if data_dtype == torch.float8_e4m3fn: + raw[raw == 0x7F] = 0x7E + raw[raw == 0xFF] = 0xFE + return raw.view(data_dtype).reshape(shape) + else: + return ( + torch.randint(-1, 2, (elem_cnt,), dtype=torch.bfloat16, device="cuda") + .to(data_dtype) + .reshape(shape) + ) + + def _create_fp4_tensor( + self, logical_shape: tuple, packed_dim: int = -1 + ) -> torch.Tensor: + """Create FP4 tensor. + + Args: + logical_shape: shape in FP4 elements (packed_dim size must be even). + packed_dim: dimension to pack (halve). This dim becomes stride-1. + + - perf mode: random uint8 bytes (all 256 values are valid FP4 pairs, + FP4 e2m1 has no NaN/inf). No nibble mapping needed. + - correctness mode: index→nibble mapping for values {0, 1, -1}, + then explicit nibble packing. + + Returns: + float4_e2m1fn_x2 tensor with packed_dim halved and stride-1. + """ + ndim = len(logical_shape) + packed_dim = packed_dim % ndim + assert logical_shape[packed_dim] % 2 == 0, ( + f"packed_dim {packed_dim} size ({logical_shape[packed_dim]}) must be even" + ) + + if self.misc.perf_run: + # All 256 byte values are valid FP4 pairs — just random bytes + elem_cnt = 1 + for s in logical_shape: + elem_cnt *= s + byte_cnt = elem_cnt // 2 + + flat = torch.randint(0, 256, (byte_cnt,), dtype=torch.uint8, device="cuda") + + # Build shape with packed dim moved to last and halved + shape_reordered = list(logical_shape) + need_perm = packed_dim != ndim - 1 + if need_perm: + shape_reordered[packed_dim], shape_reordered[-1] = ( + shape_reordered[-1], + shape_reordered[packed_dim], + ) + shape_reordered[-1] //= 2 + + tensor = flat.view(torch.float4_e2m1fn_x2).reshape(shape_reordered) + + if need_perm: + perm = list(range(ndim)) + perm[packed_dim], perm[-1] = perm[-1], perm[packed_dim] + tensor = tensor.permute(perm) + return tensor + + # ── Correctness mode: index→nibble mapping + explicit pack ── + # Use uint8 + masked_fill_ instead of int64 fancy indexing to avoid + # 16x memory overhead (int64 = 8 bytes vs FP4 = 0.5 bytes per element). + + nibbles = torch.randint(0, 3, logical_shape, dtype=torch.uint8, device="cuda") + nibbles.masked_fill_(nibbles == 2, 0xA) + nibbles.masked_fill_(nibbles == 1, 0x2) + + # Move packed_dim to last for packing + need_perm = packed_dim != ndim - 1 + if need_perm: + perm_to_last = list(range(ndim)) + perm_to_last[packed_dim], perm_to_last[-1] = ( + perm_to_last[-1], + perm_to_last[packed_dim], + ) + nibbles = nibbles.permute(perm_to_last).contiguous() + + # Pack pairs along last dim: byte = (odd_nibble << 4) | even_nibble + even = nibbles[..., ::2] + odd = nibbles[..., 1::2] + packed_uint8 = (odd << 4) | even + + tensor = packed_uint8.view(torch.float4_e2m1fn_x2) + + if need_perm: + inv_perm = list(range(ndim)) + inv_perm[packed_dim], inv_perm[-1] = inv_perm[-1], inv_perm[packed_dim] + tensor = tensor.permute(inv_perm) + + return tensor + + def _create_scale_tensor(self, shape: tuple) -> torch.Tensor: + """Scale tensor: random values {1, 2} (exact in all scale dtypes).""" + elem_cnt = 1 + for s in shape: + elem_cnt *= s + return ( + torch.randint(1, 3, (elem_cnt,), dtype=torch.float32, device="cuda") + .to(self.cfg["scale_dtype"]) + .reshape(shape) + ) + + def _generate_raw_scales( + self, group_sizes: list[int] + ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + blocksize = self.cfg["blocksize"] + scale_dtype = self.cfg["scale_dtype"] + device = self.A_tensor.device.type if self.A_tensor is not None else "cuda" + + if self.problem.scenario == "2Dx3D": + raw_scale_a = [ + create_raw_scale_tensor( + non_k_size=group_size, + k_size=self.hidden, + blocksize=blocksize, + scale_dtype=scale_dtype, + device=device, + ) + for group_size in group_sizes + ] + raw_scale_b = [ + create_raw_scale_tensor( + non_k_size=self.intermediate, + k_size=self.hidden, + blocksize=blocksize, + scale_dtype=scale_dtype, + device=device, + ) + for _ in range(self.expert_cnt) + ] + else: + raw_scale_a = [ + create_raw_scale_tensor( + non_k_size=self.hidden, + k_size=group_size, + blocksize=blocksize, + scale_dtype=scale_dtype, + device=device, + ) + for group_size in group_sizes + ] + raw_scale_b = [ + create_raw_scale_tensor( + non_k_size=self.intermediate, + k_size=group_size, + blocksize=blocksize, + scale_dtype=scale_dtype, + device=device, + ) + for group_size in group_sizes + ] + + return raw_scale_a, raw_scale_b + + def _assemble_scales_from_raw( + self, raw_scale_a: list[torch.Tensor], raw_scale_b: list[torch.Tensor] + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.problem.scenario == "2Dx3D": + scale_a = assemble_raw_scales_2d3d_2d_side(raw_scale_a) + scale_b = assemble_raw_scales_2d3d_3d_side(raw_scale_b) + else: + scale_a = assemble_raw_scales_2d2d(raw_scale_a, self.hidden) + scale_b = assemble_raw_scales_2d2d(raw_scale_b, self.intermediate) + return scale_a, scale_b + + # ----------------------------------------------------------------- + # generate_inputs + # ----------------------------------------------------------------- + + def generate_inputs(self) -> None: + self.offs_tensor = self._generate_offs() + group_sizes = offs_to_group_sizes(self.offs_tensor) + + tokens = self.tokens_after_repeat + hidden = self.hidden + intermediate = self.intermediate + expert_cnt = self.expert_cnt + blocksize = self.cfg["blocksize"] + is_fp4 = self.cfg["data_dtype"] == torch.float4_e2m1fn_x2 + + if is_fp4: + if self.problem.a_layout != "k_major": + print("WARNING: FP4 ignores a_layout, always uses k_major (K stride-1)") + if self.problem.b_layout != "k_major": + print("WARNING: FP4 ignores b_layout, always uses k_major (K stride-1)") + + if self.problem.scenario == "2Dx3D": + # ── Data tensors ── + # PyTorch: A (tokens, hidden), B (expert_cnt, hidden, intermediate) + # GEMM: A (M=tokens, K=hidden), B (N=intermediate, K=hidden, L=expert_cnt) + + # A: (tokens, hidden) — K=hidden is last dim + if is_fp4: + self.A_tensor = self._create_fp4_tensor((tokens, hidden), packed_dim=-1) + elif self.problem.a_layout == "k_major": + self.A_tensor = self._create_fp8_tensor((tokens, hidden)) + else: # m_major + self.A_tensor = self._create_fp8_tensor((hidden, tokens)).T + + # B: (expert_cnt, hidden, intermediate) — K=hidden is dim 1 + if is_fp4: + self.B_tensor = self._create_fp4_tensor( + (expert_cnt, hidden, intermediate), packed_dim=1 + ) + elif self.problem.b_layout == "k_major": + self.B_tensor = self._create_fp8_tensor( + (expert_cnt, intermediate, hidden) + ).transpose(1, 2) + else: # n_major + self.B_tensor = self._create_fp8_tensor( + (expert_cnt, hidden, intermediate) + ) + + # C: (tokens, intermediate) + # GEMM C (M=tokens, N=intermediate): n_major → N stride-1; m_major → M stride-1 + if self.problem.c_layout == "n_major": + self.C_tensor = torch.full( + (tokens, intermediate), + -1, + dtype=self.problem.out_dtype, + device="cuda", + ) + else: # m_major + self.C_tensor = torch.full( + (intermediate, tokens), + -1, + dtype=self.problem.out_dtype, + device="cuda", + ).T + + # ── Scale tensors ── + K_fixed = hidden + sfa_shape = compute_scale_shape( + "2Dx3D", + "a", + group_sizes, + hidden, + intermediate, + K_fixed, + blocksize, + expert_cnt, + ) + sfb_shape = compute_scale_shape( + "2Dx3D", + "b", + group_sizes, + hidden, + intermediate, + K_fixed, + blocksize, + expert_cnt, + ) + + elif self.problem.scenario == "2Dx2D": + # ── Data tensors ── + # PyTorch: A (hidden, tokens), B (tokens, intermediate) + # GEMM: A (M=hidden, K=tokens), B (N=intermediate, K=tokens, L=expert_cnt) + + # A: (hidden, tokens) — K=tokens is last dim + if is_fp4: + self.A_tensor = self._create_fp4_tensor((hidden, tokens), packed_dim=-1) + elif self.problem.a_layout == "k_major": + self.A_tensor = self._create_fp8_tensor((hidden, tokens)) + else: # m_major + self.A_tensor = self._create_fp8_tensor((tokens, hidden)).T + + # B: (tokens, intermediate) — K=tokens is dim 0 + if is_fp4: + self.B_tensor = self._create_fp4_tensor( + (tokens, intermediate), packed_dim=0 + ) + elif self.problem.b_layout == "k_major": + self.B_tensor = self._create_fp8_tensor((intermediate, tokens)).T + else: # n_major + self.B_tensor = self._create_fp8_tensor((tokens, intermediate)) + + # C: (expert_cnt, hidden, intermediate) + # GEMM C (M=hidden, N=intermediate): n_major → N stride-1; m_major → M stride-1 + if self.problem.c_layout == "n_major": + if self.problem.grad_accumulate: + self.C_tensor = torch.zeros( + (expert_cnt, hidden, intermediate), + dtype=self.problem.out_dtype, + device="cuda", + ) + else: + self.C_tensor = torch.full( + (expert_cnt, hidden, intermediate), + -1, + dtype=self.problem.out_dtype, + device="cuda", + ) + else: # m_major + if self.problem.grad_accumulate: + self.C_tensor = torch.zeros( + (expert_cnt, intermediate, hidden), + dtype=self.problem.out_dtype, + device="cuda", + ).transpose(1, 2) + else: + self.C_tensor = torch.full( + (expert_cnt, intermediate, hidden), + -1, + dtype=self.problem.out_dtype, + device="cuda", + ).transpose(1, 2) + + # ── Scale tensors ── + K_total = tokens + sfa_shape = compute_scale_shape( + "2Dx2D", + "a", + group_sizes, + hidden, + intermediate, + K_total, + blocksize, + expert_cnt, + ) + sfb_shape = compute_scale_shape( + "2Dx2D", + "b", + group_sizes, + hidden, + intermediate, + K_total, + blocksize, + expert_cnt, + ) + else: + raise ValueError(f"Unknown scenario: {self.problem.scenario}") + + self.raw_scale_a_tensors, self.raw_scale_b_tensors = self._generate_raw_scales( + group_sizes + ) + self.scale_a_tensor, self.scale_b_tensor = self._assemble_scales_from_raw( + self.raw_scale_a_tensors, self.raw_scale_b_tensors + ) + assert tuple(self.scale_a_tensor.shape) == tuple(sfa_shape), ( + f"scale_a shape mismatch: expected {sfa_shape}, " + f"got {tuple(self.scale_a_tensor.shape)}" + ) + assert tuple(self.scale_b_tensor.shape) == tuple(sfb_shape), ( + f"scale_b shape mismatch: expected {sfb_shape}, " + f"got {tuple(self.scale_b_tensor.shape)}" + ) + + # NVFP4: per-expert global scales + if self.cfg["has_global_scale"]: + self.global_scale_a = torch.randint( + 1, 3, (expert_cnt,), dtype=torch.float32, device="cuda" + ) + self.global_scale_b = torch.randint( + 1, 3, (expert_cnt,), dtype=torch.float32, device="cuda" + ) + + # ----------------------------------------------------------------- + # Reference preparation + # ----------------------------------------------------------------- + + @staticmethod + def _prepare_ref_ab( + tensor: torch.Tensor, + k_dim: int, + pad_k_size: Optional[int] = None, + pad_non_k_size: Optional[int] = None, + ) -> torch.Tensor: + """Prepare a ref tensor: make ``k_dim`` stride-1 and optionally pad. + + Args: + tensor: input data tensor (A or B). + k_dim: which dimension is K (must become stride-1). + pad_k_size: zero-pad K to this size (workaround: PyTorch 3D + scale validation uses floor division for K // blocksize). + pad_non_k_size: zero-pad the trailing dim (N) to this size + (workaround: PyTorch requires trailing dim % 16 == 0). + Only effective when ``k_dim`` is not the trailing dim. + + All padding happens in the permuted-contiguous space (standard layout) + so it is safe for packed sub-byte types like float4_e2m1fn_x2. + After permute(k_dim↔last), K is last and N is second-to-last: + F.pad(t, (0, k_pad)) -> pads K (last dim) + F.pad(t, (0, 0, 0, n_pad)) -> pads N (second-to-last dim) + The final permute restores the original dim order with K stride-1. + """ + ndim = tensor.dim() + k_dim = k_dim % ndim + needs_k_pad = pad_k_size is not None and pad_k_size > tensor.shape[k_dim] + needs_n_pad = ( + pad_non_k_size is not None + and k_dim != ndim - 1 + and pad_non_k_size > tensor.shape[-1] + ) + if tensor.stride(k_dim) == 1 and not needs_k_pad and not needs_n_pad: + return tensor + print( + f"WARNING: _prepare_ref_ab is copying/padding k_dim={k_dim} " + f"(stride={tensor.stride(k_dim)}, " + f"pad_k={'yes' if needs_k_pad else 'no'}, " + f"pad_n={'yes' if needs_n_pad else 'no'}); " + f"perf comparison with the kernel is not apples-to-apples." + ) + perm = list(range(ndim)) + perm[k_dim], perm[-1] = perm[-1], perm[k_dim] + orig_dtype = tensor.dtype + t = tensor.permute(perm).contiguous() + if needs_k_pad or needs_n_pad: + t = t.view(torch.uint8) + if needs_k_pad: + t = torch.nn.functional.pad(t, (0, pad_k_size - t.shape[-1])) + if needs_n_pad: + t = torch.nn.functional.pad(t, (0, 0, 0, pad_non_k_size - t.shape[-2])) + t = t.view(orig_dtype) + res = t.permute(perm) + return res + + def _prepare_ref_tensors( + self, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Prepare A and B for torch.nn.functional.scaled_grouped_mm. + + The torch API requires K to be stride-1 for both A and B. + For FP8 with non-standard layout, we permute+contiguous. + For FP4, tensors are already created with K stride-1. + + WORKAROUND (two PyTorch bugs in scaled_grouped_mm): + 1. 3D scale validation uses K // blocksize (floor) instead of ceil_div, + producing zero-sized expectations when K < blocksize. + Fix: zero-pad data along K to the next blocksize multiple. + Safe because K is the reduction dimension (zero * scale = zero). + 2. Requires mat_a.size(-1) % 16 == 0 and mat_b.size(-1) % 16 == 0 + regardless of which dimension is stride-1. + Fix: zero-pad B's trailing dim (N=intermediate) to next 16-multiple. + Safe because padded N columns produce zero output columns; the + reference output is sliced back in compute_reference. + """ + blocksize = self.cfg["blocksize"] + # For the torch's incomplete and unreasonable check. + N_padded = round_up(self.problem.intermediate, 16) + + if self.problem.scenario == "2Dx3D": + K_padded = round_up(self.problem.hidden, blocksize) + if self.problem.kind in ["nvfp4", "mxfp4"]: + K_padded = K_padded // 2 + # A: (tokens, hidden) — K=hidden is dim -1 + ref_a = self._prepare_ref_ab(self.A_tensor, k_dim=-1, pad_k_size=K_padded) + # B: (expert_cnt, hidden, intermediate) — K=hidden dim 1, N=intermediate dim -1 + ref_b = self._prepare_ref_ab( + self.B_tensor, k_dim=1, pad_k_size=K_padded, pad_non_k_size=N_padded + ) + else: + # A: (hidden, tokens) — K=tokens is dim -1 + # 2Dx2D: K=total_tokens, already blocksize-aligned by _generate_offs + ref_a = self._prepare_ref_ab(self.A_tensor, k_dim=-1) + # B: (tokens, intermediate) — K=tokens dim 0, N=intermediate dim -1 + ref_b = self._prepare_ref_ab( + self.B_tensor, k_dim=0, pad_non_k_size=N_padded + ) + return ref_a, ref_b + + def _compute_reference_manual_2d2d(self) -> torch.Tensor: + group_sizes = offs_to_group_sizes(self.offs_tensor) + results = [] + prev = 0 + blocksize = self.cfg["blocksize"] + + for expert_idx, group_size in enumerate(group_sizes): + cur = prev + group_size + a_slice = slice_tensor_logical_dim( + self.A_tensor, dim=1, start=prev, end=cur + ) + b_slice = slice_tensor_logical_dim( + self.B_tensor, dim=0, start=prev, end=cur + ) + + global_scale_a = ( + self.global_scale_a[expert_idx : expert_idx + 1] + if self.global_scale_a is not None + else None + ) + global_scale_b = ( + self.global_scale_b[expert_idx : expert_idx + 1] + if self.global_scale_b is not None + else None + ) + + a_fp32 = dequant_block_scale_to_fp32( + a_slice, + self.raw_scale_a_tensors[expert_idx], + blocksize, + global_scale_a, + ) + b_fp32_t = dequant_block_scale_to_fp32( + transpose_rhs_for_block_dequant(b_slice), + self.raw_scale_b_tensors[expert_idx], + blocksize, + global_scale_b, + ) + b_fp32 = b_fp32_t.transpose(0, 1) + results.append((a_fp32 @ b_fp32).to(self.problem.out_dtype)) + prev = cur + + return torch.stack(results, dim=0) + + def _compute_reference_manual_2d3d(self) -> torch.Tensor: + group_sizes = offs_to_group_sizes(self.offs_tensor) + results = [] + prev = 0 + blocksize = self.cfg["blocksize"] + + for expert_idx, group_size in enumerate(group_sizes): + cur = prev + group_size + a_slice = slice_tensor_logical_dim( + self.A_tensor, dim=0, start=prev, end=cur + ) + b_slice = self.B_tensor[expert_idx] + + global_scale_a = ( + self.global_scale_a[expert_idx : expert_idx + 1] + if self.global_scale_a is not None + else None + ) + global_scale_b = ( + self.global_scale_b[expert_idx : expert_idx + 1] + if self.global_scale_b is not None + else None + ) + + a_fp32 = dequant_block_scale_to_fp32( + a_slice, + self.raw_scale_a_tensors[expert_idx], + blocksize, + global_scale_a, + ) + b_fp32_t = dequant_block_scale_to_fp32( + transpose_rhs_for_block_dequant(b_slice), + self.raw_scale_b_tensors[expert_idx], + blocksize, + global_scale_b, + ) + b_fp32 = b_fp32_t.transpose(0, 1) + results.append((a_fp32 @ b_fp32).to(self.problem.out_dtype)) + prev = cur + + return torch.cat(results, dim=0) + + def _compute_reference_manual(self) -> None: + if self.raw_scale_a_tensors is None or self.raw_scale_b_tensors is None: + raise RuntimeError("Raw scale tensors must be generated before manual ref.") + + if self.problem.scenario == "2Dx2D": + self.C_ref_tensor = self._compute_reference_manual_2d2d() + else: + self.C_ref_tensor = self._compute_reference_manual_2d3d() + + def _compute_reference_torch(self) -> None: + from torch.nn.functional import scaled_grouped_mm, ScalingType, SwizzleType + + ref_a, ref_b = self._prepare_ref_tensors() + + if self.problem.kind in ("mxfp8", "mxfp4"): + scale_a_arg = self.scale_a_tensor + scale_b_arg = self.scale_b_tensor + recipe_a = ScalingType.BlockWise1x32 + recipe_b = ScalingType.BlockWise1x32 + else: # nvfp4 + scale_a_arg = [self.scale_a_tensor, self.global_scale_a] + scale_b_arg = [self.scale_b_tensor, self.global_scale_b] + recipe_a = [ScalingType.BlockWise1x16, ScalingType.TensorWise] + recipe_b = [ScalingType.BlockWise1x16, ScalingType.TensorWise] + + swizzle = SwizzleType.SWIZZLE_32_4_4 + ref_result = scaled_grouped_mm( + ref_a, + ref_b, + scale_a=scale_a_arg, + scale_recipe_a=recipe_a, + scale_b=scale_b_arg, + scale_recipe_b=recipe_b, + swizzle_a=swizzle, + swizzle_b=swizzle, + offs=self.offs_tensor, + output_dtype=self.problem.out_dtype, + ) + + self.C_ref_tensor = ref_result[..., : self.problem.intermediate] + + # ----------------------------------------------------------------- + # compute_reference + # ----------------------------------------------------------------- + + def compute_reference(self) -> None: + if self.misc.perf_run: + return + if self.misc.no_torch_210: + self._compute_reference_manual() + else: + self._compute_reference_torch() + + # ----------------------------------------------------------------- + # Kernel execution (stub — to be filled when kernel is implemented) + # ----------------------------------------------------------------- + + def create_kernel(self) -> ScaledGroupedGemmKernel: + _torch_to_cutlass = { + torch.float32: cutlass.Float32, + torch.bfloat16: cutlass.BFloat16, + torch.float16: cutlass.Float16, + } + return ScaledGroupedGemmKernel( + scenario=self.problem.scenario, + sf_vec_size=self.cfg["blocksize"], + accumulate_on_output=( + self.problem.grad_accumulate and self.problem.scenario == "2Dx2D" + ), + separate_tensormap_init=self.impl.separate_tensormap_init, + consistent_token_padding=self.problem.consistent_token_padding, + acc_dtype=_torch_to_cutlass[self.problem.acc_dtype], + mma_tiler_mnk=self.impl.mma_tiler_mnk, + cluster_shape_mnk=self.impl.cluster_shape_mnk, + use_2cta_instrs=self.impl.use_2cta_instrs, + fixed_expert_cnt=self.impl.static_expert_cnt, + ) + + def run_kernel(self, kernel: ScaledGroupedGemmKernel) -> Optional[float]: + """Run our CuTe kernel. + + Returns: + Average kernel time in ms when perf_e2e is enabled, None otherwise. + """ + _torch_to_cutlass = { + torch.float32: cutlass.Float32, + torch.bfloat16: cutlass.BFloat16, + torch.float16: cutlass.Float16, + torch.float8_e4m3fn: cutlass.Float8E4M3FN, + torch.float8_e5m2: cutlass.Float8E5M2, + torch.float4_e2m1fn_x2: cutlass.Float4E2M1FN, + } + if hasattr(torch, "float8_e8m0fnu"): + _torch_to_cutlass[torch.float8_e8m0fnu] = cutlass.Float8E8M0FNU + + # Allocate workspace + workspace_size = kernel.get_workspace_size(self.expert_cnt) + self.workspace_tensor = torch.full( + (workspace_size,), 255, dtype=torch.uint8, device="cuda" + ) + torch.cuda.synchronize() + + # Convert torch tensors → cute tensors + data_dtype = _torch_to_cutlass[self.cfg["data_dtype"]] + sf_cutlass_dtype = _torch_to_cutlass[self.cfg["scale_dtype"]] + out_cutlass_dtype = _torch_to_cutlass[self.problem.out_dtype] + + is_dynamic_expert_cnt = self.impl.static_expert_cnt is None + + def torch_tensor_to_cute_tensor_with_dyn_layout( + torch_tensor: torch.Tensor, + ) -> cute.Tensor: + cute_tensor = cutlass_torch.from_dlpack(torch_tensor) + leading_dim = cutlass_torch.get_leading_dim(torch_tensor) + cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=leading_dim) + return cute_tensor + + a_cute = torch_tensor_to_cute_tensor_with_dyn_layout(self.A_tensor) + b_cute = torch_tensor_to_cute_tensor_with_dyn_layout(self.B_tensor) + scale_a_cute = torch_tensor_to_cute_tensor_with_dyn_layout(self.scale_a_tensor) + scale_b_cute = torch_tensor_to_cute_tensor_with_dyn_layout(self.scale_b_tensor) + c_cute = torch_tensor_to_cute_tensor_with_dyn_layout(self.C_tensor) + offs_cute = torch_tensor_to_cute_tensor_with_dyn_layout(self.offs_tensor) + workspace_cute = torch_tensor_to_cute_tensor_with_dyn_layout( + self.workspace_tensor + ) + + # Query max active clusters from hardware + cluster_size = self.impl.cluster_shape_mnk[0] * self.impl.cluster_shape_mnk[1] + max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size) + + # Prepare optional NVFP4 global scales + global_scale_a_cute = None + global_scale_b_cute = None + if self.global_scale_a is not None: + global_scale_a_cute = torch_tensor_to_cute_tensor_with_dyn_layout( + self.global_scale_a + ) + global_scale_b_cute = torch_tensor_to_cute_tensor_with_dyn_layout( + self.global_scale_b + ) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + if self.misc.perf_e2e: + compiled = cute.compile( + kernel, + a_cute, + b_cute, + scale_a_cute, + scale_b_cute, + c_cute, + offs_cute, + workspace_cute, + max_active_clusters, + stream, + global_scale_a=global_scale_a_cute, + global_scale_b=global_scale_b_cute, + ) + + warmup_iters = 4 + timed_iters = 4 + + for _ in range(warmup_iters): + compiled( + a_cute, + b_cute, + scale_a_cute, + scale_b_cute, + c_cute, + offs_cute, + workspace_cute, + stream, + global_scale_a=global_scale_a_cute, + global_scale_b=global_scale_b_cute, + ) + torch.cuda.synchronize() + + times = [] + for _ in range(timed_iters): + l2_flush() + torch.cuda.synchronize() + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + start_evt.record() + compiled( + a_cute, + b_cute, + scale_a_cute, + scale_b_cute, + c_cute, + offs_cute, + workspace_cute, + stream, + global_scale_a=global_scale_a_cute, + global_scale_b=global_scale_b_cute, + ) + end_evt.record() + torch.cuda.synchronize() + times.append(start_evt.elapsed_time(end_evt)) + + avg_ms = sum(times) / len(times) + print(f"[perf_e2e] Individual times (ms): {[f'{t:.4f}' for t in times]}") + print(f"[perf_e2e] Average kernel time: {avg_ms:.4f} ms") + return avg_ms + else: + l2_flush() + kernel( + a_cute, + b_cute, + scale_a_cute, + scale_b_cute, + c_cute, + offs_cute, + workspace_cute, + max_active_clusters, + stream, + global_scale_a=global_scale_a_cute, + global_scale_b=global_scale_b_cute, + ) + torch.cuda.synchronize() + return None + + # ----------------------------------------------------------------- + # Validation + # ----------------------------------------------------------------- + + def validate(self) -> None: + if self.misc.perf_run: + return + using_torch_ref = not self.misc.no_torch_210 + if using_torch_ref and self.problem.scenario == "2Dx2D": + # Pytorch bug: zero token does not write out due to the incorrect arg setting. + self.C_ref_tensor = self.C_ref_tensor.contiguous() + group_sizes = offs_to_group_sizes(self.offs_tensor) + for i, g in enumerate(group_sizes): + if g == 0: + self.C_ref_tensor[i].zero_() + if using_torch_ref and ( + self.problem.scenario == "2Dx3D" + and self.tokens_after_repeat // self.expert_cnt == 0 + ): + print( + "Warning: Due to the Pytorch 2.10 FBGEMM bug (incorrect `M/G` early exit), ref tensor will be all 0 in this case, skip ref check." + ) + return + try: + diff = (self.C_tensor - self.C_ref_tensor).float().abs() + max_diff = diff.max().item() + if max_diff == 0.0: + print("Validation PASSED (exact match)") + else: + print( + f"C_tensor: shape={tuple(self.C_tensor.shape)} " + f"stride={self.C_tensor.stride()} dtype={self.C_tensor.dtype}" + ) + print( + f"C_ref_tensor: shape={tuple(self.C_ref_tensor.shape)} " + f"stride={self.C_ref_tensor.stride()} dtype={self.C_ref_tensor.dtype}" + ) + print( + f"Validation FAILED: " + f"max_diff={max_diff} " + f"mean_diff={diff.mean().item()}" + ) + assert False, "C_tensor != C_ref_tensor" + except torch.cuda.OutOfMemoryError: + print("OOM during diff computation, falling back to torch.equal") + assert torch.equal(self.C_tensor, self.C_ref_tensor), ( + "C_tensor != C_ref_tensor" + ) + + # ----------------------------------------------------------------- + # SOL comparison + # ----------------------------------------------------------------- + + def run_sol_comparison(self) -> None: + """Run a dense batched block-scaled GEMM as Speed-of-Light reference. + + Reuses the same tensor memory from the grouped run by passing + raw pointers with a batched problem_mnkl -- zero GPU allocation. + """ + import sys, os + + _examples_root = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..", "..") + ) + if _examples_root not in sys.path: + sys.path.insert(0, _examples_root) + + from blackwell.kernel.blockscaled_gemm.dense_blockscaled_gemm_persistent import ( + Sm100BlockScaledPersistentDenseGemmKernel, + ) + from cutlass.cute.nvgpu import OperandMajorMode + from cutlass.cute.runtime import make_ptr + + tokens = self.tokens_after_repeat + experts = self.expert_cnt + blocksize = self.cfg["blocksize"] + n_slots = tokens // blocksize + assert tokens % blocksize == 0 and n_slots % experts == 0, ( + f"compare_with_sol requires tokens*top_k ({tokens}) to be " + f"divisible by blocksize ({blocksize}), and the resulting " + f"n_slots ({n_slots}) evenly divisible by experts ({experts}) " + f"so every group has exactly the same size" + ) + tpe = tokens // experts + + if self.problem.scenario == "2Dx3D": + M, N, K, L = tpe, self.intermediate, self.hidden, experts + else: # 2Dx2D + M, N, K, L = self.hidden, self.intermediate, tpe, experts + + # Dtype mapping + _torch_to_cutlass = { + torch.float32: cutlass.Float32, + torch.bfloat16: cutlass.BFloat16, + torch.float16: cutlass.Float16, + torch.float8_e4m3fn: cutlass.Float8E4M3FN, + torch.float8_e5m2: cutlass.Float8E5M2, + torch.float4_e2m1fn_x2: cutlass.Float4E2M1FN, + } + if hasattr(torch, "float8_e8m0fnu"): + _torch_to_cutlass[torch.float8_e8m0fnu] = cutlass.Float8E8M0FNU + + data_dtype = _torch_to_cutlass[self.cfg["data_dtype"]] + sf_dtype = _torch_to_cutlass[self.cfg["scale_dtype"]] + out_dtype = _torch_to_cutlass[self.problem.out_dtype] + + # Layout mapping + a_major = ( + OperandMajorMode.K + if self.problem.a_layout == "k_major" + else OperandMajorMode.MN + ) + b_major = ( + OperandMajorMode.K + if self.problem.b_layout == "k_major" + else OperandMajorMode.MN + ) + c_layout = ( + utils.LayoutEnum.ROW_MAJOR + if self.problem.c_layout == "n_major" + else utils.LayoutEnum.COL_MAJOR + ) + layouts = (a_major, b_major, c_layout) + + # Construct pointers from existing grouped tensors + a_ptr = make_ptr( + data_dtype, + self.A_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + b_ptr = make_ptr( + data_dtype, + self.B_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + sfa_ptr = make_ptr( + sf_dtype, + self.scale_a_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ) + sfb_ptr = make_ptr( + sf_dtype, + self.scale_b_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32, + ) + c_ptr = make_ptr( + out_dtype, + self.C_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + + mma_tiler_mn = self.impl.mma_tiler_mnk[:2] + cluster_shape_mn = self.impl.cluster_shape_mnk[:2] + cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1] + max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size) + + sol_kernel = Sm100BlockScaledPersistentDenseGemmKernel( + sf_vec_size=self.cfg["blocksize"], + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + ) + + problem_mnkl = ( + cutlass.Int32(M), + cutlass.Int32(N), + cutlass.Int32(K), + cutlass.Int32(L), + ) + + print(f"\n[SOL] Dense block-scaled BMM: M={M} N={N} K={K} L={L}") + print(f"[SOL] kind={self.problem.kind} sf_vec_size={self.cfg['blocksize']}") + + l2_flush() + sol_kernel( + a_ptr, + b_ptr, + sfa_ptr, + sfb_ptr, + c_ptr, + layouts, + problem_mnkl, + max_active_clusters, + cuda.CUstream(torch.cuda.current_stream().cuda_stream), + ) + torch.cuda.synchronize() + + # ----------------------------------------------------------------- + # Run + # ----------------------------------------------------------------- + + def run(self) -> None: + print(self.problem) + print(self.impl) + print(self.misc) + + self.generate_inputs() + + group_sizes = offs_to_group_sizes(self.offs_tensor) + print( + f"A: shape={tuple(self.A_tensor.shape)} " + f"stride={self.A_tensor.stride()} dtype={self.A_tensor.dtype}" + ) + print( + f"B: shape={tuple(self.B_tensor.shape)} " + f"stride={self.B_tensor.stride()} dtype={self.B_tensor.dtype}" + ) + print( + f"C: shape={tuple(self.C_tensor.shape)} " + f"stride={self.C_tensor.stride()} dtype={self.C_tensor.dtype}" + ) + print( + f"scale_a: shape={tuple(self.scale_a_tensor.shape)} " + f"stride={self.scale_a_tensor.stride()} dtype={self.scale_a_tensor.dtype}" + ) + print( + f"scale_b: shape={tuple(self.scale_b_tensor.shape)} " + f"stride={self.scale_b_tensor.stride()} dtype={self.scale_a_tensor.dtype}" + ) + if self.cfg["has_global_scale"]: + print(f"global_scale_a: {self.global_scale_a.cpu().tolist()}") + print(f"global_scale_b: {self.global_scale_b.cpu().tolist()}") + print(f"offs: {self.offs_tensor.cpu().tolist()} group_sizes={group_sizes}") + + kernel = self.create_kernel() + + if self.misc.perf_e2e: + self.run_kernel(kernel) + else: + from torch.profiler import profile, ProfilerActivity + + with profile( + activities=[ProfilerActivity.CUDA], record_shapes=True + ) as prof: + self.compute_reference() + self.run_kernel(kernel) + if ( + self.misc.compare_with_sol + and self.misc.perf_run + and self.problem.balance_route + ): + self.run_sol_comparison() + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20)) + + self.validate() + print("PASS") + + +# ============================================================================= +# CLI entry point +# ============================================================================= + +if __name__ == "__main__": + import argparse + + def parse_tuple(s: str) -> Tuple[int, ...]: + return tuple(int(x) for x in s.split(",")) + + parser = argparse.ArgumentParser( + description="Scaled Grouped GEMM for MoE (MXFP8 / MXFP4 / NVFP4)", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # ── Problem ── + parser.add_argument("--tokens", type=int, default=128) + parser.add_argument("--experts", type=int, default=4) + parser.add_argument("--top_k_select", type=int, default=2) + parser.add_argument("--balance_route", action="store_true", default=False) + parser.add_argument("--hidden", type=int, default=512) + parser.add_argument("--intermediate", type=int, default=384) + parser.add_argument( + "--scenario", type=str, default="2Dx3D", choices=["2Dx3D", "2Dx2D"] + ) + parser.add_argument( + "--kind", type=str, default="mxfp8", choices=["mxfp8", "mxfp4", "nvfp4"] + ) + parser.add_argument("--out_dtype", type=str, default="bfloat16") + parser.add_argument("--acc_dtype", type=str, default="float32") + parser.add_argument("--grad_accumulate", action="store_true", default=False) + parser.add_argument( + "--consistent_token_padding", action="store_true", default=False + ) + parser.add_argument( + "--a_layout", type=str, default="k_major", choices=["k_major", "m_major"] + ) + parser.add_argument( + "--b_layout", type=str, default="k_major", choices=["k_major", "n_major"] + ) + parser.add_argument( + "--c_layout", type=str, default="n_major", choices=["n_major", "m_major"] + ) + + # ── Impl ── + parser.add_argument("--mma_tiler_mnk", type=str, default="128,128,128") + parser.add_argument("--cluster_shape_mnk", type=str, default="1,1,1") + parser.add_argument("--use_2cta_instrs", action="store_true", default=False) + parser.add_argument("--static_expert_cnt", type=int, default=None) + parser.add_argument("--separate_tensormap_init", action="store_true", default=True) + + # ── Misc ── + parser.add_argument("--perf_run", action="store_true", default=False) + parser.add_argument("--perf_e2e", action="store_true", default=False) + parser.add_argument("--compare_with_sol", action="store_true", default=False) + + args = parser.parse_args() + + if args.consistent_token_padding: + print( + "WARNING: Overriding consistent_token_padding to False " + "(not implemented yet)." + ) + args.consistent_token_padding = False + + problem = ProblemDesc( + tokens=args.tokens, + experts=args.experts, + top_k_select=args.top_k_select, + balance_route=args.balance_route, + hidden=args.hidden, + intermediate=args.intermediate, + scenario=args.scenario, + kind=args.kind, + out_dtype=getattr(torch, args.out_dtype), + acc_dtype=getattr(torch, args.acc_dtype), + grad_accumulate=args.grad_accumulate, + consistent_token_padding=args.consistent_token_padding, + a_layout=args.a_layout, + b_layout=args.b_layout, + c_layout=args.c_layout, + ) + + if not args.separate_tensormap_init: + print( + "Overriding separate_tensormap_init to True " + "(fused version not implemented yet)." + ) + args.separate_tensormap_init = True + + impl = ImplDesc( + mma_tiler_mnk=parse_tuple(args.mma_tiler_mnk), + cluster_shape_mnk=parse_tuple(args.cluster_shape_mnk), + use_2cta_instrs=args.use_2cta_instrs, + static_expert_cnt=args.static_expert_cnt, + separate_tensormap_init=args.separate_tensormap_init, + ) + misc = MiscDesc( + perf_run=args.perf_run, + perf_e2e=args.perf_e2e, + compare_with_sol=args.compare_with_sol, + ) + + tester = ScaledGroupedGemmTester(problem, impl, misc) + tester.run() + print("DONE")