# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # 1. Redistributions of source code must retain the above copyright notice, this # list of conditions and the following disclaimer. # 2. Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # 3. Neither the name of the copyright holder nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import argparse import functools from typing import List, Type, Tuple, Union from inspect import isclass import torch import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute from cutlass.cute.nvgpu import cpasync, tcgen05 import cutlass.torch as cutlass_torch import cutlass.utils as utils import cutlass.pipeline as pipeline from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait import cutlass.utils.blackwell_helpers as sm100_utils import cutlass.utils.blockscaled_layout as blockscaled_utils from cutlass.cute.runtime import from_dlpack """ This example provides an experimental implementation of the SM100 grouped blockscaled GEMM kernel, please note that the APIs and implementation details related to this kernel may change in future releases. A grouped blockscaled GEMM example for the NVIDIA Blackwell SM100 architecture using CUTE DSL This example demonstrates an implementation of grouped blockscaled GEMM using a TMA plus Blackwell SM100 TensorCore warp-specialized persistent kernel. The grouped GEMM workload computes a batch of GEMM operations with distinct problem sizes. Pointers to matrices in global memory are passed to the kernel in an array (also held in global memory). Similarly, problem shapes and strides are also stored in arrays in GMEM. This differs from "Batched Array" GEMM since the size of each GEMM problem in the grouped GEMM concept may be distinct. To run this example: .. code-block:: bash python examples/blackwell/grouped_blockscaled_gemm.py \ --ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU --sf_vec_size 16 \ --c_dtype Float16 \ --mma_tiler_mn 128,128 --cluster_shape_mn 1,1 \ --problem_sizes_mnkl "(8192,1280,32,1),(32,384,1536,1),(640,1280,32,1),(640,160,32,1)" \ --num_groups 4 The above example command makes 4 groups of different m, n, k sizes. The Blackwell tcgen05 MMA tile shape is specified as (128, 64) and the cluster shape is (1,1). The input, mma accumulator and output data type are set as fp16, fp32 and fp16, respectively. To collect performance with NCU profiler: .. code-block:: bash ncu python examples/blackwell/grouped_blockscaled_gemm.py \ --ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU --sf_vec_size 16 \ --c_dtype Float16 \ --mma_tiler_mn 128,128 --cluster_shape_mn 1,1 \ --problem_sizes_mnkl "(8192,1280,32,1),(32,384,1536,1),(640,1280,32,1),(640,160,32,1)" \ --num_groups 4 --warmup_iterations 1 --iterations 10 --skip_ref_check Constraints: * Supported input data types: mxf8, mxf4, nvf4 see detailed valid dtype combinations in below Sm100GroupedBlockScaledGemmKernel class documentation * A/B tensors must have the same data type, mixed data type is not supported (e.g., mxf8 x mxf4) * Mma tiler M must be 128 or 256(use_2cta_instrs) * Mma tiler N must be 128 or 256 * Cluster shape M/N must be positive and power of 2, total cluster size <= 16 * Cluster shape M/N must be <= 4 for scale factor multicasts due to limited size of scale factors * Cluster shape M must be multiple of 2 if Mma tiler M is 256(use_2cta_instrs) * The l mode(aka, batch size) for each group must be 1. * The majorness for A, B and C must be the same across all groups. * The contiguous dimension of A/B/C tensors in each group must be at least 16 bytes aligned, i.e, number of elements is a multiple of 16 and 32 for Float8 and Float4, respectively. """ class Sm100GroupedBlockScaledGemmKernel: """This example demonstrates an implementation of grouped blockscaled GEMM using a TMA plus Blackwell SM100 TensorCore warp-specialized persistent kernel. :param sf_vec_size: Scalefactor vector size. :type sf_vec_size: int :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N) :type mma_tiler_mn: Tuple[int, int] :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing :type cluster_shape_mn: Tuple[int, int] :note: In current version, A and B tensors must have the same data type - i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported :note: Supported combinations of A/B data types, SF data typs and SF vector size: - MXF8: A/B: Float8E5M2/Float8E4M3FN + SF: Float8E8M0FNU + sf_vec_size: 32 - MXF4: A/B: Float4E2M1FN + SF: Float8E8M0FNU + sf_vec_size: 32 - NVF4: A/B: Float4E2M1FN + SF: Float8E8M0FNU/Float8E4M3FN + sf_vec_size: 16 :note: Supported accumulator data types: - Float32 :note: Supported C data types: - Float32 - Float16/BFloat16 - Float8E4M3FN/Float8E5M2 :note: Constraints: - MMA tiler M must be 128 or 256 (use_2cta_instrs) - MMA tiler N must be 128/256 - Cluster shape M must be multiple of 2 if Mma tiler M is 256 - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 - Cluster shape M/N must be <= 4 for scale factor multicasts due to limited size of scale factors """ def __init__( self, sf_vec_size: int, mma_tiler_mn: Tuple[int, int], cluster_shape_mn: Tuple[int, int], ): """Initializes the configuration for a Blackwell grouped blockscaled GEMM kernel. Besides configurations for dense persistent blockscaled GEMM, there is an extra config specific to grouped blockscaled GEMM: :param sf_vec_size: Scalefactor vector size. :type sf_vec_size: int :param mma_tiler_mn: tuple (M, N) shape of the MMA instruction. :type mma_tiler_mn: tuple[int, int] :param cluster_shape_mn: tuple (ClusterM, ClusterN) shape of the cluster. :type cluster_shape_mn: tuple[int, int] """ self.acc_dtype = cutlass.Float32 self.sf_vec_size = sf_vec_size self.use_2cta_instrs = mma_tiler_mn[0] == 256 self.cluster_shape_mn = cluster_shape_mn # K dimension is deferred in _setup_attributes self.mma_tiler = (*mma_tiler_mn, 1) self.cta_group = ( tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE ) self.tensormap_update_mode = utils.TensorMapUpdateMode.SMEM self.occupancy = 1 # Set specialized warp ids self.epilog_warp_id = ( 0, 1, 2, 3, ) self.mma_warp_id = 4 self.tma_warp_id = 5 self.threads_per_cta = 32 * len( (self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id) ) # Set barrier for epilogue sync and tmem ptr sync self.epilog_sync_barrier = pipeline.NamedBarrier( barrier_id=1, num_threads=32 * len(self.epilog_warp_id), ) self.tmem_alloc_barrier = pipeline.NamedBarrier( barrier_id=2, num_threads=32 * len((self.mma_warp_id, *self.epilog_warp_id)), ) # Barrier used by MMA/TMA warps to signal A/B tensormap initialization completion self.tensormap_ab_init_barrier = pipeline.NamedBarrier( barrier_id=3, num_threads=64, ) self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") self.num_tmem_alloc_cols = cute.arch.get_max_tmem_alloc_cols("sm_100") # Set up configurations that dependent on gemm inputs. def _setup_attributes(self): """Set up configurations that are dependent on GEMM inputs This method configures various attributes based on the input tensor properties (data types, leading dimensions) and kernel settings: - Configuring tiled MMA - Computing MMA/cluster/tile shapes - Computing cluster layout - Computing multicast CTAs for A/B/SFA/SFB - Computing epilogue subtile - Setting up A/B/SFA/SFB/C stage counts in shared memory - Computing A/B/SFA/SFB/C shared memory layout - Checking reserved smem bytes size capacity for mbar, tensor memory management and tensormap updates utilization """ # Compute mma instruction shapes # (MMA_Tile_Shape_M, MMA_Tile_Shape_N, MMA_Inst_Shape_K) self.mma_inst_shape_mn = ( self.mma_tiler[0], self.mma_tiler[1], ) # (CTA_Tile_Shape_M, Round_Up(MMA_Tile_Shape_N, 128), MMA_Inst_Shape_K) self.mma_inst_shape_mn_sfb = ( self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1), cute.round_up(self.mma_inst_shape_mn[1], 128), ) tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( self.a_dtype, self.a_major_mode, self.b_major_mode, self.sf_dtype, self.sf_vec_size, self.cta_group, self.mma_inst_shape_mn, ) tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( self.a_dtype, self.a_major_mode, self.b_major_mode, self.sf_dtype, self.sf_vec_size, cute.nvgpu.tcgen05.CtaGroup.ONE, self.mma_inst_shape_mn_sfb, ) # Compute mma/cluster/tile shapes mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) mma_inst_tile_k = 4 self.mma_tiler = ( self.mma_inst_shape_mn[0], self.mma_inst_shape_mn[1], mma_inst_shape_k * mma_inst_tile_k, ) self.mma_tiler_sfb = ( self.mma_inst_shape_mn_sfb[0], self.mma_inst_shape_mn_sfb[1], mma_inst_shape_k * mma_inst_tile_k, ) self.cta_tile_shape_mnk = ( self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), self.mma_tiler[1], self.mma_tiler[2], ) self.cluster_tile_shape_mnk = tuple( x * y for x, y in zip(self.cta_tile_shape_mnk, (*self.cluster_shape_mn, 1)) ) # Compute cluster layout self.cluster_layout_vmnk = cute.tiled_divide( cute.make_layout((*self.cluster_shape_mn, 1)), (tiled_mma.thr_id.shape,), ) self.cluster_layout_sfb_vmnk = cute.tiled_divide( cute.make_layout((*self.cluster_shape_mn, 1)), (tiled_mma_sfb.thr_id.shape,), ) # Compute number of multicast CTAs for A/B self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) self.num_mcast_ctas_sfb = cute.size(self.cluster_layout_sfb_vmnk.shape[1]) self.is_a_mcast = self.num_mcast_ctas_a > 1 self.is_b_mcast = self.num_mcast_ctas_b > 1 self.is_sfb_mcast = self.num_mcast_ctas_sfb > 1 # Compute epilogue subtile self.epi_tile = sm100_utils.compute_epilogue_tile_shape( self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.c_dtype, ) # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory self.num_acc_stage, self.num_ab_stage, self.num_c_stage = self._compute_stages( tiled_mma, self.mma_tiler, self.a_dtype, self.b_dtype, self.epi_tile, self.c_dtype, self.c_layout, self.sf_dtype, self.sf_vec_size, self.smem_capacity, self.occupancy, ) # Compute A/B/SFA/SFB/C shared memory layout self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( tiled_mma, self.mma_tiler, self.a_dtype, self.num_ab_stage, ) self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( tiled_mma, self.mma_tiler, self.b_dtype, self.num_ab_stage, ) self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( tiled_mma, self.mma_tiler, self.sf_vec_size, self.num_ab_stage, ) self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( tiled_mma, self.mma_tiler, self.sf_vec_size, self.num_ab_stage, ) self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( self.c_dtype, self.c_layout, self.epi_tile, self.num_c_stage, ) mbar_smem_bytes = self._get_mbar_smem_bytes( num_acc_stage=self.num_acc_stage, num_ab_stage=self.num_ab_stage, num_c_stage=self.num_c_stage, ) # Use utils.TensorMapUpdateMode.SMEM by default tensormap_smem_bytes = ( Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap * Sm100GroupedBlockScaledGemmKernel.num_tensormaps ) if ( mbar_smem_bytes + tensormap_smem_bytes + Sm100GroupedBlockScaledGemmKernel.tensor_memory_management_bytes > self.reserved_smem_bytes ): raise ValueError( f"smem consumption for mbar and tensormap {mbar_smem_bytes + tensormap_smem_bytes} exceeds the " f"reserved smem bytes {self.reserved_smem_bytes}" ) @cute.jit def __call__( self, initial_a: cute.Tensor, initial_b: cute.Tensor, initial_c: cute.Tensor, initial_sfa: cute.Tensor, initial_sfb: cute.Tensor, group_count: cutlass.Constexpr[int], problem_shape_mnkl: cute.Tensor, strides_abc: cute.Tensor, tensor_address_abc: cute.Tensor, tensor_address_sfasfb: cute.Tensor, total_num_clusters: cutlass.Constexpr[int], tensormap_cute_tensor: cute.Tensor, max_active_clusters: cutlass.Constexpr[int], stream: cuda.CUstream, ): """Execute the GEMM operation in steps: - Setup static attributes before smem/grid/tma computation - Setup TMA load/store atoms and tensors - Compute grid size with regard to hardware constraints - Define shared storage for kernel - Launch the kernel synchronously For grouped GEMM, tensor shapes, tensor strides, and tensor address are all provided by different tensors in global memory. The "initial" tensors only carry data type and majorness information. :param initial_a: Initial tensor A, used for data type and majorness information. :type initial_a: cute.Tensor :param initial_b: Initial tensor B, used for data type and majorness information. :type initial_b: cute.Tensor :param initial_c: Initial tensor C, used for data type and majorness information. :type initial_c: cute.Tensor :param initial_sfa: Initial tensor SFA, used for data type and majorness information. :type initial_sfa: cute.Tensor :param initial_sfb: Initial tensor SFB, used for data type and majorness information. :type initial_sfb: cute.Tensor :param group_count: The number of GEMM groups. :type group_count: cutlass.Constexpr[int] :param problem_shape_mnkl: Tensor containing the (M, N, K, L) shape for each group. :type problem_shape_mnkl: cute.Tensor :param strides_abc: Tensor containing the strides for A, B, and C for each group. :type strides_abc: cute.Tensor :param tensor_address_abc: Tensor containing the base addresses for A, B, and C for each group. :type tensor_address_abc: cute.Tensor :param tensor_address_sfasfb: Tensor containing the base addresses for SFA and SFB for each group. :type tensor_address_sfasfb: cute.Tensor :param total_num_clusters: Total number of clusters needed for all groups. :type total_num_clusters: cutlass.Constexpr[int] :param tensormap_cute_tensor: Tensor for storing tensormaps. :type tensormap_cute_tensor: cute.Tensor :param max_active_clusters: Maximum number of active clusters. :type max_active_clusters: cutlass.Constexpr[int] :param stream: CUDA stream for asynchronous execution. :type stream: cuda.CUstream :raises TypeError: If A and B data types do not match. """ self.a_dtype = initial_a.element_type self.b_dtype = initial_b.element_type self.sf_dtype = initial_sfa.element_type self.c_dtype = initial_c.element_type self.is_nvfp4_output = self.c_dtype is cutlass.Float4E2M1FN self.a_major_mode = utils.LayoutEnum.from_tensor(initial_a).mma_major_mode() self.b_major_mode = utils.LayoutEnum.from_tensor(initial_b).mma_major_mode() self.c_layout = utils.LayoutEnum.from_tensor(initial_c) if cutlass.const_expr(self.a_dtype != self.b_dtype): raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}") # Setup attributes that dependent on gemm inputs self._setup_attributes() # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL) sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( initial_a.shape, self.sf_vec_size ) initial_sfa = cute.make_tensor(initial_sfa.iterator, sfa_layout) # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( initial_b.shape, self.sf_vec_size ) initial_sfb = cute.make_tensor(initial_sfb.iterator, sfb_layout) tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( self.a_dtype, self.a_major_mode, self.b_major_mode, self.sf_dtype, self.sf_vec_size, self.cta_group, self.mma_inst_shape_mn, ) tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( self.a_dtype, self.a_major_mode, self.b_major_mode, self.sf_dtype, self.sf_vec_size, cute.nvgpu.tcgen05.CtaGroup.ONE, self.mma_inst_shape_mn_sfb, ) atom_thr_size = cute.size(tiled_mma.thr_id.shape) # Setup TMA load for A a_op = sm100_utils.cluster_shape_to_tma_atom_A( self.cluster_shape_mn, tiled_mma.thr_id ) a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( a_op, initial_a, a_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape, ) # Setup TMA load for B b_op = sm100_utils.cluster_shape_to_tma_atom_B( self.cluster_shape_mn, tiled_mma.thr_id ) b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( b_op, initial_b, b_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape, ) # Setup TMA load for SFA sfa_op = sm100_utils.cluster_shape_to_tma_atom_A( self.cluster_shape_mn, tiled_mma.thr_id ) sfa_smem_layout = cute.slice_( self.sfa_smem_layout_staged, (None, None, None, 0) ) tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( sfa_op, initial_sfa, sfa_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape, internal_type=cutlass.Int16, ) # Setup TMA load for SFB sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB( self.cluster_shape_mn, tiled_mma.thr_id ) sfb_smem_layout = cute.slice_( self.sfb_smem_layout_staged, (None, None, None, 0) ) tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( sfb_op, initial_sfb, sfb_smem_layout, self.mma_tiler_sfb, tiled_mma_sfb, self.cluster_layout_sfb_vmnk.shape, internal_type=cutlass.Int16, ) a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout) sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) self.num_tma_load_bytes = ( a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size ) * atom_thr_size # Setup TMA store for C epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( cpasync.CopyBulkTensorTileS2GOp(), initial_c, epi_smem_layout, self.epi_tile, ) # Compute grid size self.tile_sched_params, grid = self._compute_grid( total_num_clusters, self.cluster_shape_mn, max_active_clusters ) self.buffer_align_bytes = 1024 self.size_tensormap_in_i64 = ( Sm100GroupedBlockScaledGemmKernel.num_tensormaps * Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap // 8 ) # Define shared storage for kernel @cute.struct class SharedStorage: tensormap_buffer: cute.struct.MemRange[ cutlass.Int64, self.size_tensormap_in_i64 ] ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] tmem_dealloc_mbar: cutlass.Int64 tmem_holding_buf: cutlass.Int32 # (EPI_TILE_M, EPI_TILE_N, STAGE) sC: cute.struct.Align[ cute.struct.MemRange[ self.c_dtype, cute.cosize(self.c_smem_layout_staged.outer), ], self.buffer_align_bytes, ] # (MMA, MMA_M, MMA_K, STAGE) sA: cute.struct.Align[ cute.struct.MemRange[ self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) ], self.buffer_align_bytes, ] # (MMA, MMA_N, MMA_K, STAGE) sB: cute.struct.Align[ cute.struct.MemRange[ self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) ], self.buffer_align_bytes, ] # (MMA, MMA_M, MMA_K, STAGE) sSFA: cute.struct.Align[ cute.struct.MemRange[ self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged) ], self.buffer_align_bytes, ] # (MMA, MMA_N, MMA_K, STAGE) sSFB: cute.struct.Align[ cute.struct.MemRange[ self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged) ], self.buffer_align_bytes, ] self.shared_storage = SharedStorage # Launch the kernel synchronously self.kernel( tiled_mma, tiled_mma_sfb, tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b, tma_atom_sfa, tma_tensor_sfa, tma_atom_sfb, tma_tensor_sfb, tma_atom_c, tma_tensor_c, self.cluster_layout_vmnk, self.cluster_layout_sfb_vmnk, self.a_smem_layout_staged, self.b_smem_layout_staged, self.sfa_smem_layout_staged, self.sfb_smem_layout_staged, self.c_smem_layout_staged, self.epi_tile, self.tile_sched_params, group_count, problem_shape_mnkl, strides_abc, tensor_address_abc, tensor_address_sfasfb, tensormap_cute_tensor, ).launch( grid=grid, block=[self.threads_per_cta, 1, 1], cluster=(*self.cluster_shape_mn, 1), stream=stream, min_blocks_per_mp=1, ) return # GPU device kernel @cute.kernel def kernel( self, tiled_mma: cute.TiledMma, tiled_mma_sfb: cute.TiledMma, tma_atom_a: cute.CopyAtom, mA_mkl: cute.Tensor, tma_atom_b: cute.CopyAtom, mB_nkl: cute.Tensor, tma_atom_sfa: cute.CopyAtom, mSFA_mkl: cute.Tensor, tma_atom_sfb: cute.CopyAtom, mSFB_nkl: cute.Tensor, tma_atom_c: cute.CopyAtom, mC_mnl: cute.Tensor, cluster_layout_vmnk: cute.Layout, cluster_layout_sfb_vmnk: cute.Layout, a_smem_layout_staged: cute.ComposedLayout, b_smem_layout_staged: cute.ComposedLayout, sfa_smem_layout_staged: cute.Layout, sfb_smem_layout_staged: cute.Layout, c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout], epi_tile: cute.Tile, tile_sched_params: utils.PersistentTileSchedulerParams, group_count: cutlass.Constexpr, problem_sizes_mnkl: cute.Tensor, strides_abc: cute.Tensor, ptrs_abc: cute.Tensor, ptrs_sfasfb: cute.Tensor, tensormaps: cute.Tensor, ): """ GPU device kernel performing the grouped GEMM computation. """ warp_idx = cute.arch.warp_idx() warp_idx = cute.arch.make_warp_uniform(warp_idx) if warp_idx == self.tma_warp_id: cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_a) cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_b) cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_sfa) cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_sfb) cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_c) use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 # # Setup cta/thread coordinates # # Coords inside cluster bidx, bidy, bidz = cute.arch.block_idx() mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) is_leader_cta = mma_tile_coord_v == 0 cta_rank_in_cluster = cute.arch.make_warp_uniform( cute.arch.block_idx_in_cluster() ) block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( cta_rank_in_cluster ) block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord( cta_rank_in_cluster ) # coord inside cta tidx, _, _ = cute.arch.thread_idx() # # Alloc and init: tensormap buffer, a+b full/empty, accumulator full/empty, tensor memory dealloc barrier # smem = utils.SmemAllocator() storage = smem.allocate(self.shared_storage) tensormap_smem_ptr = storage.tensormap_buffer.data_ptr() tensormap_a_smem_ptr = tensormap_smem_ptr tensormap_b_smem_ptr = ( tensormap_a_smem_ptr + Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap // 8 ) tensormap_sfa_smem_ptr = ( tensormap_b_smem_ptr + Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap // 8 ) tensormap_sfb_smem_ptr = ( tensormap_sfa_smem_ptr + Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap // 8 ) tensormap_c_smem_ptr = ( tensormap_sfb_smem_ptr + Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap // 8 ) tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar.ptr tmem_holding_buf_ptr = storage.tmem_holding_buf.ptr # Initialize mainloop ab_pipeline (barrier) and states ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 ab_pipeline_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, num_tma_producer ) ab_pipeline = pipeline.PipelineTmaUmma.create( barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), num_stages=self.num_ab_stage, producer_group=ab_pipeline_producer_group, consumer_group=ab_pipeline_consumer_group, tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cluster_layout_vmnk, ) # Initialize acc_pipeline (barrier) and states acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) num_acc_consumer_threads = len(self.epilog_warp_id) * ( 2 if use_2cta_instrs else 1 ) acc_pipeline_consumer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, num_acc_consumer_threads ) acc_pipeline = pipeline.PipelineUmmaAsync.create( barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), num_stages=self.num_acc_stage, producer_group=acc_pipeline_producer_group, consumer_group=acc_pipeline_consumer_group, cta_layout_vmnk=cluster_layout_vmnk, ) # Tensor memory dealloc barrier init if use_2cta_instrs: if warp_idx == self.tma_warp_id: num_tmem_dealloc_threads = 32 with cute.arch.elect_one(): cute.arch.mbarrier_init( tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads ) # Cluster arrive after barrier init pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True) # # Setup smem tensor A/B/SFA/SFB/C # sC = storage.sC.get_tensor( c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner ) # (MMA, MMA_M, MMA_K, STAGE) sA = storage.sA.get_tensor( a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner ) # (MMA, MMA_N, MMA_K, STAGE) sB = storage.sB.get_tensor( b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner ) # (MMA, MMA_M, MMA_K, STAGE) sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) # (MMA, MMA_N, MMA_K, STAGE) sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged) # # Compute multicast mask for A/B/SFA/SFB buffer full # a_full_mcast_mask = None b_full_mcast_mask = None sfa_full_mcast_mask = None sfb_full_mcast_mask = None if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): a_full_mcast_mask = cpasync.create_tma_multicast_mask( cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 ) b_full_mcast_mask = cpasync.create_tma_multicast_mask( cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 ) sfa_full_mcast_mask = cpasync.create_tma_multicast_mask( cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 ) sfb_full_mcast_mask = cpasync.create_tma_multicast_mask( cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1 ) # # Local_tile partition global tensors # # (bM, bK, RestM, RestK, RestL) gA_mkl = cute.local_tile( mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) ) # (bN, bK, RestN, RestK, RestL) gB_nkl = cute.local_tile( mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) ) # (bM, bK, RestM, RestK, RestL) gSFA_mkl = cute.local_tile( mSFA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) ) # (bN, bK, RestN, RestK, RestL) gSFB_nkl = cute.local_tile( mSFB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) ) # (bM, bN, RestM, RestN, RestL) gC_mnl = cute.local_tile( mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) ) # # Partition global tensor for TiledMMA_A/B/C # thr_mma = tiled_mma.get_slice(mma_tile_coord_v) thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_coord_v) # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) tCgA = thr_mma.partition_A(gA_mkl) # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) tCgB = thr_mma.partition_B(gB_nkl) # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) tCgSFA = thr_mma.partition_A(gSFA_mkl) # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl) # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) tCgC = thr_mma.partition_C(gC_mnl) # # Partition global/shared tensor for TMA load A/B # # TMA load A partition_S/D a_cta_layout = cute.make_layout( cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape ) # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), RestM, RestK, RestL) tAsA, tAgA = cpasync.tma_partition( tma_atom_a, block_in_cluster_coord_vmnk[2], a_cta_layout, cute.group_modes(sA, 0, 3), cute.group_modes(tCgA, 0, 3), ) # TMA load B partition_S/D b_cta_layout = cute.make_layout( cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape ) # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), RestN, RestK, RestL) tBsB, tBgB = cpasync.tma_partition( tma_atom_b, block_in_cluster_coord_vmnk[1], b_cta_layout, cute.group_modes(sB, 0, 3), cute.group_modes(tCgB, 0, 3), ) # TMA load scaled factor A partition_S/D sfa_cta_layout = a_cta_layout # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), RestM, RestK, RestL) tAsSFA, tAgSFA = cute.nvgpu.cpasync.tma_partition( tma_atom_sfa, block_in_cluster_coord_vmnk[2], sfa_cta_layout, cute.group_modes(sSFA, 0, 3), cute.group_modes(tCgSFA, 0, 3), ) tAsSFA = cute.filter_zeros(tAsSFA) tAgSFA = cute.filter_zeros(tAgSFA) # TMA load scaled factor B partition_S/D sfb_cta_layout = cute.make_layout( cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape ) # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), RestN, RestK, RestL) tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition( tma_atom_sfb, block_in_cluster_coord_sfb_vmnk[1], sfb_cta_layout, cute.group_modes(sSFB, 0, 3), cute.group_modes(tCgSFB, 0, 3), ) tBsSFB = cute.filter_zeros(tBsSFB) tBgSFB = cute.filter_zeros(tBgSFB) # # Partition shared/tensor memory tensor for TiledMMA_A/B/C # # (MMA, MMA_M, MMA_K, STAGE) tCrA = tiled_mma.make_fragment_A(sA) # (MMA, MMA_N, MMA_K, STAGE) tCrB = tiled_mma.make_fragment_B(sB) # (MMA, MMA_M, MMA_N) acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) # (MMA, MMA_M, MMA_N, STAGE) tCtAcc_fake = tiled_mma.make_fragment_C( cute.append(acc_shape, self.num_acc_stage) ) # # Cluster wait before tensor memory alloc # pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn) # # Get tensormap buffer address # grid_dim = cute.arch.grid_dim() tensormap_workspace_idx = ( bidz * grid_dim[1] * grid_dim[0] + bidy * grid_dim[0] + bidx ) tensormap_manager = utils.TensorMapManager( utils.TensorMapUpdateMode.SMEM, Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap, ) tensormap_a_gmem_ptr = tensormap_manager.get_tensormap_ptr( tensormaps[(tensormap_workspace_idx, 0, None)].iterator ) tensormap_b_gmem_ptr = tensormap_manager.get_tensormap_ptr( tensormaps[(tensormap_workspace_idx, 1, None)].iterator ) tensormap_sfa_gmem_ptr = tensormap_manager.get_tensormap_ptr( tensormaps[(tensormap_workspace_idx, 2, None)].iterator ) tensormap_sfb_gmem_ptr = tensormap_manager.get_tensormap_ptr( tensormaps[(tensormap_workspace_idx, 3, None)].iterator ) tensormap_c_gmem_ptr = tensormap_manager.get_tensormap_ptr( tensormaps[(tensormap_workspace_idx, 4, None)].iterator ) # # Persistent tile scheduling loop # # When the problem shapes are on device, we launch one CTA per SM. # The if condition later prevents the warps from extra CTAs from doing any work. tile_sched = utils.StaticPersistentGroupTileScheduler.create( tile_sched_params, cute.arch.block_idx(), grid_dim, self.cluster_tile_shape_mnk, utils.create_initial_search_state(), group_count, problem_sizes_mnkl, ) initial_work_tile_info = tile_sched.initial_work_tile_info() # # Specialized TMA load warp # if warp_idx == self.tma_warp_id and initial_work_tile_info.is_valid_tile: # # Persistent tile scheduling loop # work_tile = initial_work_tile_info tensormap_init_done = cutlass.Boolean(False) # group index of last tile last_group_idx = cutlass.Int32(-1) ab_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.num_ab_stage ) while work_tile.is_valid_tile: grouped_gemm_cta_tile_info = work_tile.group_search_result cur_k_tile_cnt = grouped_gemm_cta_tile_info.cta_tile_count_k cur_group_idx = grouped_gemm_cta_tile_info.group_idx is_k_tile_cnt_zero = cur_k_tile_cnt == 0 # Do not load any data if cur_k_tile_cnt is 0 if not is_k_tile_cnt_zero: is_group_changed = cur_group_idx != last_group_idx # skip tensormap update if we're working on the same group if is_group_changed: real_tensor_a = self.make_tensor_abc_for_tensormap_update( cur_group_idx, self.a_dtype, ( grouped_gemm_cta_tile_info.problem_shape_m, grouped_gemm_cta_tile_info.problem_shape_n, grouped_gemm_cta_tile_info.problem_shape_k, ), strides_abc, ptrs_abc, 0, # 0 for tensor A ) real_tensor_b = self.make_tensor_abc_for_tensormap_update( cur_group_idx, self.b_dtype, ( grouped_gemm_cta_tile_info.problem_shape_m, grouped_gemm_cta_tile_info.problem_shape_n, grouped_gemm_cta_tile_info.problem_shape_k, ), strides_abc, ptrs_abc, 1, # 1 for tensor B ) real_tensor_sfa = self.make_tensor_sfasfb_for_tensormap_update( cur_group_idx, self.sf_dtype, ( grouped_gemm_cta_tile_info.problem_shape_m, grouped_gemm_cta_tile_info.problem_shape_n, grouped_gemm_cta_tile_info.problem_shape_k, ), ptrs_sfasfb, 0, # 0 for tensor SFA ) real_tensor_sfb = self.make_tensor_sfasfb_for_tensormap_update( cur_group_idx, self.sf_dtype, ( grouped_gemm_cta_tile_info.problem_shape_m, grouped_gemm_cta_tile_info.problem_shape_n, grouped_gemm_cta_tile_info.problem_shape_k, ), ptrs_sfasfb, 1, # 1 for tensor SFB ) if not tensormap_init_done: # wait tensormap initialization complete self.tensormap_ab_init_barrier.arrive_and_wait() tensormap_init_done = True tensormap_manager.update_tensormap( ( real_tensor_a, real_tensor_b, real_tensor_sfa, real_tensor_sfb, ), (tma_atom_a, tma_atom_b, tma_atom_sfa, tma_atom_sfb), ( tensormap_a_gmem_ptr, tensormap_b_gmem_ptr, tensormap_sfa_gmem_ptr, tensormap_sfb_gmem_ptr, ), self.tma_warp_id, ( tensormap_a_smem_ptr, tensormap_b_smem_ptr, tensormap_sfa_smem_ptr, tensormap_sfb_smem_ptr, ), ) mma_tile_coord_mnl = ( grouped_gemm_cta_tile_info.cta_tile_idx_m // cute.size(tiled_mma.thr_id.shape), grouped_gemm_cta_tile_info.cta_tile_idx_n, 0, ) # # Slice to per mma tile index # # ((atom_v, rest_v), RestK) tAgA_slice = tAgA[ (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) ] # ((atom_v, rest_v), RestK) tBgB_slice = tBgB[ (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) ] # ((atom_v, rest_v), RestK) tAgSFA_slice = tAgSFA[ (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) ] # ((atom_v, rest_v), RestK) tBgSFB_slice = tBgSFB[ (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) ] # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt ab_producer_state.reset_count() peek_ab_empty_status = cutlass.Boolean(1) if ab_producer_state.count < cur_k_tile_cnt: peek_ab_empty_status = ab_pipeline.producer_try_acquire( ab_producer_state ) if is_group_changed: tensormap_manager.fence_tensormap_update(tensormap_a_gmem_ptr) tensormap_manager.fence_tensormap_update(tensormap_b_gmem_ptr) tensormap_manager.fence_tensormap_update(tensormap_sfa_gmem_ptr) tensormap_manager.fence_tensormap_update(tensormap_sfb_gmem_ptr) # # Tma load loop # for k_tile in cutlass.range(0, cur_k_tile_cnt, 1, unroll=1): # Conditionally wait for AB buffer empty ab_pipeline.producer_acquire( ab_producer_state, peek_ab_empty_status ) # TMA load A/B/SFA/SFB cute.copy( tma_atom_a, tAgA_slice[(None, ab_producer_state.count)], tAsA[(None, ab_producer_state.index)], tma_bar_ptr=ab_pipeline.producer_get_barrier( ab_producer_state ), mcast_mask=a_full_mcast_mask, tma_desc_ptr=tensormap_manager.get_tensormap_ptr( tensormap_a_gmem_ptr, cute.AddressSpace.generic, ), ) cute.copy( tma_atom_b, tBgB_slice[(None, ab_producer_state.count)], tBsB[(None, ab_producer_state.index)], tma_bar_ptr=ab_pipeline.producer_get_barrier( ab_producer_state ), mcast_mask=b_full_mcast_mask, tma_desc_ptr=tensormap_manager.get_tensormap_ptr( tensormap_b_gmem_ptr, cute.AddressSpace.generic, ), ) cute.copy( tma_atom_sfa, tAgSFA_slice[(None, ab_producer_state.count)], tAsSFA[(None, ab_producer_state.index)], tma_bar_ptr=ab_pipeline.producer_get_barrier( ab_producer_state ), mcast_mask=sfa_full_mcast_mask, tma_desc_ptr=tensormap_manager.get_tensormap_ptr( tensormap_sfa_gmem_ptr, cute.AddressSpace.generic, ), ) cute.copy( tma_atom_sfb, tBgSFB_slice[(None, ab_producer_state.count)], tBsSFB[(None, ab_producer_state.index)], tma_bar_ptr=ab_pipeline.producer_get_barrier( ab_producer_state ), mcast_mask=sfb_full_mcast_mask, tma_desc_ptr=tensormap_manager.get_tensormap_ptr( tensormap_sfb_gmem_ptr, cute.AddressSpace.generic, ), ) # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 ab_producer_state.advance() peek_ab_empty_status = cutlass.Boolean(1) if ab_producer_state.count < cur_k_tile_cnt: peek_ab_empty_status = ab_pipeline.producer_try_acquire( ab_producer_state ) else: if not tensormap_init_done: # wait tensormap initialization complete self.tensormap_ab_init_barrier.arrive_and_wait() tensormap_init_done = True # # Advance to next tile # tile_sched.advance_to_next_work() work_tile = tile_sched.get_current_work() last_group_idx = cur_group_idx # # Wait A/B buffer empty # ab_pipeline.producer_tail(ab_producer_state) # # Specialized MMA warp # if warp_idx == self.mma_warp_id and initial_work_tile_info.is_valid_tile: # # Initialize tensormaps for A, B, SFA and SFB # tensormap_manager.init_tensormap_from_atom( tma_atom_a, tensormap_a_smem_ptr, self.mma_warp_id ) tensormap_manager.init_tensormap_from_atom( tma_atom_b, tensormap_b_smem_ptr, self.mma_warp_id ) tensormap_manager.init_tensormap_from_atom( tma_atom_sfa, tensormap_sfa_smem_ptr, self.mma_warp_id ) tensormap_manager.init_tensormap_from_atom( tma_atom_sfb, tensormap_sfb_smem_ptr, self.mma_warp_id ) # indicate tensormap initialization has finished self.tensormap_ab_init_barrier.arrive_and_wait() # # Bar sync for retrieve tensor memory ptr from shared mem # self.tmem_alloc_barrier.arrive_and_wait() # # Retrieving tensor memory ptr and make accumulator/SFA/SFB tensor # # Make accumulator tmem tensor acc_tmem_ptr = cute.arch.retrieve_tmem_ptr( self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf_ptr, ) # (MMA, MMA_M, MMA_N, STAGE) tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) # Make SFA tmem tensor sfa_tmem_ptr = cute.recast_ptr( acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base), dtype=self.sf_dtype, ) # (MMA, MMA_M, MMA_K) tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( tiled_mma, self.mma_tiler, self.sf_vec_size, cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), ) tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) # Make SFB tmem tensor sfb_tmem_ptr = cute.recast_ptr( acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + tcgen05.find_tmem_tensor_col_offset(tCtSFA), dtype=self.sf_dtype, ) # (MMA, MMA_N, MMA_K) tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( tiled_mma, self.mma_tiler, self.sf_vec_size, cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), ) tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) # # Partition for S2T copy of SFA/SFB # tiled_copy_s2t_sfa, tCsSFA_compact_s2t, tCtSFA_compact_s2t = ( self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA) ) tiled_copy_s2t_sfb, tCsSFB_compact_s2t, tCtSFB_compact_s2t = ( self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB) ) # # Persistent tile scheduling loop # work_tile = initial_work_tile_info ab_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.num_ab_stage ) acc_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.num_acc_stage ) while work_tile.is_valid_tile: cur_group_idx = work_tile.group_search_result.group_idx problem_shape_k = work_tile.group_search_result.problem_shape_k # MMA warp is only interested in number of tiles along K dimension cur_k_tile_cnt = ( problem_shape_k + self.cluster_tile_shape_mnk[2] - 1 ) // self.cluster_tile_shape_mnk[2] is_k_tile_cnt_zero = cur_k_tile_cnt == 0 # (MMA, MMA_M, MMA_N) tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] # Peek (try_wait) AB buffer full for k_tile = 0 ab_consumer_state.reset_count() peek_ab_full_status = cutlass.Boolean(1) if ab_consumer_state.count < cur_k_tile_cnt and is_leader_cta: peek_ab_full_status = ab_pipeline.consumer_try_wait( ab_consumer_state ) # # Wait for accumulator buffer empty # if is_leader_cta and not is_k_tile_cnt_zero: acc_pipeline.producer_acquire(acc_producer_state) # # Reset the ACCUMULATE field for each tile # tiled_mma.set(tcgen05.Field.ACCUMULATE, False) # # Mma mainloop # for k_tile in range(cur_k_tile_cnt): if is_leader_cta: # Conditionally wait for AB buffer full ab_pipeline.consumer_wait( ab_consumer_state, peek_ab_full_status ) # Copy SFA/SFB from smem to tmem s2t_stage_coord = ( None, None, None, None, ab_consumer_state.index, ) tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord] tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord] cute.copy( tiled_copy_s2t_sfa, tCsSFA_compact_s2t_staged, tCtSFA_compact_s2t, ) cute.copy( tiled_copy_s2t_sfb, tCsSFB_compact_s2t_staged, tCtSFB_compact_s2t, ) # tCtAcc += tCrA * tCrSFA * tCrB * tCrSFB num_kblocks = cute.size(tCrA, mode=[2]) for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): kblock_coord = ( None, None, kblock_idx, ab_consumer_state.index, ) # Set SFA/SFB tensor to tiled_mma sf_kblock_coord = (None, None, kblock_idx) tiled_mma.set( tcgen05.Field.SFA, tCtSFA[sf_kblock_coord].iterator, ) tiled_mma.set( tcgen05.Field.SFB, tCtSFB[sf_kblock_coord].iterator, ) cute.gemm( tiled_mma, tCtAcc, tCrA[kblock_coord], tCrB[kblock_coord], tCtAcc, ) # Enable accumulate on tCtAcc after first kblock tiled_mma.set(tcgen05.Field.ACCUMULATE, True) # Async arrive AB buffer empty ab_pipeline.consumer_release(ab_consumer_state) # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 ab_consumer_state.advance() peek_ab_full_status = cutlass.Boolean(1) if ab_consumer_state.count < cur_k_tile_cnt: if is_leader_cta: peek_ab_full_status = ab_pipeline.consumer_try_wait( ab_consumer_state ) # # Async arrive accumulator buffer full # if not is_k_tile_cnt_zero: if is_leader_cta: acc_pipeline.producer_commit(acc_producer_state) acc_producer_state.advance() # # Advance to next tile # tile_sched.advance_to_next_work() work_tile = tile_sched.get_current_work() # # Wait for accumulator buffer empty # acc_pipeline.producer_tail(acc_producer_state) # # Specialized epilogue warps # if warp_idx < self.mma_warp_id and initial_work_tile_info.is_valid_tile: # initialize tensorap for C tensormap_manager.init_tensormap_from_atom( tma_atom_c, tensormap_c_smem_ptr, self.epilog_warp_id[0], ) # # Alloc tensor memory buffer # if warp_idx == self.epilog_warp_id[0]: cute.arch.alloc_tmem( self.num_tmem_alloc_cols, tmem_holding_buf_ptr, is_two_cta=use_2cta_instrs, ) # # Bar sync for retrieve tensor memory ptr from shared memory # self.tmem_alloc_barrier.arrive_and_wait() # # Retrieving tensor memory ptr and make accumulator tensor # acc_tmem_ptr = cute.arch.retrieve_tmem_ptr( self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf_ptr, ) # (MMA, MMA_M, MMA_N, STAGE) tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) ### Start from here # # Partition for epilogue # epi_tidx = tidx tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = ( self.epilog_tmem_copy_and_partition( epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs ) ) tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( tiled_copy_t2r, tTR_rC, epi_tidx, sC ) tma_atom_c, bSG_sC, bSG_gC_partitioned = ( self.epilog_gmem_copy_and_partition( epi_tidx, tma_atom_c, tCgC, epi_tile, sC ) ) # # Persistent tile scheduling loop # work_tile = initial_work_tile_info acc_consumer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.num_acc_stage ) # Threads/warps participating in tma store pipeline c_producer_group = pipeline.CooperativeGroup( pipeline.Agent.Thread, 32 * len(self.epilog_warp_id), ) c_pipeline = pipeline.PipelineTmaStore.create( num_stages=self.num_c_stage, producer_group=c_producer_group, ) # group index to start searching last_group_idx = cutlass.Int32(-1) while work_tile.is_valid_tile: grouped_gemm_cta_tile_info = work_tile.group_search_result cur_group_idx = grouped_gemm_cta_tile_info.group_idx cur_k_tile_cnt = grouped_gemm_cta_tile_info.cta_tile_count_k is_k_tile_cnt_zero = cur_k_tile_cnt == 0 is_group_changed = cur_group_idx != last_group_idx # We still need to store 0s when k_tile_cnt is 0 if is_group_changed: # construct tensor c based on real shape, stride information real_tensor_c = self.make_tensor_abc_for_tensormap_update( cur_group_idx, self.c_dtype, ( grouped_gemm_cta_tile_info.problem_shape_m, grouped_gemm_cta_tile_info.problem_shape_n, grouped_gemm_cta_tile_info.problem_shape_k, ), strides_abc, ptrs_abc, 2, # 2 for tensor C ) tensormap_manager.update_tensormap( ((real_tensor_c),), ((tma_atom_c),), ((tensormap_c_gmem_ptr),), self.epilog_warp_id[0], (tensormap_c_smem_ptr,), ) mma_tile_coord_mnl = ( grouped_gemm_cta_tile_info.cta_tile_idx_m // cute.size(tiled_mma.thr_id.shape), grouped_gemm_cta_tile_info.cta_tile_idx_n, 0, ) # # Slice to per mma tile index # # ((ATOM_V, REST_V), EPI_M, EPI_N) bSG_gC = bSG_gC_partitioned[ ( None, None, None, *mma_tile_coord_mnl, ) ] # Set tensor memory buffer for current tile # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) tTR_tAcc = tTR_tAcc_base[ (None, None, None, None, None, acc_consumer_state.index) ] # # Wait for accumulator buffer full # if not is_k_tile_cnt_zero: acc_pipeline.consumer_wait(acc_consumer_state) tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) if is_group_changed: if warp_idx == self.epilog_warp_id[0]: tensormap_manager.fence_tensormap_update(tensormap_c_gmem_ptr) # # Store accumulator to global memory in subtiles # subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt for subtile_idx in range(subtile_cnt): if not is_k_tile_cnt_zero: # # Load accumulator from tensor memory buffer to register # tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) # # Convert to C type # acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() tRS_rC.store(acc_vec.to(self.c_dtype)) else: if cutlass.const_expr(self.is_nvfp4_output): zeros_i8 = cute.make_rmem_tensor( cute.recast_layout( cutlass.Int8.width, self.c_dtype.width, tRS_rC.layout, ), cutlass.Int8, ) zeros_i8.fill(0) tRS_rC.store( cute.recast_tensor(zeros_i8, self.c_dtype).load() ) else: tRS_rC.fill(0) # # Store C to shared memory # c_buffer = (num_prev_subtiles + subtile_idx) % self.num_c_stage cute.copy( tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)], ) # Fence and barrier to make sure shared memory store is visible to TMA store cute.arch.fence_proxy( "async.shared", space="cta", ) self.epilog_sync_barrier.arrive_and_wait() # # TMA store C to global memory # if warp_idx == self.epilog_warp_id[0]: cute.copy( tma_atom_c, bSG_sC[(None, c_buffer)], bSG_gC[(None, subtile_idx)], tma_desc_ptr=tensormap_manager.get_tensormap_ptr( tensormap_c_gmem_ptr, cute.AddressSpace.generic, ), ) # Fence and barrier to make sure shared memory store is visible to TMA store c_pipeline.producer_commit() c_pipeline.producer_acquire() self.epilog_sync_barrier.arrive_and_wait() # # Async arrive accumulator buffer empty # if not is_k_tile_cnt_zero: with cute.arch.elect_one(): acc_pipeline.consumer_release(acc_consumer_state) acc_consumer_state.advance() # # Advance to next tile # tile_sched.advance_to_next_work() work_tile = tile_sched.get_current_work() last_group_idx = cur_group_idx # # Dealloc the tensor memory buffer # if warp_idx == self.epilog_warp_id[0]: cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs) self.epilog_sync_barrier.arrive_and_wait() if warp_idx == self.epilog_warp_id[0]: if use_2cta_instrs: cute.arch.mbarrier_arrive( tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1 ) cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) cute.arch.dealloc_tmem( acc_tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs ) # # Wait for C store complete # c_pipeline.producer_tail() @cute.jit def make_tensor_abc_for_tensormap_update( self, group_idx: cutlass.Int32, dtype: Type[cutlass.Numeric], problem_shape_mnk: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32], strides_abc: cute.Tensor, tensor_address_abc: cute.Tensor, tensor_index: int, ): """Extract stride and tensor address for a given group and construct a global tensor for A, B or C. This function is used within the kernel to dynamically create a CUTE tensor representing A, B, or C for the current group being processed, using the group-specific address, shape, and stride information. :param group_idx: The index of the current group within the grouped GEMM. :type group_idx: cutlass.Int32 :param dtype: The data type of the tensor elements (e.g., cutlass.Float16). :type dtype: Type[cutlass.Numeric] :param problem_shape_mnk: The (M, N, K) problem shape for the current group. :type problem_shape_mnk: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32] :param strides_abc: Tensor containing strides for A, B, C for all groups. Layout: (group_count, 3, 2). :type strides_abc: cute.Tensor :param tensor_address_abc: Tensor containing global memory addresses for A, B, C for all groups. Layout: (group_count, 3). :type tensor_address_abc: cute.Tensor :param tensor_index: Specifies which tensor to create: 0 for A, 1 for B, 2 for C. :type tensor_index: int :return: A CUTE tensor representing the requested global memory tensor (A, B, or C) for the specified group. :rtype: cute.Tensor :raises TypeError: If the provided dtype is not a subclass of cutlass.Numeric. """ ptr_i64 = tensor_address_abc[(group_idx, tensor_index)] if cutlass.const_expr( not isclass(dtype) or not issubclass(dtype, cutlass.Numeric) ): raise TypeError( f"dtype must be a type of cutlass.Numeric, got {type(dtype)}" ) tensor_gmem_ptr = cute.make_ptr( dtype, ptr_i64, cute.AddressSpace.gmem, assumed_align=16 ) strides_tensor_gmem = strides_abc[(group_idx, tensor_index, None)] strides_tensor_reg = cute.make_rmem_tensor( cute.make_layout(2), strides_abc.element_type, ) cute.autovec_copy(strides_tensor_gmem, strides_tensor_reg) stride_mn = strides_tensor_reg[0] stride_k = strides_tensor_reg[1] c1 = cutlass.Int32(1) c0 = cutlass.Int32(0) if cutlass.const_expr(tensor_index == 0): # tensor A m = problem_shape_mnk[0] k = problem_shape_mnk[2] return cute.make_tensor( tensor_gmem_ptr, cute.make_layout((m, k, c1), stride=(stride_mn, stride_k, c0)), ) elif cutlass.const_expr(tensor_index == 1): # tensor B n = problem_shape_mnk[1] k = problem_shape_mnk[2] return cute.make_tensor( tensor_gmem_ptr, cute.make_layout((n, k, c1), stride=(stride_mn, stride_k, c0)), ) else: # tensor C m = problem_shape_mnk[0] n = problem_shape_mnk[1] return cute.make_tensor( tensor_gmem_ptr, cute.make_layout((m, n, c1), stride=(stride_mn, stride_k, c0)), ) @cute.jit def make_tensor_sfasfb_for_tensormap_update( self, group_idx: cutlass.Int32, dtype: Type[cutlass.Numeric], problem_shape_mnk: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32], tensor_address_sfasfb: cute.Tensor, tensor_index: int, ): """Extract tensor address for a given group and construct a global tensor for SFA or SFB. This function is used within the kernel to dynamically create a CUTE tensor representing SFA or SFB for the current group being processed, using the group-specific address, shape information. :param group_idx: The index of the current group within the grouped GEMM. :type group_idx: cutlass.Int32 :param dtype: The data type of the tensor elements (e.g., cutlass.Float16). :type dtype: Type[cutlass.Numeric] :param problem_shape_mnk: The (M, N, K) problem shape for the current group. :type problem_shape_mnk: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32] :param tensor_address_sfasfb: Tensor containing global memory addresses for SFA, SFB for all groups. Layout: (group_count, 2). :type tensor_address_sfasfb: cute.Tensor :param tensor_index: Specifies which tensor to create: 0 for SFA, 1 for SFB. :type tensor_index: int :return: A CUTE tensor representing the requested global memory tensor (SFA, SFB) for the specified group. :rtype: cute.Tensor :raises TypeError: If the provided dtype is not a subclass of cutlass.Numeric. """ ptr_i64 = tensor_address_sfasfb[(group_idx, tensor_index)] if cutlass.const_expr( not isclass(dtype) or not issubclass(dtype, cutlass.Numeric) ): raise TypeError( f"dtype must be a type of cutlass.Numeric, got {type(dtype)}" ) tensor_gmem_ptr = cute.make_ptr( dtype, ptr_i64, cute.AddressSpace.gmem, assumed_align=16 ) c1 = cutlass.Int32(1) if cutlass.const_expr(tensor_index == 0): # tensor SFA m = problem_shape_mnk[0] k = problem_shape_mnk[2] sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( (m, k, c1), self.sf_vec_size ) return cute.make_tensor( tensor_gmem_ptr, sfa_layout, ) else: # tensor SFB n = problem_shape_mnk[1] k = problem_shape_mnk[2] sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( (n, k, c1), self.sf_vec_size ) return cute.make_tensor( tensor_gmem_ptr, sfb_layout, ) def mainloop_s2t_copy_and_partition( self, sSF: cute.Tensor, tSF: cute.Tensor, ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: """ Make tiledCopy for smem to tmem load for scale factor tensor, then use it to partition smem memory (source) and tensor memory (destination). :param sSF: The scale factor tensor in smem :type sSF: cute.Tensor :param tSF: The scale factor tensor in tmem :type tSF: cute.Tensor :return: A tuple containing (tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t) where: - tiled_copy_s2t: The tiled copy operation for smem to tmem load for scale factor tensor(s2t) - tCsSF_compact_s2t: The partitioned scale factor tensor in smem - tSF_compact_s2t: The partitioned scale factor tensor in tmem :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] """ # (MMA, MMA_MN, MMA_K, STAGE) tCsSF_compact = cute.filter_zeros(sSF) # (MMA, MMA_MN, MMA_K) tCtSF_compact = cute.filter_zeros(tSF) # Make S2T CopyAtom and tiledCopy copy_atom_s2t = cute.make_copy_atom( tcgen05.Cp4x32x128bOp(self.cta_group), self.sf_dtype, ) tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact) thr_copy_s2t = tiled_copy_s2t.get_slice(0) # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact) # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( tiled_copy_s2t, tCsSF_compact_s2t_ ) # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact) return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t def epilog_tmem_copy_and_partition( self, tidx: cutlass.Int32, tAcc: cute.Tensor, gC_mnl: cute.Tensor, epi_tile: cute.Tile, use_2cta_instrs: Union[cutlass.Boolean, bool], ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: """ Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). :param tidx: The thread index in epilogue warp groups :type tidx: cutlass.Int32 :param tAcc: The accumulator tensor to be copied and partitioned :type tAcc: cute.Tensor :param gC_mnl: The global tensor C :type gC_mnl: cute.Tensor :param epi_tile: The epilogue tiler :type epi_tile: cute.Tile :param use_2cta_instrs: Whether use_2cta_instrs is enabled :type use_2cta_instrs: bool :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) - tTR_tAcc: The partitioned accumulator tensor - tTR_rAcc: The accumulated tensor in register used to hold t2r results :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] """ # Make tiledCopy for tensor memory load copy_atom_t2r = sm100_utils.get_tmem_load_op( self.cta_tile_shape_mnk, self.c_layout, self.c_dtype, self.acc_dtype, epi_tile, use_2cta_instrs, ) # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) tAcc_epi = cute.flat_divide( tAcc[((None, None), 0, 0, None)], epi_tile, ) # (EPI_TILE_M, EPI_TILE_N) tiled_copy_t2r = tcgen05.make_tmem_copy( copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)] ) thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) gC_mnl_epi = cute.flat_divide( gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile ) # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) # (T2R, T2R_M, T2R_N) tTR_rAcc = cute.make_rmem_tensor( tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype ) return tiled_copy_t2r, tTR_tAcc, tTR_rAcc def epilog_smem_copy_and_partition( self, tiled_copy_t2r: cute.TiledCopy, tTR_rC: cute.Tensor, tidx: cutlass.Int32, sC: cute.Tensor, ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: """ Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) :type tiled_copy_t2r: cute.TiledCopy :param tTR_rC: The partitioned accumulator tensor :type tTR_rC: cute.Tensor :param tidx: The thread index in epilogue warp groups :type tidx: cutlass.Int32 :param sC: The shared memory tensor to be copied and partitioned :type sC: cute.Tensor :type sepi: cute.Tensor :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) - tRS_rC: The partitioned tensor C (register source) - tRS_sC: The partitioned tensor C (smem destination) :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] """ copy_atom_r2s = sm100_utils.get_smem_store_op( self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r ) tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) # (R2S, R2S_M, R2S_N, PIPE_D) thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) tRS_sC = thr_copy_r2s.partition_D(sC) # (R2S, R2S_M, R2S_N) tRS_rC = tiled_copy_r2s.retile(tTR_rC) return tiled_copy_r2s, tRS_rC, tRS_sC def epilog_gmem_copy_and_partition( self, tidx: cutlass.Int32, atom: Union[cute.CopyAtom, cute.TiledCopy], gC_mnl: cute.Tensor, epi_tile: cute.Tile, sC: cute.Tensor, ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: """Make tiledCopy for global memory store, then use it to: partition shared memory (source) and global memory (destination) for TMA store version. :param tidx: The thread index in epilogue warp groups :type tidx: cutlass.Int32 :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version :type atom: cute.CopyAtom or cute.TiledCopy :param gC_mnl: The global tensor C :type gC_mnl: cute.Tensor :param epi_tile: The epilogue tiler :type epi_tile: cute.Tile :param sC: The shared memory tensor to be copied and partitioned :type sC: cute.Tensor :return: A tuple containing (tma_atom_c, bSG_sC, bSG_gC) where: - tma_atom_c: The TMA copy atom - bSG_sC: The partitioned shared memory tensor C - bSG_gC: The partitioned global tensor C :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] """ # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) gC_epi = cute.flat_divide( gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile ) tma_atom_c = atom sC_for_tma_partition = cute.group_modes(sC, 0, 2) gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) # ((ATOM_V, REST_V), EPI_M, EPI_N) # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) bSG_sC, bSG_gC = cpasync.tma_partition( tma_atom_c, 0, cute.make_layout(1), sC_for_tma_partition, gC_for_tma_partition, ) return tma_atom_c, bSG_sC, bSG_gC @staticmethod def _compute_stages( tiled_mma: cute.TiledMma, mma_tiler_mnk: Tuple[int, int, int], a_dtype: Type[cutlass.Numeric], b_dtype: Type[cutlass.Numeric], epi_tile: cute.Tile, c_dtype: Type[cutlass.Numeric], c_layout: utils.LayoutEnum, sf_dtype: Type[cutlass.Numeric], sf_vec_size: int, smem_capacity: int, occupancy: int, ) -> Tuple[int, int, int]: """Computes the number of stages for A/B/C operands based on heuristics. :param tiled_mma: The tiled MMA object defining the core computation. :type tiled_mma: cute.TiledMma :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. :type mma_tiler_mnk: tuple[int, int, int] :param a_dtype: Data type of operand A. :type a_dtype: type[cutlass.Numeric] :param b_dtype: Data type of operand B. :type b_dtype: type[cutlass.Numeric] :param epi_tile: The epilogue tile shape. :type epi_tile: cute.Tile :param c_dtype: Data type of operand C (output). :type c_dtype: type[cutlass.Numeric] :param c_layout: Layout enum of operand C. :type c_layout: utils.LayoutEnum :param sf_dtype: Data type of Scale factor. :type sf_dtype: type[cutlass.Numeric] :param sf_vec_size: Scale factor vector size. :type sf_vec_size: int :param smem_capacity: Total available shared memory capacity in bytes. :type smem_capacity: int :param occupancy: Target number of CTAs per SM (occupancy). :type occupancy: int :return: A tuple containing the computed number of stages for: (ACC stages, A/B operand stages, C stages) :rtype: tuple[int, int, int] """ # ACC stages num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2 # Default C stages num_c_stage = 2 # Calculate smem layout and size for one stage of A, B, SFA, SFB and C a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( tiled_mma, mma_tiler_mnk, a_dtype, 1, # a tmp 1 stage is provided ) b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( tiled_mma, mma_tiler_mnk, b_dtype, 1, # a tmp 1 stage is provided ) sfa_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfa( tiled_mma, mma_tiler_mnk, sf_vec_size, 1, # a tmp 1 stage is provided ) sfb_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfb( tiled_mma, mma_tiler_mnk, sf_vec_size, 1, # a tmp 1 stage is provided ) c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( c_dtype, c_layout, epi_tile, 1, ) ab_bytes_per_stage = ( cute.size_in_bytes(a_dtype, a_smem_layout_stage_one) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + cute.size_in_bytes(sf_dtype, sfa_smem_layout_staged_one) + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one) ) mbar_helpers_bytes = 1024 c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) c_bytes = c_bytes_per_stage * num_c_stage # Calculate A/B/SFA/SFB stages: # Start with total smem per CTA (capacity / occupancy) # Subtract reserved bytes and initial C stages bytes # Divide remaining by bytes needed per A/B/SFA/SFB stage num_ab_stage = ( smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes) ) // ab_bytes_per_stage # Refine epilogue stages: # Calculate remaining smem after allocating for A/B/SFA/SFB stages and reserved bytes # Add remaining unused smem to epilogue num_c_stage += ( smem_capacity - occupancy * ab_bytes_per_stage * num_ab_stage - occupancy * (mbar_helpers_bytes + c_bytes) ) // (occupancy * c_bytes_per_stage) return num_acc_stage, num_ab_stage, num_c_stage @staticmethod def _compute_grid( total_num_clusters: int, cluster_shape_mn: tuple[int, int], max_active_clusters: cutlass.Constexpr[int], ) -> tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]]: """Compute tile scheduler parameters and grid shape for grouped GEMM operations. :param total_num_clusters: Total number of clusters to process across all groups. :type total_num_clusters: int :param cluster_shape_mn: Shape of each cluster in M, N dimensions. :type cluster_shape_mn: tuple[int, int] :param max_active_clusters: Maximum number of active clusters. :type max_active_clusters: cutlass.Constexpr[int] :return: A tuple containing: - tile_sched_params: Parameters for the persistent tile scheduler. - grid: Grid shape for kernel launch. :rtype: tuple[utils.PersistentTileSchedulerParams, tuple[int, ...]] """ # Create problem shape with M, N dimensions from cluster shape # and L dimension representing the total number of clusters. problem_shape_ntile_mnl = ( cluster_shape_mn[0], cluster_shape_mn[1], cutlass.Int32(total_num_clusters), ) tile_sched_params = utils.PersistentTileSchedulerParams( problem_shape_ntile_mnl, (*cluster_shape_mn, 1) ) grid = utils.StaticPersistentTileScheduler.get_grid_shape( tile_sched_params, max_active_clusters ) return tile_sched_params, grid @staticmethod def _get_mbar_smem_bytes(**kwargs_stages: int) -> int: """Calculate shared memory consumption for memory barriers based on provided stages. Each stage requires 2 barriers, and each barrier consumes 8 bytes of shared memory. The total consumption is the sum across all provided stages. This function calculates the total shared memory needed for these barriers. :param kwargs_stages: Variable keyword arguments where each key is a stage name (e.g., num_acc_stage, num_ab_stage) and each value is the number of stages of that type. :type kwargs_stages: int :return: Total shared memory bytes required for all memory barriers. :rtype: int """ num_barriers_per_stage = 2 num_bytes_per_barrier = 8 mbar_smem_consumption = sum( [ num_barriers_per_stage * num_bytes_per_barrier * stage for stage in kwargs_stages.values() ] ) return mbar_smem_consumption @staticmethod def is_valid_dtypes_and_scale_factor_vec_size( ab_dtype: Type[cutlass.Numeric], sf_dtype: Type[cutlass.Numeric], sf_vec_size: int, c_dtype: Type[cutlass.Numeric], ) -> bool: """ Check if the dtypes and sf_vec_size are valid combinations :param ab_dtype: The data type of the A and B operands :type ab_dtype: Type[cutlass.Numeric] :param sf_dtype: The data type of the scale factor :type sf_dtype: Type[cutlass.Numeric] :param sf_vec_size: The vector size of the scale factor :type sf_vec_size: int :param c_dtype: The data type of the output tensor :type c_dtype: Type[cutlass.Numeric] :return: True if the dtypes and sf_vec_size are valid, False otherwise :rtype: bool """ is_valid = True # Check valid ab_dtype if ab_dtype not in { cutlass.Float4E2M1FN, cutlass.Float8E5M2, cutlass.Float8E4M3FN, }: is_valid = False # Check valid sf_vec_size if sf_vec_size not in {16, 32}: is_valid = False # Check valid sf_dtype if sf_dtype not in {cutlass.Float8E8M0FNU, cutlass.Float8E4M3FN}: is_valid = False # Check valid sf_dtype and sf_vec_size combinations if sf_dtype == cutlass.Float8E4M3FN and sf_vec_size == 32: is_valid = False if ab_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} and sf_vec_size == 16: is_valid = False # Check valid c_dtype if c_dtype not in { cutlass.Float32, cutlass.Float16, cutlass.BFloat16, cutlass.Float8E5M2, cutlass.Float8E4M3FN, }: is_valid = False return is_valid @staticmethod def is_valid_layouts( ab_dtype: Type[cutlass.Numeric], c_dtype: Type[cutlass.Numeric], a_major: str, b_major: str, c_major: str, ) -> bool: """ Check if layouts and dtypes are valid combinations :param ab_dtype: The data type of the A and B operands :type ab_dtype: Type[cutlass.Numeric] :param c_dtype: The data type of the output tensor :type c_dtype: Type[cutlass.Numeric] :param a_major: The major dimension of the A tensor :type a_major: str :param b_major: The major dimension of the B tensor :type b_major: str :param c_major: The major dimension of the C tensor :type c_major: str :return: True if the layouts are valid, False otherwise :rtype: bool """ is_valid = True if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"): is_valid = False return is_valid @staticmethod def is_valid_mma_tiler_and_cluster_shape( mma_tiler_mn: Tuple[int, int], cluster_shape_mn: Tuple[int, int], ) -> bool: """ Check if the mma tiler and cluster shape are valid :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler :type mma_tiler_mn: Tuple[int, int] :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster :type cluster_shape_mn: Tuple[int, int] :return: True if the mma tiler and cluster shape are valid, False otherwise :rtype: bool """ is_valid = True # Skip invalid mma tile shape if mma_tiler_mn[0] not in [128, 256]: is_valid = False if mma_tiler_mn[1] not in [128, 256]: is_valid = False # Skip illegal cluster shape if cluster_shape_mn[0] % (2 if mma_tiler_mn[0] == 256 else 1) != 0: is_valid = False # Skip invalid cluster shape is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 if ( cluster_shape_mn[0] * cluster_shape_mn[1] > 16 or cluster_shape_mn[0] <= 0 or cluster_shape_mn[1] <= 0 # Special cluster shape check for scale factor multicasts. # Due to limited size of scale factors, we can't multicast among more than 4 CTAs. or cluster_shape_mn[0] > 4 or cluster_shape_mn[1] > 4 or not is_power_of_2(cluster_shape_mn[0]) or not is_power_of_2(cluster_shape_mn[1]) ): is_valid = False return is_valid @staticmethod def is_valid_tensor_alignment( problem_sizes_mnkl: List[Tuple[int, int, int, int]], ab_dtype: Type[cutlass.Numeric], c_dtype: Type[cutlass.Numeric], a_major: str, b_major: str, c_major: str, ) -> bool: """ Check if the tensor alignment is valid :param problem_sizes_mnkl: The problem shape for each group :type problem_sizes_mnkl: List[Tuple[int, int, int, int]] :param ab_dtype: The data type of the A and B operands :type ab_dtype: Type[cutlass.Numeric] :param c_dtype: The data type of the output tensor :type c_dtype: Type[cutlass.Numeric] :param a_major: The major axis of the A tensor :type a_major: str :param b_major: The major axis of the B tensor :type b_major: str :param c_major: The major axis of the C tensor :type c_major: str :return: True if the problem shape is valid, False otherwise :rtype: bool """ is_valid = True def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): major_mode_idx = 0 if is_mode0_major else 1 num_major_elements = tensor_shape[major_mode_idx] num_contiguous_elements = 16 * 8 // dtype.width return num_major_elements % num_contiguous_elements == 0 for m, n, k, l in problem_sizes_mnkl: if ( not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) or not check_contigous_16B_alignment( ab_dtype, b_major == "n", (n, k, l) ) or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) ): is_valid = False return is_valid @staticmethod def can_implement( ab_dtype: Type[cutlass.Numeric], sf_dtype: Type[cutlass.Numeric], sf_vec_size: int, c_dtype: Type[cutlass.Numeric], mma_tiler_mn: Tuple[int, int], cluster_shape_mn: Tuple[int, int], problem_sizes_mnkl: List[Tuple[int, int, int, int]], a_major: str, b_major: str, c_major: str, ) -> bool: """ Check if the gemm can be implemented :param ab_dtype: The data type of the A and B operands :type ab_dtype: Type[cutlass.Numeric] :param sf_dtype: The data type of the scale factor tensor :type sf_dtype: Type[cutlass.Numeric] :param sf_vec_size: The vector size :type sf_vec_size: int :param c_dtype: The data type of the output tensor :type c_dtype: Type[cutlass.Numeric] :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler :type mma_tiler_mn: Tuple[int, int] :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster :type cluster_shape_mn: Tuple[int, int] :param a_major: The major axis of the A tensor :type a_major: str :param b_major: The major axis of the B tensor :type b_major: str :param c_major: The major axis of the C tensor :type c_major: str :return: True if the gemm can be implemented, False otherwise :rtype: bool """ can_implement = True # Skip unsupported types if not Sm100GroupedBlockScaledGemmKernel.is_valid_dtypes_and_scale_factor_vec_size( ab_dtype, sf_dtype, sf_vec_size, c_dtype ): can_implement = False # Skip unsupported layouts if not Sm100GroupedBlockScaledGemmKernel.is_valid_layouts( ab_dtype, c_dtype, a_major, b_major, c_major ): can_implement = False # Skip invalid mma tile shape and cluster shape if not Sm100GroupedBlockScaledGemmKernel.is_valid_mma_tiler_and_cluster_shape( mma_tiler_mn, cluster_shape_mn ): can_implement = False # Skip illegal problem shape for load/store alignment if not Sm100GroupedBlockScaledGemmKernel.is_valid_tensor_alignment( problem_sizes_mnkl, ab_dtype, c_dtype, a_major, b_major, c_major ): can_implement = False return can_implement # Size of smem we reserved for mbarrier, tensor memory management and tensormap update reserved_smem_bytes = 1024 bytes_per_tensormap = 128 num_tensormaps = 5 # size of smem used for tensor memory management tensor_memory_management_bytes = 12 # Create tensor and return the pointer, tensor, and stride def create_tensor_and_stride( l: int, mode0: int, mode1: int, is_mode0_major: bool, dtype: type[cutlass.Numeric], is_dynamic_layout: bool = True, ) -> tuple[int, torch.Tensor, cute.Tensor, torch.Tensor, tuple[int, int]]: """Create GPU tensor from either a new or existing CPU tensor. :param torch_tensor_cpu: Optional existing CPU tensor to reuse. If None, creates a new one. :type torch_tensor_cpu: torch.Tensor, optional """ # Create new CPU tensor torch_tensor_cpu = cutlass_torch.matrix( l, mode0, mode1, is_mode0_major, cutlass.Float32, ) # Create GPU tensor from CPU tensor (new or existing) cute_tensor, torch_tensor = cutlass_torch.cute_tensor_like( torch_tensor_cpu, dtype, is_dynamic_layout, assumed_align=16 ) # omit stride for L mode as it is always 1 stride = (1, mode0) if is_mode0_major else (mode1, 1) return ( torch_tensor.data_ptr(), torch_tensor, cute_tensor, torch_tensor_cpu, stride, ) def create_tensors_abc_for_all_groups( problem_sizes_mnkl: List[tuple[int, int, int, int]], ab_dtype: Type[cutlass.Numeric], c_dtype: Type[cutlass.Numeric], a_major: str, b_major: str, c_major: str, ) -> tuple[ List[List[int]], List[List[torch.Tensor]], List[tuple], List[List[tuple]], List[List[torch.Tensor]], ]: ref_torch_fp32_tensors_abc = [] torch_tensors_abc = [] cute_tensors_abc = [] strides_abc = [] ptrs_abc = [] # Iterate through all groups and create tensors for each group for group_idx, (m, n, k, l) in enumerate(problem_sizes_mnkl): # Create tensors A, B, C ( ptr_a, torch_tensor_a, cute_tensor_a, ref_torch_fp32_tensor_a, stride_mk_a, ) = create_tensor_and_stride(l, m, k, a_major == "m", ab_dtype) ( ptr_b, torch_tensor_b, cute_tensor_b, ref_torch_fp32_tensor_b, stride_nk_b, ) = create_tensor_and_stride(l, n, k, b_major == "n", ab_dtype) ( ptr_c, torch_tensor_c, cute_tensor_c, ref_torch_fp32_tensor_c, stride_mn_c, ) = create_tensor_and_stride(l, m, n, c_major == "m", c_dtype) ref_torch_fp32_tensors_abc.append( [ref_torch_fp32_tensor_a, ref_torch_fp32_tensor_b, ref_torch_fp32_tensor_c] ) ptrs_abc.append([ptr_a, ptr_b, ptr_c]) torch_tensors_abc.append([torch_tensor_a, torch_tensor_b, torch_tensor_c]) strides_abc.append([stride_mk_a, stride_nk_b, stride_mn_c]) cute_tensors_abc.append( ( cute_tensor_a, cute_tensor_b, cute_tensor_c, ) ) return ( ptrs_abc, torch_tensors_abc, cute_tensors_abc, strides_abc, ref_torch_fp32_tensors_abc, ) @cute.jit def cvt_sf_MKL_to_M32x4xrm_K4xrk_L( sf_ref_tensor: cute.Tensor, sf_mma_tensor: cute.Tensor, ): """Convert scale factor tensor from MKL layout to mma specification M(32x4xrest_m)xK(4xrest_k)xL layout""" # sf_mma_tensor has flatten shape (32, 4, rest_m, 4, rest_k, l) # group to ((32, 4, rest_m), (4, rest_k), l) sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3) sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3) for i in cutlass.range(cute.size(sf_ref_tensor)): mkl_coord = sf_ref_tensor.layout.get_hier_coord(i) sf_mma_tensor[mkl_coord] = sf_ref_tensor[mkl_coord] # Create scale factor tensor SFA/SFB def create_scale_factor_tensor(l, mn, k, sf_vec_size, dtype): def ceil_div(a, b): return (a + b - 1) // b sf_k = max(1, ceil_div(k, sf_vec_size)) ref_shape = (l, mn, sf_k) atom_m = (32, 4) atom_k = 4 mma_shape = ( l, ceil_div(mn, atom_m[0] * atom_m[1]), ceil_div(sf_k, atom_k), atom_m[0], atom_m[1], atom_k, ) ref_permute_order = (1, 2, 0) mma_permute_order = (3, 4, 1, 5, 2, 0) # Create f32 ref torch tensor (cpu) ref_f32_torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( ref_shape, torch.float32, permute_order=ref_permute_order, init_type=cutlass_torch.TensorInitType.RANDOM, init_config=cutlass_torch.RandomInitConfig( min_val=1, max_val=3, ), ) # Create f32 cute torch tensor (cpu) cute_f32_torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( mma_shape, torch.float32, permute_order=mma_permute_order, init_type=cutlass_torch.TensorInitType.RANDOM, init_config=cutlass_torch.RandomInitConfig( min_val=0, max_val=1, ), ) # convert ref f32 tensor to cute f32 tensor cvt_sf_MKL_to_M32x4xrm_K4xrk_L( from_dlpack(ref_f32_torch_tensor_cpu), from_dlpack(cute_f32_torch_tensor_cpu), ) cute_f32_torch_tensor = cute_f32_torch_tensor_cpu.cuda() # reshape makes memory contiguous ref_f32_torch_tensor_cpu = ( ref_f32_torch_tensor_cpu.permute(2, 0, 1) .unsqueeze(-1) .expand(l, mn, sf_k, sf_vec_size) .reshape(l, mn, sf_k * sf_vec_size) .permute(*ref_permute_order) ) # prune to mkl for reference check. ref_f32_torch_tensor_cpu = ref_f32_torch_tensor_cpu[:, :k, :] # Create dtype cute torch tensor (cpu) cute_tensor, cute_torch_tensor = cutlass_torch.cute_tensor_like( cute_f32_torch_tensor_cpu, dtype, is_dynamic_layout=True, assumed_align=16, ) # Convert f32 cute tensor to dtype cute tensor cute_tensor = cutlass_torch.convert_cute_tensor( cute_f32_torch_tensor, cute_tensor, dtype, is_dynamic_layout=True, ) # get pointer of the tensor ptr = cute_torch_tensor.data_ptr() return ref_f32_torch_tensor_cpu, ptr, cute_tensor, cute_torch_tensor def create_tensors_sfasfb_for_all_groups( problem_sizes_mnkl: List[tuple[int, int, int, int]], sf_dtype: Type[cutlass.Numeric], sf_vec_size: int, ) -> tuple[ List[List[int]], List[List[torch.Tensor]], List[tuple], List[List[torch.Tensor]], ]: ptrs_sfasfb = [] torch_tensors_sfasfb = [] cute_tensors_sfasfb = [] refs_sfasfb = [] # Iterate through all groups and create tensors for each group for group_idx, (m, n, k, l) in enumerate(problem_sizes_mnkl): sfa_ref, ptr_sfa, sfa_tensor, sfa_torch = create_scale_factor_tensor( l, m, k, sf_vec_size, sf_dtype ) sfb_ref, ptr_sfb, sfb_tensor, sfb_torch = create_scale_factor_tensor( l, n, k, sf_vec_size, sf_dtype ) ptrs_sfasfb.append([ptr_sfa, ptr_sfb]) torch_tensors_sfasfb.append([sfa_torch, sfb_torch]) cute_tensors_sfasfb.append( ( sfa_tensor, sfb_tensor, ) ) refs_sfasfb.append([sfa_ref, sfb_ref]) return ( ptrs_sfasfb, torch_tensors_sfasfb, cute_tensors_sfasfb, refs_sfasfb, ) def run( num_groups: int, problem_sizes_mnkl: List[Tuple[int, int, int, int]], host_problem_shape_available: bool, ab_dtype: Type[cutlass.Numeric], sf_dtype: Type[cutlass.Numeric], sf_vec_size: int, c_dtype: Type[cutlass.Numeric], a_major: str, b_major: str, c_major: str, mma_tiler_mn: Tuple[int, int], cluster_shape_mn: Tuple[int, int], tolerance: float = 1e-01, warmup_iterations: int = 0, iterations: int = 1, skip_ref_check: bool = False, use_cold_l2: bool = False, **kwargs, ): """Run SM100 grouped blockscaledGEMM example with specified configurations. :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False :type use_cold_l2: bool, optional :return: Execution time of the GEMM kernel in microseconds :rtype: float """ print("Running Blackwell Grouped GEMM test with:") print(f"{num_groups} groups") for i, (m, n, k, l) in enumerate(problem_sizes_mnkl): print(f"Group {i}: {m}x{n}x{k}x{l}") print(f"AB dtype: {ab_dtype}, SF dtype: {sf_dtype}, SF Vec size: {sf_vec_size}") print(f"C dtype: {c_dtype}") print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}") print(f"Tolerance: {tolerance}") print(f"Warmup iterations: {warmup_iterations}") print(f"Iterations: {iterations}") print(f"Skip reference checking: {skip_ref_check}") print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}") # Skip unsupported testcase if not Sm100GroupedBlockScaledGemmKernel.can_implement( ab_dtype, sf_dtype, sf_vec_size, c_dtype, mma_tiler_mn, cluster_shape_mn, problem_sizes_mnkl, a_major, b_major, c_major, ): raise TypeError( f"Unsupported testcase {ab_dtype}, {sf_dtype}, {sf_vec_size}, {c_dtype}, {mma_tiler_mn}, {cluster_shape_mn}, {problem_sizes_mnkl}, {a_major}, {b_major}, {c_major}" ) if not torch.cuda.is_available(): raise RuntimeError("GPU is required to run this example!") torch.manual_seed(2025) # Create tensors A, B, C for all groups ( ptrs_abc, torch_tensors_abc, cute_tensors_abc, strides_abc, ref_f32_torch_tensors_abc, ) = create_tensors_abc_for_all_groups( problem_sizes_mnkl, ab_dtype, c_dtype, a_major, b_major, c_major, ) # Create tensors SFA, SFB for all groups ( ptrs_sfasfb, torch_tensors_sfasfb, cute_tensors_sfasfb, refs_f32_torch_tensors_sfasfb, ) = create_tensors_sfasfb_for_all_groups( problem_sizes_mnkl, sf_dtype, sf_vec_size, ) # Setup inital tensors for TMA of A,B and C alignment = 16 # 16 bytes aligned divisibility_ab = 32 if ab_dtype == cutlass.Float4E2M1FN else 16 divisibility_c = 32 if c_dtype == cutlass.Float4E2M1FN else 16 divisibility_sf = 32 if sf_dtype == cutlass.Float4E2M1FN else 16 min_ab_size = alignment * 8 // ab_dtype.width # alignment bytes of width div_mul_ab = (divisibility_ab + min_ab_size - 1) // min_ab_size min_ab_size = min_ab_size * div_mul_ab min_c_size = alignment * 8 // c_dtype.width div_mul_c = (divisibility_c + min_c_size - 1) // min_c_size min_c_size = min_c_size * div_mul_c min_sf_size = alignment * 8 // sf_dtype.width div_mul_sf = (divisibility_sf + min_sf_size - 1) // min_sf_size min_sf_size = min_sf_size * div_mul_sf initial_cute_tensors_abc = [ create_tensor_and_stride(1, min_ab_size, min_ab_size, a_major == "m", ab_dtype)[ 2 ], create_tensor_and_stride(1, min_ab_size, min_ab_size, b_major == "n", ab_dtype)[ 2 ], create_tensor_and_stride(1, min_c_size, min_c_size, c_major == "m", c_dtype)[2], ] initial_cute_tensors_sfasfb = [ create_tensor_and_stride(1, min_sf_size, min_sf_size, a_major == "m", sf_dtype)[ 2 ], create_tensor_and_stride(1, min_sf_size, min_sf_size, b_major == "n", sf_dtype)[ 2 ], ] hardware_info = cutlass.utils.HardwareInfo() sm_count = hardware_info.get_max_active_clusters(1) max_active_clusters = hardware_info.get_max_active_clusters( cluster_shape_mn[0] * cluster_shape_mn[1] ) # Prepare tensormap buffer for each SM num_tensormap_buffers = sm_count tensormap_shape = ( num_tensormap_buffers, Sm100GroupedBlockScaledGemmKernel.num_tensormaps, Sm100GroupedBlockScaledGemmKernel.bytes_per_tensormap // 8, ) tensor_of_tensormap, tensor_of_tensormap_torch = cutlass_torch.cute_tensor_like( torch.empty(tensormap_shape, dtype=torch.int64), cutlass.Int64, is_dynamic_layout=False, ) grouped_blockscaled_gemm = Sm100GroupedBlockScaledGemmKernel( sf_vec_size, mma_tiler_mn, cluster_shape_mn, ) # layout (num_groups, 4):(4, 1) ( tensor_of_dim_size_mnkl, tensor_of_dim_size_mnkl_torch, ) = cutlass_torch.cute_tensor_like( torch.tensor(problem_sizes_mnkl, dtype=torch.int32), cutlass.Int32, is_dynamic_layout=False, assumed_align=16, ) # layout (num_groups, 3, 2):(6, 2, 1) tensor_of_strides_abc, tensor_of_strides_abc_torch = cutlass_torch.cute_tensor_like( torch.tensor(strides_abc, dtype=torch.int32), cutlass.Int32, is_dynamic_layout=False, assumed_align=16, ) # layout (num_groups,3):(3, 1) tensor_of_ptrs_abc, tensor_of_ptrs_abc_torch = cutlass_torch.cute_tensor_like( torch.tensor(ptrs_abc, dtype=torch.int64), cutlass.Int64, is_dynamic_layout=False, assumed_align=16, ) # layout (num_groups,2):(2, 1) tensor_of_ptrs_sfasfb, tensor_of_ptrs_sfasfb_torch = cutlass_torch.cute_tensor_like( torch.tensor(ptrs_sfasfb, dtype=torch.int64), cutlass.Int64, is_dynamic_layout=False, assumed_align=16, ) # Compute total number of cluster tiles we need to compute for given grouped GEMM problem def compute_total_num_clusters( problem_sizes_mnkl: List[tuple[int, int, int, int]], cluster_tile_shape_mn: tuple[int, int], ) -> int: total_num_clusters = 0 for m, n, _, _ in problem_sizes_mnkl: num_clusters_mn = tuple( (x + y - 1) // y for x, y in zip((m, n), cluster_tile_shape_mn) ) total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn) return total_num_clusters # Compute cluster tile shape def compute_cluster_tile_shape( mma_tiler_mn: tuple[int, int], cluster_shape_mn: tuple[int, int], ) -> tuple[int, int]: cta_tile_shape_mn = [128, mma_tiler_mn[1]] return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn)) cluster_tile_shape_mn = compute_cluster_tile_shape(mma_tiler_mn, cluster_shape_mn) total_num_clusters = compute_total_num_clusters( problem_sizes_mnkl, cluster_tile_shape_mn ) # Initialize Stream current_stream = cutlass_torch.default_stream() # If the host problem shape is available, we will launch the grid with only # the necessary clusters. The function compute_total_num_clusters() does that. # If the problem shape only exists on device, we will need to launch all active # clusters possible on a device. if host_problem_shape_available: print("Problem shapes available on host and device") total_num_clusters = compute_total_num_clusters( problem_sizes_mnkl, cluster_tile_shape_mn ) else: print("Problem shapes available only on device") total_num_clusters = max_active_clusters # Compile grouped GEMM kernel compiled_grouped_gemm = cute.compile( grouped_blockscaled_gemm, initial_cute_tensors_abc[0], initial_cute_tensors_abc[1], initial_cute_tensors_abc[2], initial_cute_tensors_sfasfb[0], initial_cute_tensors_sfasfb[1], num_groups, tensor_of_dim_size_mnkl, tensor_of_strides_abc, tensor_of_ptrs_abc, tensor_of_ptrs_sfasfb, total_num_clusters, tensor_of_tensormap, max_active_clusters, current_stream, options=f"--opt-level 2", ) # reference check if not skip_ref_check: compiled_grouped_gemm( initial_cute_tensors_abc[0], initial_cute_tensors_abc[1], initial_cute_tensors_abc[2], initial_cute_tensors_sfasfb[0], initial_cute_tensors_sfasfb[1], tensor_of_dim_size_mnkl, tensor_of_strides_abc, tensor_of_ptrs_abc, tensor_of_ptrs_sfasfb, tensor_of_tensormap, current_stream, ) print("Verifying results...") for i, ( (a_ref, b_ref, c_ref), (sfa_ref, sfb_ref), (a_tensor, b_tensor, c_tensor), (m, n, k, l), ) in enumerate( zip( ref_f32_torch_tensors_abc, refs_f32_torch_tensors_sfasfb, cute_tensors_abc, problem_sizes_mnkl, ) ): ref_res_a = torch.einsum("mkl,mkl->mkl", a_ref, sfa_ref) ref_res_b = torch.einsum("nkl,nkl->nkl", b_ref, sfb_ref) ref = torch.einsum("mkl,nkl->mnl", ref_res_a, ref_res_b) print(f"checking group {i}") c_ref_device = c_ref.cuda() cute.testing.convert( c_tensor, from_dlpack(c_ref_device, assumed_align=16).mark_layout_dynamic( leading_dim=(1 if c_major == "n" else 0) ), ) c_ref = c_ref_device.cpu() if c_dtype in (cutlass.Float32, cutlass.Float16, cutlass.BFloat16): torch.testing.assert_close(c_ref, ref, atol=tolerance, rtol=1e-02) elif c_dtype in (cutlass.Float8E5M2, cutlass.Float8E4M3FN): # Convert ref : f32 -> f8 -> f32 ref_f8_ = torch.empty( *(l, m, n), dtype=torch.uint8, device="cuda" ).permute(1, 2, 0) ref_f8 = from_dlpack(ref_f8_, assumed_align=16).mark_layout_dynamic( leading_dim=1 ) ref_f8.element_type = c_dtype ref_device = ref.permute(2, 0, 1).contiguous().permute(1, 2, 0).cuda() ref_tensor = from_dlpack( ref_device, assumed_align=16 ).mark_layout_dynamic(leading_dim=1) cute.testing.convert(ref_tensor, ref_f8) cute.testing.convert(ref_f8, ref_tensor) ref = ref_device.cpu() torch.testing.assert_close(c_ref, ref, atol=tolerance, rtol=1e-02) def generate_tensors(): ( ptrs_abc_workspace, torch_tensors_abc_workspace, cute_tensors_abc_workspace, strides_abc_workspace, _, ) = create_tensors_abc_for_all_groups( problem_sizes_mnkl, ab_dtype, c_dtype, a_major, b_major, c_major, ) ( ptrs_sfasfb_workspace, torch_tensors_sfasfb_workspace, cute_tensors_sfasfb_workspace, _, ) = create_tensors_sfasfb_for_all_groups( problem_sizes_mnkl, sf_dtype, sf_vec_size, ) initial_cute_tensors_abc_workspace = [ create_tensor_and_stride( 1, min_ab_size, min_ab_size, a_major == "m", ab_dtype )[2], create_tensor_and_stride( 1, min_ab_size, min_ab_size, b_major == "n", ab_dtype )[2], create_tensor_and_stride( 1, min_c_size, min_c_size, c_major == "m", c_dtype )[2], ] initial_cute_tensors_sfasfb_workspace = [ create_tensor_and_stride( 1, min_sf_size, min_sf_size, a_major == "m", sf_dtype )[2], create_tensor_and_stride( 1, min_sf_size, min_sf_size, b_major == "n", sf_dtype )[2], ] # Create new tensors for this workspace tensor_of_strides_abc_workspace, _ = cutlass_torch.cute_tensor_like( torch.tensor(strides_abc_workspace, dtype=torch.int32), cutlass.Int32, is_dynamic_layout=False, assumed_align=16, ) tensor_of_ptrs_abc_workspace, _ = cutlass_torch.cute_tensor_like( torch.tensor(ptrs_abc_workspace, dtype=torch.int64), cutlass.Int64, is_dynamic_layout=False, assumed_align=16, ) tensor_of_ptrs_sfasfb_workspace, _ = cutlass_torch.cute_tensor_like( torch.tensor(ptrs_sfasfb_workspace, dtype=torch.int64), cutlass.Int64, is_dynamic_layout=False, assumed_align=16, ) tensormap_workspace, _ = cutlass_torch.cute_tensor_like( torch.empty(tensormap_shape, dtype=torch.int64), cutlass.Int64, is_dynamic_layout=False, ) args = cute.testing.JitArguments( initial_cute_tensors_abc_workspace[0], initial_cute_tensors_abc_workspace[1], initial_cute_tensors_abc_workspace[2], initial_cute_tensors_sfasfb_workspace[0], initial_cute_tensors_sfasfb_workspace[1], tensor_of_dim_size_mnkl, tensor_of_strides_abc_workspace, tensor_of_ptrs_abc_workspace, tensor_of_ptrs_sfasfb_workspace, tensormap_workspace, current_stream, ) args.add_to_scope([torch_tensors_abc_workspace, torch_tensors_sfasfb_workspace]) return args workspace_count = 1 if use_cold_l2: one_workspace_bytes = ( sum( [ sum( [ torch_tensor.numel() * torch_tensor.element_size() for torch_tensor in group_tensors ] ) for group_tensors in torch_tensors_abc + torch_tensors_sfasfb ] ) + # Add size of strides tensor tensor_of_strides_abc_torch.numel() * tensor_of_strides_abc_torch.element_size() + # Add size of ptrs tensor A, B, C tensor_of_ptrs_abc_torch.numel() * tensor_of_ptrs_abc_torch.element_size() + # Add size of ptrs tensor SFA, SFB tensor_of_ptrs_sfasfb_torch.numel() * tensor_of_ptrs_sfasfb_torch.element_size() + # Add size of tensormap tensor tensor_of_tensormap_torch.numel() * tensor_of_tensormap_torch.element_size() ) workspace_count = cute.testing.get_workspace_count( one_workspace_bytes, warmup_iterations, iterations ) exec_time = cute.testing.benchmark( compiled_grouped_gemm, workspace_generator=generate_tensors, workspace_count=workspace_count, stream=current_stream, warmup_iterations=warmup_iterations, iterations=iterations, ) runtime_s = exec_time / 1.0e6 fmas = 0 for group in range(num_groups): [M, N, K, _] = problem_sizes_mnkl[group] fmas += M * N * K flop = 2 * fmas gflop = flop / 1.0e9 gflops = gflop / runtime_s print("Average Runtime : ", exec_time / 1000, "ms") print("GFLOPS : ", gflops) return exec_time # Return execution time in microseconds if __name__ == "__main__": def parse_comma_separated_ints(s: str) -> tuple[int, ...]: try: return tuple(int(x.strip()) for x in s.split(",")) except ValueError: raise argparse.ArgumentTypeError( "Invalid format. Expected comma-separated integers." ) def parse_comma_separated_tuples(s: str) -> List[tuple[int, ...]]: if s.strip().startswith("("): # Split on ),( to separate tuples tuples = s.strip("()").split("),(") result = [] tuple_len = None for t in tuples: # Parse individual tuple nums = [int(x.strip()) for x in t.split(",")] # Validate tuple length consistency if tuple_len is None: tuple_len = len(nums) elif len(nums) != tuple_len: raise argparse.ArgumentTypeError( "All tuples must have the same length" ) result.append(tuple(nums)) return result raise argparse.ArgumentTypeError( "Invalid format. Expected comma-separated integers or list of tuples" ) parser = argparse.ArgumentParser( description="Example of Grouped GEMM on Blackwell." ) parser.add_argument( "--num_groups", type=int, default=2, help="Number of groups", ) parser.add_argument( "--problem_sizes_mnkl", type=parse_comma_separated_tuples, default=((128, 128, 128, 1), (128, 128, 128, 1)), help="a tuple of problem sizes for each group (comma-separated tuples)", ) parser.add_argument( "--mma_tiler_mn", type=parse_comma_separated_ints, default=(128, 128), help="Mma tile shape (comma-separated)", ) parser.add_argument( "--host_problem_shape_available", action="store_true", help="Enable the compute of grid based upon host problem shape", ) parser.add_argument( "--cluster_shape_mn", type=parse_comma_separated_ints, default=(1, 1), help="Cluster shape (comma-separated)", ) parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.Float4E2M1FN) parser.add_argument("--sf_dtype", type=cutlass.dtype, default=cutlass.Float8E8M0FNU) parser.add_argument("--sf_vec_size", type=int, default=16) parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.Float16) parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k") parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k") parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") parser.add_argument( "--tolerance", type=float, default=1e-01, help="Tolerance for validation" ) parser.add_argument( "--warmup_iterations", type=int, default=0, help="Warmup iterations" ) parser.add_argument( "--iterations", type=int, default=1, help="Number of iterations to run the kernel", ) parser.add_argument( "--skip_ref_check", action="store_true", help="Skip reference checking" ) parser.add_argument( "--use_cold_l2", action="store_true", default=False, help="Use circular buffer tensor sets to ensure L2 cold cache", ) args = parser.parse_args() if ( len(args.problem_sizes_mnkl) != 0 and len(args.problem_sizes_mnkl) != args.num_groups ): parser.error("--problem_sizes_mnkl must contain exactly num_groups tuples") # l mode must be 1 for all groups for _, _, _, l in args.problem_sizes_mnkl: if l != 1: parser.error("l must be 1 for all groups") if len(args.mma_tiler_mn) != 2: parser.error("--mma_tiler_mn must contain exactly 2 values") if len(args.cluster_shape_mn) != 2: parser.error("--cluster_shape_mn must contain exactly 2 values") run( args.num_groups, args.problem_sizes_mnkl, args.host_problem_shape_available, args.ab_dtype, args.sf_dtype, args.sf_vec_size, args.c_dtype, args.a_major, args.b_major, args.c_major, args.mma_tiler_mn, args.cluster_shape_mn, args.tolerance, args.warmup_iterations, args.iterations, args.skip_ref_check, args.use_cold_l2, ) print("PASS")