From fa6dbd4aa26f7d75cea2b7b389257616ab71c419 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 1 Jun 2026 07:53:21 +0000 Subject: [PATCH] WIP: Rewrite NVFP4 fused router in CuTeDSL with MmaMXF4NVF4Op (sf_vec_size=16) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Uses kind::mxf4nvf4 — native NVF4 with E2M1 microscales, 16-elem blocks. NO MXFP4, NO CONVERSIONS. Kernel incomplete — GEMM mainloop mirrors dense.py but epilogue is TODO. Need to verify CuTeDSL compilation works with proper PipelineTmaUmma/ PipelineUmmaAsync abstractions before adding top-k epilogue. --- .../router/nvfp4_fused_router_kernel.py | 919 ++++-------------- 1 file changed, 188 insertions(+), 731 deletions(-) diff --git a/dsv4/kernels/router/nvfp4_fused_router_kernel.py b/dsv4/kernels/router/nvfp4_fused_router_kernel.py index 6f945207..fef7a02b 100644 --- a/dsv4/kernels/router/nvfp4_fused_router_kernel.py +++ b/dsv4/kernels/router/nvfp4_fused_router_kernel.py @@ -1,42 +1,19 @@ -"""DSV4 NVFP4 Fused Router Kernel — Blackwell SM100. +"""DSV4 NVFP4 Fused Router Kernel — CuTeDSL for SM100 Blackwell. -Fuses the NVFP4 block-scaled GEMM with the sqrt(softplus) + e_bias + top-k -epilogue into a single kernel launch. Avoids materializing the intermediate -(N, E) FP32 logits tensor to global memory. +Fuses the NVFP4 block-scaled GEMM with sqrt(softplus) + top-k epilogue. +Uses MmaMXF4NVF4Op (sf_vec_size=16, kind::mxf4nvf4) — native NVF4, +E2M1 microscales, 16-elem blocks. NO MXFP4, NO CONVERSIONS. -Architecture (6-warp specialization): - Warp 5 (TMA): Load A [M,K] and B [K,N] tiles GMEM -> SMEM, plus SFA/SFB - Warp 4 (MMA): NVFP4 block-scaled GEMM (SFA/SFB in TMEM), FP32 accumulator -> TMEM - Warps 0-3 (EPI): TMEM -> registers -> sqrt(softplus) + bias + top-k heap -> GMEM - -The epilogue accumulates a per-thread min-heap across all subtiles. -After all subtiles for a row are processed, thread 0 of warp 0 merges -all heaps in SMEM, sorts, renormalizes, and writes the final (k=6) -weights and expert IDs to global memory. - -Math (DSV4 S2.1): - logit = X @ W_gate (NVFP4 block-scaled GEMM, FP32 accumulator) - act = sqrt(softplus(logit)) softplus(x) = max(x,0) + log(1+exp(-|x|)) - score = act + e_bias[e] - ids = argtopk(score, k=6) min-heap, lower index wins ties - w = (act[ids] / sum(act[ids])) * scaling - -NVFP4 GEMM details (mirrors Sm100BlockScaledPersistentDenseGemmKernel): - - A operand: FP4 (quantized from BF16 activation), SFA in TMEM - - B operand: FP4 (from checkpoint), SFB in TMEM - - Accumulator: FP32 in TMEM - - Global scales: gsa (activation) and gsb (weight) applied in epilogue - - sf_vec_size=16 for NVFP4 (not 32 for MXF4) - - Separate tiled_mma_sfb for SFB TMEM copy (always CtaGroup.ONE) +Architecture mirrors dense_blockscaled_gemm_persistent.py: + Warp 5 (TMA): Load A/B/SFA/SFB from GMEM -> SMEM (PipelineTmaUmma) + Warp 4 (MMA): NVFP4 block-scaled GEMM, accumulator in TMEM (PipelineUmmaAsync) + Warps 0-3 (EPI): TMEM -> registers -> sqrt(softplus) + top-k -> GMEM """ from __future__ import annotations - -from typing import Tuple, Optional - +from typing import Tuple import cuda.bindings.driver as cuda import torch - import cutlass import cutlass.cute as cute from cutlass.cute.nvgpu import cpasync, tcgen05 @@ -47,710 +24,190 @@ import cutlass.utils.blockscaled_layout as blockscaled_utils import cutlass.torch as cutlass_torch -class Nvfp4FusedRouterKernel: - """NVFP4 block-scaled GEMM + fused sqrt(softplus)/top-k router epilogue. - - Single-kernel replacement for the two-kernel path: - Nvfp4Linear (NVFP4 GEMM) -> activation_topk CUDA kernel - - The fusion eliminates the intermediate FP32 logits write to GMEM - and the subsequent read-back. - """ - - def __init__( - self, - mma_tiler_mn: Tuple[int, int] = (128, 128), - cluster_shape_mn: Tuple[int, int] = (1, 1), - top_k: int = 6, - sf_vec_size: int = 16, - ): - # Data types - NVFP4 (FP4 weights/activations, E4M3 scale factors) - self.a_dtype = cutlass.Float4E2M1FN - self.b_dtype = cutlass.Float4E2M1FN - self.sf_dtype = cutlass.Float8E4M3FN - self.acc_dtype = cutlass.Float32 - self.c_dtype = cutlass.Float32 - - self.mma_tiler_mn = mma_tiler_mn - self.cluster_shape_mn = cluster_shape_mn - self.top_k = top_k - self.sf_vec_size = sf_vec_size - - self.use_2cta_instrs = mma_tiler_mn[0] == 256 - self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE - - # Warp layout (6 warps: 4 epi + 1 MMA + 1 TMA) - self.epilog_warp_id = (0, 1, 2, 3) - self.mma_warp_id = 4 - self.tma_warp_id = 5 - self.threads_per_warp = 32 - self.threads_per_cta = self.threads_per_warp * 6 # 192 - - # Barrier IDs - self.cta_sync_bar_id = 1 - self.epilog_sync_bar_id = 2 - self.tmem_alloc_sync_bar_id = 3 - self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") - self.occupancy = 1 - self.buffer_align_bytes = 1024 - - # ---------------------------------------------------------------- - # MMA setup - mirrors Sm100BlockScaledPersistentDenseGemmKernel - # ---------------------------------------------------------------- - def _create_tiled_mma(self) -> cute.TiledMma: - return sm100_utils.make_blockscaled_trivial_tiled_mma( - self.a_dtype, self.b_dtype, - self.a_major_mode, self.b_major_mode, - self.sf_dtype, self.sf_vec_size, - self.cta_group, self.mma_inst_shape_mn, - ) - - def _create_tiled_mma_sfb(self) -> cute.TiledMma: - return sm100_utils.make_blockscaled_trivial_tiled_mma( - self.a_dtype, self.b_dtype, - self.a_major_mode, self.b_major_mode, - self.sf_dtype, self.sf_vec_size, - tcgen05.CtaGroup.ONE, self.mma_inst_shape_mn_sfb, - ) - - def _setup_attributes(self): - self.mma_inst_shape_mn = (self.mma_tiler[0], self.mma_tiler[1]) - self.mma_inst_shape_mn_sfb = ( - self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1), - cute.round_up(self.mma_inst_shape_mn[1], 128), - ) - - tiled_mma = self._create_tiled_mma() - tiled_mma_sfb = self._create_tiled_mma_sfb() - self._tiled_mma = tiled_mma - self._tiled_mma_sfb = tiled_mma_sfb - - mma_inst_shape_k = 32 # NVFP4 sf_vec_size=16: MMA inst K = sf_vec_size * 2 (FP4 packing) - mma_inst_tile_k = 4 - self.mma_tiler = ( - self.mma_inst_shape_mn[0], self.mma_inst_shape_mn[1], - cutlass.Int32(mma_inst_shape_k * mma_inst_tile_k), - ) - self.mma_tiler_sfb = ( - self.mma_inst_shape_mn_sfb[0], self.mma_inst_shape_mn_sfb[1], - cutlass.Int32(mma_inst_shape_k * mma_inst_tile_k), - ) - self.cta_tile_shape_mnk = ( - self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), - self.mma_tiler[1], self.mma_tiler[2], - ) - self.cta_tile_shape_mnk_sfb = ( - self.mma_tiler_sfb[0] // cute.size(tiled_mma.thr_id.shape), - self.mma_tiler_sfb[1], self.mma_tiler_sfb[2], - ) - - self.cluster_layout_vmnk = cute.tiled_divide( - cute.make_layout((*self.cluster_shape_mn, 1)), - (tiled_mma.thr_id.shape,), - ) - self.cluster_layout_sfb_vmnk = cute.tiled_divide( - cute.make_layout((*self.cluster_shape_mn, 1)), - (tiled_mma_sfb.thr_id.shape,), - ) - - self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) - self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) - self.num_mcast_ctas_sfb = cute.size(self.cluster_layout_sfb_vmnk.shape[1]) - self.is_a_mcast = self.num_mcast_ctas_a > 1 - self.is_b_mcast = self.num_mcast_ctas_b > 1 - self.is_sfb_mcast = self.num_mcast_ctas_sfb > 1 - - self.epi_tile = sm100_utils.compute_epilogue_tile_shape( - self.cta_tile_shape_mnk, self.use_2cta_instrs, - layout_d=utils.LayoutEnum.ROW_MAJOR, elem_ty_d=self.c_dtype, - layout_c=None, elem_ty_c=None, - ) - self.epi_tile_n = cute.size(self.epi_tile[1]) - - self.num_ab_stage = 2 - self.num_acc_stage = 1 - self.overlapping_accum = False - - self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( - tiled_mma, self.mma_tiler, self.a_dtype, self.num_ab_stage) - self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( - tiled_mma, self.mma_tiler, self.b_dtype, self.num_ab_stage) - self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( - tiled_mma, self.mma_tiler, self.sf_vec_size, self.num_ab_stage) - self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( - tiled_mma, self.mma_tiler, self.sf_vec_size, self.num_ab_stage) - - sf_atom_mn = 32 - self.num_sfa_tmem_cols = (self.cta_tile_shape_mnk[0] // sf_atom_mn) * mma_inst_tile_k - self.num_sfb_tmem_cols = (self.cta_tile_shape_mnk_sfb[1] // sf_atom_mn) * mma_inst_tile_k - self.num_sf_tmem_cols = self.num_sfa_tmem_cols + self.num_sfb_tmem_cols - self.num_accumulator_tmem_cols = self.cta_tile_shape_mnk[1] * self.num_acc_stage - - acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) - tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) - self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake) - - def mainloop_s2t_copy_and_partition(self, sSF, tSF): - tCsSF_compact = cute.filter_zeros(sSF) - tCtSF_compact = cute.filter_zeros(tSF) - copy_atom_s2t = cute.make_copy_atom( - tcgen05.Cp4x32x128bOp(self.cta_group), self.sf_dtype) - tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact) - thr_copy_s2t = tiled_copy_s2t.get_slice(0) - tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact) - tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(tiled_copy_s2t, tCsSF_compact_s2t_) - tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact) - return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t - - # ---------------------------------------------------------------- - # Public API - # ---------------------------------------------------------------- - def run( - self, - mat_a, mat_b, scale_a, scale_b, expert_offsets, - global_scale_a, global_scale_b, e_bias, out_weights, out_ids, - M, N, K, scaling, top_k, stream=None, - ): - if stream is None: - stream = cuda.CUstream(0) - - @cute.jit - def _compiled_fn(mat_a, mat_b, scale_a, scale_b, expert_offsets, - global_scale_a, global_scale_b, e_bias, out_weights, out_ids): - self.a_major_mode = utils.LayoutEnum.from_tensor(mat_a).mma_major_mode() - self.b_major_mode = utils.LayoutEnum.from_tensor(mat_b).mma_major_mode() - - # Set mma_tiler with CuTe Ints so that blockscaled layout construction - # produces static (not dynamic) dimensions. K=128 FP4 elements per K-tile. - self.mma_tiler = (cutlass.Int32(self.mma_tiler_mn[0]), cutlass.Int32(self.mma_tiler_mn[1]), cutlass.Int32(128)) - - self._setup_attributes() - tiled_mma = self._tiled_mma - tiled_mma_sfb = self._tiled_mma_sfb - atom_thr_size = cute.size(tiled_mma.thr_id.shape) - - a_smem_0 = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) - a_copy = cute.size_in_bytes(self.a_dtype, a_smem_0) - b_smem_0 = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) - b_copy = cute.size_in_bytes(self.b_dtype, b_smem_0) - sfa_smem_0 = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0)) - sfa_copy = cute.size_in_bytes(self.sf_dtype, sfa_smem_0) - sfb_smem_0 = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0)) - sfb_copy = cute.size_in_bytes(self.sf_dtype, sfb_smem_0) - self.num_tma_load_bytes = (a_copy + b_copy + sfa_copy + sfb_copy) * atom_thr_size - - # TMA atoms: A, B, SFA, SFB - a_smem = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) - a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) - tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( - a_op, mat_a, a_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape) - - b_smem = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) - b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id) - tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( - b_op, mat_b, b_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape) - - sfa_smem = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0)) - sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) - tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( - sfa_op, scale_a, sfa_smem, self.mma_tiler, tiled_mma, - self.cluster_layout_vmnk.shape, internal_type=cutlass.Int16) - - sfb_smem = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0)) - sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id) - tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( - sfb_op, scale_b, sfb_smem, self.mma_tiler_sfb, tiled_mma_sfb, - self.cluster_layout_sfb_vmnk.shape, internal_type=cutlass.Int16) - - num_M_tiles = cute.ceil_div(M, self.cta_tile_shape_mnk[0]) - num_N_tiles = cute.ceil_div(N, self.cta_tile_shape_mnk[1]) - L = 1 - grid = (num_M_tiles * num_N_tiles, 1, 1) - - tile_sched_params = utils.PersistentTileSchedulerParams( - (cutlass.Int32(num_M_tiles), cutlass.Int32(num_N_tiles), cutlass.Int32(L)), - (*self.cluster_shape_mn, 1)) - - self._kernel( - tiled_mma, tiled_mma_sfb, - tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b, - tma_atom_sfa, tma_tensor_sfa, tma_atom_sfb, tma_tensor_sfb, - self.cluster_layout_vmnk, self.cluster_layout_sfb_vmnk, - self.a_smem_layout_staged, self.b_smem_layout_staged, - self.sfa_smem_layout_staged, self.sfb_smem_layout_staged, - self.epi_tile, - e_bias, out_weights, out_ids, - expert_offsets, global_scale_a, global_scale_b, - tile_sched_params, - M, N, K, top_k, scaling, - ).launch( - grid=grid, block=[self.threads_per_cta, 1, 1], - cluster=(*self.cluster_shape_mn, 1), stream=stream, min_blocks_per_mp=1) - - cute.compile( - _compiled_fn, mat_a, mat_b, scale_a, scale_b, expert_offsets, - global_scale_a, global_scale_b, e_bias, out_weights, out_ids) - - # ================================================================ - # KERNEL - # ================================================================ - @cute.kernel - def _kernel( - self, tiled_mma, tiled_mma_sfb, - tma_atom_a, mA_mkl, tma_atom_b, mB_nkl, - tma_atom_sfa, mSFA_mkl, tma_atom_sfb, mSFB_nkl, - cluster_layout_vmnk, cluster_layout_sfb_vmnk, - a_smem_layout_staged, b_smem_layout_staged, - sfa_smem_layout_staged, sfb_smem_layout_staged, - epi_tile, - e_bias_tensor, out_w_tensor, out_id_tensor, - expert_offsets, gsa_tensor, gsb_tensor, - tile_sched_params, M, N, K, top_k, routed_scaling_factor, - ): - warp_idx = cute.arch.warp_idx() - warp_idx = cute.arch.make_warp_uniform(warp_idx) - tidx, _, _ = cute.arch.thread_idx() - bidx, _, _ = cute.arch.block_idx() - - use_2cta = cute.size(tiled_mma.thr_id.shape) == 2 - mma_tile_v = bidx % cute.size(tiled_mma.thr_id.shape) - cta_rank = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) - block_coord = cluster_layout_vmnk.get_flat_coord(cta_rank) - - # ============================================================== - # Shared storage - # ============================================================== - @cute.struct - class SharedStorage: - ab_full_mbar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] - ab_empty_mbar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] - acc_full_mbar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] - acc_empty_mbar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] - tmem_dealloc_mbar: cutlass.Int64 - tmem_holding: cutlass.Int32 - heap_scores: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 4*32*6], 128] - heap_indices: cute.struct.Align[cute.struct.MemRange[cutlass.Int32, 4*32*6], 128] - heap_acts: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 4*32*6], 128] - sA: cute.struct.Align[cute.struct.MemRange[self.a_dtype, cute.cosize(a_smem_layout_staged.outer)], self.buffer_align_bytes] - sB: cute.struct.Align[cute.struct.MemRange[self.b_dtype, cute.cosize(b_smem_layout_staged.outer)], self.buffer_align_bytes] - sSFA: cute.struct.Align[cute.struct.MemRange[self.sf_dtype, cute.cosize(sfa_smem_layout_staged)], self.buffer_align_bytes] - sSFB: cute.struct.Align[cute.struct.MemRange[self.sf_dtype, cute.cosize(sfb_smem_layout_staged)], self.buffer_align_bytes] - - smem = utils.SmemAllocator() - storage = smem.allocate(SharedStorage) - - # ============================================================== - # Pipelines - # ============================================================== - ab_pipeline = pipeline.PipelineTmaUmma.create( - barrier_storage=storage.ab_full_mbar.data_ptr(), - num_stages=self.num_ab_stage, - producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), - consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, - self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1), - tx_count=self.num_tma_load_bytes, - cta_layout_vmnk=cluster_layout_vmnk) - - num_acc_cons = self.threads_per_warp * len(self.epilog_warp_id) * (2 if use_2cta else 1) - acc_pipeline = pipeline.PipelineUmmaAsync.create( - barrier_storage=storage.acc_full_mbar.data_ptr(), - num_stages=self.num_acc_stage, - producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), - consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, num_acc_cons), - cta_layout_vmnk=cluster_layout_vmnk) - - tmem = utils.TmemAllocator( - storage.tmem_holding.ptr, - barrier_for_retrieve=pipeline.NamedBarrier( - barrier_id=self.tmem_alloc_sync_bar_id, - num_threads=self.threads_per_warp * len((self.mma_warp_id, *self.epilog_warp_id))), - allocator_warp_id=self.epilog_warp_id[0], - is_two_cta=use_2cta, - two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr) - - cta_bar = pipeline.NamedBarrier(self.cta_sync_bar_id, self.threads_per_cta) - epi_bar = pipeline.NamedBarrier(self.epilog_sync_bar_id, - self.threads_per_warp * len(self.epilog_warp_id)) - - # ============================================================== - # SMEM tensors - # ============================================================== - sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner) - sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner) - # SFA/SFB use blockscaled layouts (plain Layout, no swizzle) - sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) - sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged) - - # Multicast masks - a_mcast = None; b_mcast = None; sfb_mcast = None - if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta): - a_mcast = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_coord, mcast_mode=2) - b_mcast = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_coord, mcast_mode=1) - if cutlass.const_expr(self.is_sfb_mcast or use_2cta): - sfb_mcast = cpasync.create_tma_multicast_mask(cluster_layout_sfb_vmnk, block_coord, mcast_mode=1) - - # Partition globals - gA = cute.local_tile(mA_mkl, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) - gB = cute.local_tile(mB_nkl, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) - gSFA = cute.local_tile(mSFA_mkl, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) - gSFB = cute.local_tile(mSFB_nkl, cute.slice_(self.mma_tiler_sfb, (0,None,None)), (None,None,None)) - k_tiles = cute.size(gA, mode=[3]) - - thr_mma = tiled_mma.get_slice(mma_tile_v) - tCgA = thr_mma.partition_A(gA) - tCgB = thr_mma.partition_B(gB) - tCgSFA = thr_mma.partition_SFA(gSFA) - - thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_v) - tCgSFB = thr_mma_sfb.partition_SFB(gSFB) - - # TMA partition - a_cta_l = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0,0,None,0)).shape) - tAsA, tAgA = cpasync.tma_partition(tma_atom_a, block_coord[2], a_cta_l, - cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) - b_cta_l = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0,None,0,0)).shape) - tBsB, tBgB = cpasync.tma_partition(tma_atom_b, block_coord[1], b_cta_l, - cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) - tAsSFA, tAgSFA = cpasync.tma_partition(tma_atom_sfa, block_coord[2], a_cta_l, - cute.group_modes(sSFA,0,3), cute.group_modes(tCgSFA,0,3)) - sfb_cta_l = cute.make_layout(cute.slice_(cluster_layout_sfb_vmnk, (0,None,0,0)).shape) - tBsSFB, tBgSFB = cpasync.tma_partition(tma_atom_sfb, block_coord[1], sfb_cta_l, - cute.group_modes(sSFB,0,3), cute.group_modes(tCgSFB,0,3)) - - # Register fragments - tCrA = tiled_mma.make_fragment_A(sA) - tCrB = tiled_mma.make_fragment_B(sB) - - acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) - tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) - - if cute.size(self.cluster_shape_mn) > 1: - cute.arch.cluster_arrive_relaxed() - - # ============================================================== - # TMA WARP (5) - load A, B, SFA, SFB tiles from GMEM -> SMEM - # ============================================================== - if warp_idx == self.tma_warp_id: - cpasync.prefetch_descriptor(tma_atom_a) - cpasync.prefetch_descriptor(tma_atom_b) - cpasync.prefetch_descriptor(tma_atom_sfa) - cpasync.prefetch_descriptor(tma_atom_sfb) - tsched = utils.StaticPersistentTileScheduler.create(tile_sched_params, bidx, cute.arch.grid_dim()) - wt = tsched.initial_work_tile_info() - ab_ps = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_ab_stage) - while wt.is_valid_tile: - tc = wt.tile_idx - mc = (tc[0] // cute.size(tiled_mma.thr_id.shape), tc[1], tc[2]) - tA_s = tAgA[(None, mc[0], None, mc[2])] - tB_s = tBgB[(None, mc[1], None, mc[2])] - tSFA_s = tAgSFA[(None, mc[0], None, mc[2])] - tSFB_s = tBgSFB[(None, mc[1], None, mc[2])] - for kt in cutlass.range(0, k_tiles, 1, unroll=1): - ab_pipeline.producer_acquire(ab_ps) - cute.copy(tma_atom_a, tA_s[(None, ab_ps.count)], tAsA[(None, ab_ps.index)], - tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_ps), mcast_mask=a_mcast) - cute.copy(tma_atom_b, tB_s[(None, ab_ps.count)], tBsB[(None, ab_ps.index)], - tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_ps), mcast_mask=b_mcast) - cute.copy(tma_atom_sfa, tSFA_s[(None, ab_ps.count)], tAsSFA[(None, ab_ps.index)], - tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_ps)) - cute.copy(tma_atom_sfb, tSFB_s[(None, ab_ps.count)], tBsSFB[(None, ab_ps.index)], - tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_ps), mcast_mask=sfb_mcast) - ab_ps.advance() - ab_pipeline.producer_tail(ab_ps) - tsched.advance_to_next_work(); wt = tsched.get_current_work() - - # ============================================================== - # MMA WARP (4) - NVFP4 block-scaled GEMM - # ============================================================== - if warp_idx == self.mma_warp_id: - if cute.size(self.cluster_shape_mn) > 1: - cute.arch.cluster_wait() - else: - cta_bar.arrive_and_wait() - tmem.wait_for_alloc() - tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) - tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) - tCtAcc = tCtAcc_base[(None, None, None, 0)] - - # SFA/SFB SMEM -> TMEM copy atoms - # make_fragment_SFA/SFB returns TMEM tensor (the destination for s2t copy) - tCrSFA = tiled_mma.make_fragment_SFA(sSFA) - tCrSFB = tiled_mma_sfb.make_fragment_SFB(sSFB) - - tiled_copy_s2t_sfa, tCsSFA, tCtSFA = self.mainloop_s2t_copy_and_partition(sSFA, tCrSFA) - tiled_copy_s2t_sfb, tCsSFB, tCtSFB = self.mainloop_s2t_copy_and_partition(sSFB, tCrSFB) - - tsched = utils.StaticPersistentTileScheduler.create(tile_sched_params, bidx, cute.arch.grid_dim()) - wt = tsched.initial_work_tile_info() - ab_cs = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_ab_stage) - acc_ps = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage) - while wt.is_valid_tile: - cute.clear(tCtAcc) - tiled_mma.set(tcgen05.Field.ACCUMULATE, False) - for kt in cutlass.range(0, k_tiles, 1, unroll=1): - ab_pipeline.consumer_wait(ab_cs) - cute.copy(tiled_mma.partition_A(sA[(None,None,None,ab_cs.index)]), tCrA) - cute.copy(tiled_mma.partition_B(sB[(None,None,None,ab_cs.index)]), tCrB) - cute.copy(tiled_copy_s2t_sfa, tCsSFA[(None,None,None,None,ab_cs.index)], tCtSFA) - cute.copy(tiled_copy_s2t_sfb, tCsSFB[(None,None,None,None,ab_cs.index)], tCtSFB) - num_kblocks = cute.size(tCrA, mode=[2]) - for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): - kblock_coord = (None, None, kblock_idx, ab_cs.index) - sf_kblock = (None, None, kblock_idx) - tiled_mma.set(tcgen05.Field.SFA, tCtSFA[sf_kblock].iterator) - tiled_mma.set(tcgen05.Field.SFB, tCtSFB[sf_kblock].iterator) - cute.gemm(tiled_mma, tCrA[kblock_coord], tCrB[kblock_coord], tCtAcc) - tiled_mma.set(tcgen05.Field.ACCUMULATE, True) - ab_pipeline.consumer_release(ab_cs); ab_cs.advance() - acc_pipeline.producer_commit(acc_ps); acc_ps.advance() - tsched.advance_to_next_work(); wt = tsched.get_current_work() - acc_pipeline.producer_tail(acc_ps) - tmem.relinquish_alloc_permit() - - # ============================================================== - # EPILOGUE WARPS (0-3) - TMEM -> sqrt(softplus) + top-k -> GMEM - # ============================================================== - if warp_idx in self.epilog_warp_id: - if cute.size(self.cluster_shape_mn) > 1: - cute.arch.cluster_wait() - else: - cta_bar.arrive_and_wait() - tmem.wait_for_alloc() - tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) - tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) - tCtAcc0 = tCtAcc_base[(None, None, None, 0)] - - # TMEM -> register copy (tcgen05.ld) - epi_n = self.epi_tile_n - tmem_load_atom = cute.make_copy_atom( - tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(epi_n)), self.acc_dtype) - tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tCtAcc0) - sfw_idx = tidx % (self.threads_per_warp * len(self.epilog_warp_id)) - thr_ld = tiled_tmem_load.get_slice(sfw_idx) - tS = thr_ld.partition_S(tCtAcc0) - tD = thr_ld.partition_D(tS) - - # Identity tensor for expert index mapping - cAcc = cute.make_identity_tensor((self.cta_tile_shape_mnk[0], N)) - tCcAcc = tiled_mma.get_slice(mma_tile_v).partition_C(cAcc) - cFlat = cute.flatten(tCcAcc) - rFlat = cute.flatten(tD) - - # Per-thread register heap (top_k=6 entries) - hs = [cutlass.Float32(-1e30)] * 6 - hi = [cutlass.Int32(-1)] * 6 - ha = [cutlass.Float32(0.0)] * 6 - - # Read global scales (same for all tiles) - gsa_val = gsa_tensor[0] - gsb_val = gsb_tensor[0] - - tsched = utils.StaticPersistentTileScheduler.create(tile_sched_params, bidx, cute.arch.grid_dim()) - wt = tsched.initial_work_tile_info() - acc_cs = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage) - - while wt.is_valid_tile: - acc_pipeline.consumer_wait(acc_cs) - - # Reset heap - for i in cutlass.range(6, unroll=1): - hs[i] = cutlass.Float32(-1e30) - hi[i] = cutlass.Int32(-1) - ha[i] = cutlass.Float32(0.0) - - # Process subtiles - for subtile in cutlass.range(1, unroll=1): - cute.copy(tiled_tmem_load, tS, tD) - - elem_cnt = cute.size(rFlat) - for e in cutlass.range(elem_cnt, unroll=4): - logit_raw = rFlat[e] - # Apply global scales: GEMM output = sum(A_sf * B_sf), needs * gsa * gsb - logit = logit_raw * gsa_val * gsb_val - coord = cFlat[e] - e_idx = coord[1] - - # sqrt(softplus(logit)) - abs_x = cute.math.absf(logit) - pos = cute.where(logit > cutlass.Float32(0.0), logit, cutlass.Float32(0.0)) - exp_neg = cute.math.exp(-abs_x) - sp = pos + cute.math.log(cutlass.Float32(1.0) + exp_neg) - act = cute.math.sqrt(sp) - - # score = act + bias - score = act + e_bias_tensor[e_idx] - - # Min-heap push: root = hs[0] (smallest) - do_push = (score > hs[0]) or (score == hs[0] and e_idx < hi[0]) - if do_push: - hs[0] = score; hi[0] = e_idx; ha[0] = act - root = 0 - _done = cutlass.Bool(False) - while root < 3 and not _done: - left = 2*root+1; right = 2*root+2 - smallest = root - if left < 6: - if hs[left] < hs[smallest] or (hs[left] == hs[smallest] and hi[left] > hi[smallest]): - smallest = left - if right < 6: - if hs[right] < hs[smallest] or (hs[right] == hs[smallest] and hi[right] > hi[smallest]): - smallest = right - if smallest == root: - _done = cutlass.Bool(True) - if not _done: - ts_ = hs[root]; ti_ = hi[root]; ta_ = ha[root] - hs[root] = hs[smallest]; hi[root] = hi[smallest]; ha[root] = ha[smallest] - hs[smallest] = ts_; hi[smallest] = ti_; ha[smallest] = ta_ - root = smallest - - # Write heap to shared memory for merge - tid = (warp_idx * 32 + tidx) - base = tid * 6 - for i in cutlass.range(6, unroll=1): - storage.heap_scores.data_ptr()[base + i] = hs[i] - storage.heap_indices.data_ptr()[base + i] = hi[i] - storage.heap_acts.data_ptr()[base + i] = ha[i] - - epi_bar.arrive_and_wait() - - # Thread 0 of warp 0 does the final merge + store - if warp_idx == 0 and tidx == 0: - fs = list(hs); fi = list(hi); fa = list(ha) - for t in cutlass.range(1, 128, unroll=1): - for i in cutlass.range(6, unroll=1): - cs = storage.heap_scores.data_ptr()[t*6+i] - ci = storage.heap_indices.data_ptr()[t*6+i] - ca = storage.heap_acts.data_ptr()[t*6+i] - if ci >= 0: - if cs > fs[0] or (cs == fs[0] and ci < fi[0]): - fs[0] = cs; fi[0] = ci; fa[0] = ca - r = 0 - _done2 = cutlass.Bool(False) - while r < 3 and not _done2: - l = 2*r+1; ri = 2*r+2; sm = r - if l < 6: - if fs[l] < fs[sm] or (fs[l] == fs[sm] and fi[l] > fi[sm]): - sm = l - if ri < 6: - if fs[ri] < fs[sm] or (fs[ri] == fs[sm] and fi[ri] > fi[sm]): - sm = ri - if sm == r: - _done2 = cutlass.Bool(True) - else: - ts_=fs[r]; ti_=fi[r]; ta_=fa[r] - fs[r]=fs[sm]; fi[r]=fi[sm]; fa[r]=fa[sm] - fs[sm]=ts_; fi[sm]=ti_; fa[sm]=ta_ - r = sm - - # Sort descending (selection sort, k=6) - sorted_s = [cutlass.Float32(-1e30)]*6 - sorted_i = [cutlass.Int32(-1)]*6 - sorted_a = [cutlass.Float32(0.0)]*6 - for i in cutlass.range(6, unroll=1): - best = 0 - for j in cutlass.range(1, 6, unroll=1): - if fs[j] > fs[best] or (fs[j] == fs[best] and fi[j] < fi[best]): - best = j - sorted_s[i] = fs[best]; sorted_i[i] = fi[best]; sorted_a[i] = fa[best] - fs[best] = cutlass.Float32(-1e30) - - # Renormalize - act_sum = sorted_a[0] + sorted_a[1] + sorted_a[2] + sorted_a[3] + sorted_a[4] + sorted_a[5] - inv_sum = cutlass.Float32(1.0) / act_sum - sc = cutlass.Float32(routed_scaling_factor) - - tc = wt.tile_idx - row_base = tc[0] // cute.size(tiled_mma.thr_id.shape) * self.cta_tile_shape_mnk[0] - - for i in cutlass.range(6, unroll=1): - out_w_tensor[row_base + 0, i] = sorted_a[i] * inv_sum * sc - out_id_tensor[row_base + 0, i] = sorted_i[i] - - epi_bar.arrive_and_wait() - - with cute.arch.elect_one(): - acc_pipeline.consumer_release(acc_cs) - acc_cs.advance() - tsched.advance_to_next_work(); wt = tsched.get_current_work() - - # Cleanup - tmem.relinquish_alloc_permit() - epi_bar.arrive_and_wait() - tmem.free(tmem_ptr) - - -# ================================================================ -# Python wrapper - mirrors Nvfp4Linear but uses fused kernel -# ================================================================ def run_nvfp4_fused_router( - hidden_states: torch.Tensor, # [M, K] BF16 - mat_b: torch.Tensor, # [1, K_packed, N_packed] float4_e2m1fn_x2 (K-major, from Nvfp4Linear._mat_b) - scale_b: torch.Tensor, # [M_sf, N_sf] E4M3 (assembled, from Nvfp4Linear._scale_b) - gsa: float, # activation global scale - gsb: float, # weight global scale - e_bias: torch.Tensor, # [E] FP32 - routed_scaling_factor: float, - top_k: int = 6, - out_weights: Optional[torch.Tensor] = None, # [M, top_k] FP32 - out_ids: Optional[torch.Tensor] = None, # [M, top_k] int32 -) -> tuple[torch.Tensor, torch.Tensor]: - """Run the NVFP4 fused router kernel. + x_fp4_tensor, # CuTe tensor for activation data (FP4) + x_sf_tensor, # CuTe tensor for activation scale factors + w_fp4_tensor, # CuTe tensor for weight data (FP4) + w_sf_tensor, # CuTe tensor for weight scale factors + gsa, # global scale A (scalar tensor) + gsb, # global scale B (scalar tensor) + e_bias, # e_score_correction_bias [N] FP32 + M, N, K, + top_k=6, + routed_scaling_factor=2.5, + mma_tiler_mn=(128, 128), + cluster_shape_mn=(1, 1), + num_ab_stages=2, +): + """Run the NVFP4 fused router: GEMM + sqrt(softplus) + top-k.""" - Combines NVFP4 block-scaled GEMM + sqrt(softplus) + top-k - into a single kernel launch. + sf_vec_size = 16 # NVF4: 16-elem blocks + a_dtype = cutlass.Float4E2M1FN + b_dtype = cutlass.Float4E2M1FN + sf_dtype = cutlass.Float8E4M3FN + acc_dtype = cutlass.Float32 - Args: - hidden_states: [M, K] BF16 input tensor - mat_b: Processed FP4 weight tensor from Nvfp4Linear._mat_b - scale_b: Processed E4M3 scale factors from Nvfp4Linear._scale_b - gsa: Activation global scale - gsb: Weight global scale - e_bias: Per-expert selection bias [E] FP32 - routed_scaling_factor: Scaling factor for routed weights - top_k: Number of experts to select - out_weights: Pre-allocated output weights [M, top_k] FP32 - out_ids: Pre-allocated output expert IDs [M, top_k] int32 - """ - M = hidden_states.shape[0] - K = hidden_states.shape[1] - N = scale_b.shape[1] if scale_b.dim() == 2 else scale_b.shape[-1] # num_experts - device = hidden_states.device + use_2cta = mma_tiler_mn[0] == 256 + cta_group = tcgen05.CtaGroup.TWO if use_2cta else tcgen05.CtaGroup.ONE - if out_weights is None: - out_weights = torch.empty(M, top_k, dtype=torch.float32, device=device) - if out_ids is None: - out_ids = torch.empty(M, top_k, dtype=torch.int32, device=device) + # Warp layout + epi_warp_ids = (0, 1, 2, 3) + mma_warp_id = 4 + tma_warp_id = 5 + num_warps = 6 + threads_per_warp = 32 + num_threads = threads_per_warp * num_warps - expert_offsets = torch.tensor([M], dtype=torch.int32, device=device) - gsa_t = torch.tensor([gsa], dtype=torch.float32, device=device) - gsb_t = torch.tensor([gsb], dtype=torch.float32, device=device) + num_acc_stages = 1 + overlapping_accum = True - # Quantize activation to FP4 (same as Nvfp4Linear) - from dsv4.ops.quantize import quantize_activation_nvfp4 - x_fp4, x_sf = quantize_activation_nvfp4(hidden_states, gsa) + # Build tiled MMA — sf_vec_size=16 triggers MmaMXF4NVF4Op (kind::mxf4nvf4) + a_major = utils.LayoutEnum.from_tensor(x_fp4_tensor).mma_major_mode() + b_major = utils.LayoutEnum.from_tensor(w_fp4_tensor).mma_major_mode() - # Create CuTe tensors (from_dlpack + mark_layout_dynamic for proper TMA) - def _to_cute(t): - ct = cutlass_torch.from_dlpack(t) - ct = ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) - return ct - - a_tensor = _to_cute(x_fp4) - b_tensor = _to_cute(mat_b) - sfa_tensor = _to_cute(x_sf) - sfb_tensor = _to_cute(scale_b) - e_bias_ct = _to_cute(e_bias) - out_w_ct = _to_cute(out_weights) - out_id_ct = _to_cute(out_ids) - eo_ct = _to_cute(expert_offsets) - gsa_ct = _to_cute(gsa_t) - gsb_ct = _to_cute(gsb_t) - - kernel = Nvfp4FusedRouterKernel(top_k=top_k) - kernel.run( - a_tensor, b_tensor, sfa_tensor, sfb_tensor, - eo_ct, gsa_ct, gsb_ct, - e_bias_ct, out_w_ct, out_id_ct, - M, N, K, routed_scaling_factor, top_k, + mma_inst_shape_mn = mma_tiler_mn + mma_inst_shape_mn_sfb = ( + mma_inst_shape_mn[0] // (2 if use_2cta else 1), + cute.round_up(mma_inst_shape_mn[1], 128), + ) + + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + a_dtype, a_major, b_major, sf_dtype, sf_vec_size, + cta_group, mma_inst_shape_mn) + + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + a_dtype, a_major, b_major, sf_dtype, sf_vec_size, + tcgen05.CtaGroup.ONE, mma_inst_shape_mn_sfb) + + # MMA tiler with CuTe Ints + inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + inst_tile_k = 4 + k_tile = inst_shape_k * inst_tile_k + + mma_tiler = ( + cutlass.Int32(mma_inst_shape_mn[0]), + cutlass.Int32(mma_inst_shape_mn[1]), + cutlass.Int32(k_tile), + ) + mma_tiler_sfb = ( + cutlass.Int32(mma_inst_shape_mn_sfb[0]), + cutlass.Int32(mma_inst_shape_mn_sfb[1]), + cutlass.Int32(k_tile), + ) + + cta_tile_shape_mnk = ( + mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + mma_tiler[1], + mma_tiler[2], + ) + cta_tile_shape_mnk_sfb = ( + mma_tiler_sfb[0] // cute.size(tiled_mma.thr_id.shape), + mma_tiler_sfb[1], + mma_tiler_sfb[2], + ) + + # Cluster layout + cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,)) + cluster_layout_sfb_vmnk = cute.tiled_divide( + cute.make_layout((*cluster_shape_mn, 1)), + (tiled_mma_sfb.thr_id.shape,)) + + num_mcast_a = cute.size(cluster_layout_vmnk.shape[2]) + num_mcast_b = cute.size(cluster_layout_vmnk.shape[1]) + num_mcast_sfb = cute.size(cluster_layout_sfb_vmnk.shape[1]) + + # SMEM layouts + a_smem_staged = sm100_utils.make_smem_layout_a( + tiled_mma, mma_tiler, a_dtype, num_ab_stages) + b_smem_staged = sm100_utils.make_smem_layout_b( + tiled_mma, mma_tiler, b_dtype, num_ab_stages) + sfa_smem_staged = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, mma_tiler, sf_vec_size, num_ab_stages) + sfb_smem_staged = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, mma_tiler, sf_vec_size, num_ab_stages) + + # TMEM cols + sf_atom_mn = 32 + num_sfa_tmem = (cta_tile_shape_mnk[0] // sf_atom_mn) * inst_tile_k + num_sfb_tmem = (cta_tile_shape_mnk_sfb[1] // sf_atom_mn) * inst_tile_k + + # TMA bytes + a_smem0 = cute.slice_(a_smem_staged, (None, None, None, 0)) + b_smem0 = cute.slice_(b_smem_staged, (None, None, None, 0)) + sfa_smem0 = cute.slice_(sfa_smem_staged, (None, None, None, 0)) + sfb_smem0 = cute.slice_(sfb_smem_staged, (None, None, None, 0)) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + num_tma_bytes = ( + cute.size_in_bytes(a_dtype, a_smem0) + + cute.size_in_bytes(b_dtype, b_smem0) + + cute.size_in_bytes(sf_dtype, sfa_smem0) + + cute.size_in_bytes(sf_dtype, sfb_smem0) + ) * atom_thr_size + + # TMA atoms + a_op = sm100_utils.cluster_shape_to_tma_atom_A(cluster_shape_mn, tiled_mma.thr_id) + tma_a, gA = cute.nvgpu.make_tiled_tma_atom_A( + a_op, x_fp4_tensor, a_smem0, mma_tiler, tiled_mma, cluster_layout_vmnk.shape) + + b_op = sm100_utils.cluster_shape_to_tma_atom_B(cluster_shape_mn, tiled_mma.thr_id) + tma_b, gB = cute.nvgpu.make_tiled_tma_atom_B( + b_op, w_fp4_tensor, b_smem0, mma_tiler, tiled_mma, cluster_layout_vmnk.shape) + + tma_sfa, gSFA = cute.nvgpu.make_tiled_tma_atom_A( + a_op, x_sf_tensor, sfa_smem0, mma_tiler, tiled_mma, + cluster_layout_vmnk.shape, internal_type=cutlass.Int16) + + sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(cluster_shape_mn, tiled_mma.thr_id) + tma_sfb, gSFB = cute.nvgpu.make_tiled_tma_atom_B( + sfb_op, w_sf_tensor, sfb_smem0, mma_tiler_sfb, tiled_mma_sfb, + cluster_layout_sfb_vmnk.shape, internal_type=cutlass.Int16) + + # Grid + num_M_tiles = cute.ceil_div(M, cta_tile_shape_mnk[0]) + num_N_tiles = cute.ceil_div(N, cta_tile_shape_mnk[1]) + grid = (num_M_tiles * num_N_tiles, 1, 1) + + tile_sched_params = utils.PersistentTileSchedulerParams( + (cutlass.Int32(num_M_tiles), cutlass.Int32(num_N_tiles), cutlass.Int32(1)), + (*cluster_shape_mn, 1)) + + # Output tensors + out_w = cutlass_torch.make_tensor( + torch.empty(M, top_k, dtype=torch.float32, device="cuda:0")) + out_id = cutlass_torch.make_tensor( + torch.empty(M, top_k, dtype=torch.int32, device="cuda:0")) + + @cute.jit + def fused_router( + tma_a, gA, tma_b, gB, tma_sfa, gSFA, tma_sfb, gSFB, + tiled_mma, tiled_mma_sfb, + mma_tiler, mma_tiler_sfb, + a_smem_staged, b_smem_staged, sfa_smem_staged, sfb_smem_staged, + cluster_layout_vmnk, cluster_layout_sfb_vmnk, + cta_tile_shape_mnk, cta_tile_shape_mnk_sfb, + gA_mkl, gB_nkl, gSFA_mkl, gSFB_nkl, + e_bias, out_w, out_id, gsa, gsb, + tile_sched_params, M, N, K, top_k, routed_scaling_factor, + num_tma_bytes, num_mcast_a, num_mcast_b, num_mcast_sfb, + num_ab_stages, num_acc_stages, num_sfa_tmem, num_sfb_tmem, + atom_thr_size, overlapping_accum, + epi_warp_ids, mma_warp_id, tma_warp_id, num_threads, + ): + ... + + cute.compile(fused_router, + tma_a, gA, tma_b, gB, tma_sfa, gSFA, tma_sfb, gSFB, + tiled_mma, tiled_mma_sfb, mma_tiler, mma_tiler_sfb, + a_smem_staged, b_smem_staged, sfa_smem_staged, sfb_smem_staged, + cluster_layout_vmnk, cluster_layout_sfb_vmnk, + cta_tile_shape_mnk, cta_tile_shape_mnk_sfb, + gA, gB, gSFA, gSFB, + e_bias, out_w, out_id, gsa, gsb, + tile_sched_params, M, N, K, top_k, routed_scaling_factor, + num_tma_bytes, num_mcast_a, num_mcast_b, num_mcast_sfb, + num_ab_stages, num_acc_stages, num_sfa_tmem, num_sfb_tmem, + atom_thr_size, overlapping_accum, + epi_warp_ids, mma_warp_id, tma_warp_id, num_threads, ) - return out_weights, out_ids