diff --git a/dsv4/kernels/router/nvfp4_fused_router_kernel.py b/dsv4/kernels/router/nvfp4_fused_router_kernel.py index fef7a02b..5439574a 100644 --- a/dsv4/kernels/router/nvfp4_fused_router_kernel.py +++ b/dsv4/kernels/router/nvfp4_fused_router_kernel.py @@ -1,19 +1,31 @@ -"""DSV4 NVFP4 Fused Router Kernel — CuTeDSL for SM100 Blackwell. +"""DSV4 NVFP4 Fused Router Kernel — Block-scaled GEMM + Router Epilogue. -Fuses the NVFP4 block-scaled GEMM with sqrt(softplus) + top-k epilogue. -Uses MmaMXF4NVF4Op (sf_vec_size=16, kind::mxf4nvf4) — native NVF4, -E2M1 microscales, 16-elem blocks. NO MXFP4, NO CONVERSIONS. +Single-kernel path: NVFP4 block-scaled GEMM (A: activation FP4, B: gate weight FP4) +with fused router epilogue (sqrt(softplus) + e_bias + top-k + renorm). -Architecture mirrors dense_blockscaled_gemm_persistent.py: - Warp 5 (TMA): Load A/B/SFA/SFB from GMEM -> SMEM (PipelineTmaUmma) - Warp 4 (MMA): NVFP4 block-scaled GEMM, accumulator in TMEM (PipelineUmmaAsync) - Warps 0-3 (EPI): TMEM -> registers -> sqrt(softplus) + top-k -> GMEM +PRODUCTION KERNEL. No intermediate GMEM buffer. No BF16 fallback. +The GEMM accumulates logits in TMEM, then the epilogue warps process them directly: + 1. TMEM -> registers (via paired t2r atom from CUTLASS epilogue helpers) + 2. For each logit: sqrt(softplus(logit)) + e_bias -> score; track top-k via min-heap + 3. After all subtiles: sort, renormalize, write (topk_weights, topk_ids) to GMEM + +Warp specialization (6 warps, no scheduler for dense GEMM): + Warps 0-3: Epilogue (TMEM -> register -> router logic -> GMEM) + Warp 4: MMA (tcgen05.mma.block_scale with SFA/SFB in TMEM) + Warp 5: TMA load (A, B, SFA, SFB from GMEM -> SMEM) + +Pipeline structure (2 pipelines): + AB pipeline: TMA (producer) -> MMA (consumer) [PipelineTmaUmma] + Acc pipeline: MMA (producer) -> Epilogue (consumer) [PipelineUmmaAsync] """ from __future__ import annotations from typing import Tuple +import math + import cuda.bindings.driver as cuda import torch + import cutlass import cutlass.cute as cute from cutlass.cute.nvgpu import cpasync, tcgen05 @@ -21,193 +33,830 @@ import cutlass.utils as utils import cutlass.pipeline as pipeline import cutlass.utils.blackwell_helpers as sm100_utils import cutlass.utils.blockscaled_layout as blockscaled_utils -import cutlass.torch as cutlass_torch +from cutlass.utils.gemm.sm100 import ( + epilogue_tmem_copy_and_partition, + epilogue_smem_copy_and_partition, + transform_partitioned_tensor_layout, +) -def run_nvfp4_fused_router( - x_fp4_tensor, # CuTe tensor for activation data (FP4) - x_sf_tensor, # CuTe tensor for activation scale factors - w_fp4_tensor, # CuTe tensor for weight data (FP4) - w_sf_tensor, # CuTe tensor for weight scale factors - gsa, # global scale A (scalar tensor) - gsb, # global scale B (scalar tensor) - e_bias, # e_score_correction_bias [N] FP32 - M, N, K, - top_k=6, - routed_scaling_factor=2.5, - mma_tiler_mn=(128, 128), - cluster_shape_mn=(1, 1), - num_ab_stages=2, -): - """Run the NVFP4 fused router: GEMM + sqrt(softplus) + top-k.""" +class Nvfp4FusedRouterKernel: - sf_vec_size = 16 # NVF4: 16-elem blocks - a_dtype = cutlass.Float4E2M1FN - b_dtype = cutlass.Float4E2M1FN - sf_dtype = cutlass.Float8E4M3FN - acc_dtype = cutlass.Float32 - - use_2cta = mma_tiler_mn[0] == 256 - cta_group = tcgen05.CtaGroup.TWO if use_2cta else tcgen05.CtaGroup.ONE - - # Warp layout - epi_warp_ids = (0, 1, 2, 3) - mma_warp_id = 4 - tma_warp_id = 5 - num_warps = 6 - threads_per_warp = 32 - num_threads = threads_per_warp * num_warps - - num_acc_stages = 1 - overlapping_accum = True - - # Build tiled MMA — sf_vec_size=16 triggers MmaMXF4NVF4Op (kind::mxf4nvf4) - a_major = utils.LayoutEnum.from_tensor(x_fp4_tensor).mma_major_mode() - b_major = utils.LayoutEnum.from_tensor(w_fp4_tensor).mma_major_mode() - - mma_inst_shape_mn = mma_tiler_mn - mma_inst_shape_mn_sfb = ( - mma_inst_shape_mn[0] // (2 if use_2cta else 1), - cute.round_up(mma_inst_shape_mn[1], 128), - ) - - tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( - a_dtype, a_major, b_major, sf_dtype, sf_vec_size, - cta_group, mma_inst_shape_mn) - - tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( - a_dtype, a_major, b_major, sf_dtype, sf_vec_size, - tcgen05.CtaGroup.ONE, mma_inst_shape_mn_sfb) - - # MMA tiler with CuTe Ints - inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) - inst_tile_k = 4 - k_tile = inst_shape_k * inst_tile_k - - mma_tiler = ( - cutlass.Int32(mma_inst_shape_mn[0]), - cutlass.Int32(mma_inst_shape_mn[1]), - cutlass.Int32(k_tile), - ) - mma_tiler_sfb = ( - cutlass.Int32(mma_inst_shape_mn_sfb[0]), - cutlass.Int32(mma_inst_shape_mn_sfb[1]), - cutlass.Int32(k_tile), - ) - - cta_tile_shape_mnk = ( - mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), - mma_tiler[1], - mma_tiler[2], - ) - cta_tile_shape_mnk_sfb = ( - mma_tiler_sfb[0] // cute.size(tiled_mma.thr_id.shape), - mma_tiler_sfb[1], - mma_tiler_sfb[2], - ) - - # Cluster layout - cluster_layout_vmnk = cute.tiled_divide( - cute.make_layout((*cluster_shape_mn, 1)), - (tiled_mma.thr_id.shape,)) - cluster_layout_sfb_vmnk = cute.tiled_divide( - cute.make_layout((*cluster_shape_mn, 1)), - (tiled_mma_sfb.thr_id.shape,)) - - num_mcast_a = cute.size(cluster_layout_vmnk.shape[2]) - num_mcast_b = cute.size(cluster_layout_vmnk.shape[1]) - num_mcast_sfb = cute.size(cluster_layout_sfb_vmnk.shape[1]) - - # SMEM layouts - a_smem_staged = sm100_utils.make_smem_layout_a( - tiled_mma, mma_tiler, a_dtype, num_ab_stages) - b_smem_staged = sm100_utils.make_smem_layout_b( - tiled_mma, mma_tiler, b_dtype, num_ab_stages) - sfa_smem_staged = blockscaled_utils.make_smem_layout_sfa( - tiled_mma, mma_tiler, sf_vec_size, num_ab_stages) - sfb_smem_staged = blockscaled_utils.make_smem_layout_sfb( - tiled_mma, mma_tiler, sf_vec_size, num_ab_stages) - - # TMEM cols - sf_atom_mn = 32 - num_sfa_tmem = (cta_tile_shape_mnk[0] // sf_atom_mn) * inst_tile_k - num_sfb_tmem = (cta_tile_shape_mnk_sfb[1] // sf_atom_mn) * inst_tile_k - - # TMA bytes - a_smem0 = cute.slice_(a_smem_staged, (None, None, None, 0)) - b_smem0 = cute.slice_(b_smem_staged, (None, None, None, 0)) - sfa_smem0 = cute.slice_(sfa_smem_staged, (None, None, None, 0)) - sfb_smem0 = cute.slice_(sfb_smem_staged, (None, None, None, 0)) - atom_thr_size = cute.size(tiled_mma.thr_id.shape) - num_tma_bytes = ( - cute.size_in_bytes(a_dtype, a_smem0) + - cute.size_in_bytes(b_dtype, b_smem0) + - cute.size_in_bytes(sf_dtype, sfa_smem0) + - cute.size_in_bytes(sf_dtype, sfb_smem0) - ) * atom_thr_size - - # TMA atoms - a_op = sm100_utils.cluster_shape_to_tma_atom_A(cluster_shape_mn, tiled_mma.thr_id) - tma_a, gA = cute.nvgpu.make_tiled_tma_atom_A( - a_op, x_fp4_tensor, a_smem0, mma_tiler, tiled_mma, cluster_layout_vmnk.shape) - - b_op = sm100_utils.cluster_shape_to_tma_atom_B(cluster_shape_mn, tiled_mma.thr_id) - tma_b, gB = cute.nvgpu.make_tiled_tma_atom_B( - b_op, w_fp4_tensor, b_smem0, mma_tiler, tiled_mma, cluster_layout_vmnk.shape) - - tma_sfa, gSFA = cute.nvgpu.make_tiled_tma_atom_A( - a_op, x_sf_tensor, sfa_smem0, mma_tiler, tiled_mma, - cluster_layout_vmnk.shape, internal_type=cutlass.Int16) - - sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(cluster_shape_mn, tiled_mma.thr_id) - tma_sfb, gSFB = cute.nvgpu.make_tiled_tma_atom_B( - sfb_op, w_sf_tensor, sfb_smem0, mma_tiler_sfb, tiled_mma_sfb, - cluster_layout_sfb_vmnk.shape, internal_type=cutlass.Int16) - - # Grid - num_M_tiles = cute.ceil_div(M, cta_tile_shape_mnk[0]) - num_N_tiles = cute.ceil_div(N, cta_tile_shape_mnk[1]) - grid = (num_M_tiles * num_N_tiles, 1, 1) - - tile_sched_params = utils.PersistentTileSchedulerParams( - (cutlass.Int32(num_M_tiles), cutlass.Int32(num_N_tiles), cutlass.Int32(1)), - (*cluster_shape_mn, 1)) - - # Output tensors - out_w = cutlass_torch.make_tensor( - torch.empty(M, top_k, dtype=torch.float32, device="cuda:0")) - out_id = cutlass_torch.make_tensor( - torch.empty(M, top_k, dtype=torch.int32, device="cuda:0")) - - @cute.jit - def fused_router( - tma_a, gA, tma_b, gB, tma_sfa, gSFA, tma_sfb, gSFB, - tiled_mma, tiled_mma_sfb, - mma_tiler, mma_tiler_sfb, - a_smem_staged, b_smem_staged, sfa_smem_staged, sfb_smem_staged, - cluster_layout_vmnk, cluster_layout_sfb_vmnk, - cta_tile_shape_mnk, cta_tile_shape_mnk_sfb, - gA_mkl, gB_nkl, gSFA_mkl, gSFB_nkl, - e_bias, out_w, out_id, gsa, gsb, - tile_sched_params, M, N, K, top_k, routed_scaling_factor, - num_tma_bytes, num_mcast_a, num_mcast_b, num_mcast_sfb, - num_ab_stages, num_acc_stages, num_sfa_tmem, num_sfb_tmem, - atom_thr_size, overlapping_accum, - epi_warp_ids, mma_warp_id, tma_warp_id, num_threads, + def __init__( + self, + sf_vec_size: int = 16, + mma_tiler_mnk: Tuple[int, int, int] = (128, 128, 64), + cluster_shape_mnk: Tuple[int, int, int] = (1, 1, 1), + top_k: int = 6, ): - ... + self.sf_vec_size = sf_vec_size + self.mma_tiler_mnk = mma_tiler_mnk + self.cluster_shape_mn = (cluster_shape_mnk[0], cluster_shape_mnk[1]) + self.top_k = top_k + self.use_2cta_instrs = mma_tiler_mnk[0] == 256 + self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE - cute.compile(fused_router, - tma_a, gA, tma_b, gB, tma_sfa, gSFA, tma_sfb, gSFB, - tiled_mma, tiled_mma_sfb, mma_tiler, mma_tiler_sfb, - a_smem_staged, b_smem_staged, sfa_smem_staged, sfb_smem_staged, - cluster_layout_vmnk, cluster_layout_sfb_vmnk, - cta_tile_shape_mnk, cta_tile_shape_mnk_sfb, - gA, gB, gSFA, gSFB, - e_bias, out_w, out_id, gsa, gsb, - tile_sched_params, M, N, K, top_k, routed_scaling_factor, - num_tma_bytes, num_mcast_a, num_mcast_b, num_mcast_sfb, - num_ab_stages, num_acc_stages, num_sfa_tmem, num_sfb_tmem, - atom_thr_size, overlapping_accum, - epi_warp_ids, mma_warp_id, tma_warp_id, num_threads, + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4 + self.tma_warp_id = 5 + self.threads_per_warp = 32 + self.threads_per_cta = self.threads_per_warp * 6 + + self.cta_sync_bar_id = 1 + self.epilogue_sync_bar_id = 2 + self.tmem_alloc_sync_bar_id = 3 + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + self.occupancy = 1 + self.buffer_align_bytes = 1024 + + def _create_tiled_mma(self, a_dtype, a_major_mode, b_major_mode, sf_dtype): + return sm100_utils.make_blockscaled_trivial_tiled_mma( + a_dtype, a_major_mode, b_major_mode, sf_dtype, + self.sf_vec_size, self.cta_group, + (self.mma_tiler_mnk[0], self.mma_tiler_mnk[1]), + ) + + def _create_tiled_mma_sfb(self, a_dtype, a_major_mode, b_major_mode, sf_dtype): + mma_inst_shape_mn_sfb = ( + self.mma_tiler_mnk[0] // (2 if self.use_2cta_instrs else 1), + cute.round_up(self.mma_tiler_mnk[1], 128), + ) + return sm100_utils.make_blockscaled_trivial_tiled_mma( + a_dtype, a_major_mode, b_major_mode, sf_dtype, + self.sf_vec_size, tcgen05.CtaGroup.ONE, mma_inst_shape_mn_sfb, + ) + + def _setup_attributes(self, tiled_mma, tiled_mma_sfb, a_dtype, b_dtype, sf_dtype): + self.mma_inst_shape_mn = (self.mma_tiler_mnk[0], self.mma_tiler_mnk[1]) + self.mma_inst_shape_mn_sfb = ( + self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1), + cute.round_up(self.mma_inst_shape_mn[1], 128), + ) + + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = self.mma_tiler_mnk[2] // mma_inst_shape_k + + self.mma_tiler = ( + self.mma_inst_shape_mn[0], + self.mma_inst_shape_mn[1], + self.mma_tiler_mnk[2], + ) + self.mma_tiler_sfb = ( + self.mma_inst_shape_mn_sfb[0], + self.mma_inst_shape_mn_sfb[1], + self.mma_tiler_mnk[2], + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], self.mma_tiler[2], + ) + self.cta_tile_shape_mnk_sfb = ( + self.mma_tiler_sfb[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler_sfb[1], self.mma_tiler_sfb[2], + ) + + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,)) + self.cluster_layout_sfb_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma_sfb.thr_id.shape,)) + + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.num_mcast_ctas_sfb = cute.size(self.cluster_layout_sfb_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + self.is_sfb_mcast = self.num_mcast_ctas_sfb > 1 + + self.epi_tile = ( + cute.make_layout(self.cta_tile_shape_mnk[0]), + cute.make_layout(self.cta_tile_shape_mnk[1]), + ) + self.epi_tile_n = cute.size(self.epi_tile[1]) + + self.overlapping_accum = self.cta_tile_shape_mnk[1] == 256 + self.num_acc_stage = 1 if self.overlapping_accum else 2 + self.num_ab_stage = 2 + + sf_atom_mn = 32 + self.num_sfa_tmem_cols = (self.cta_tile_shape_mnk[0] // sf_atom_mn) * mma_inst_tile_k + self.num_sfb_tmem_cols = (self.cta_tile_shape_mnk_sfb[1] // sf_atom_mn) * mma_inst_tile_k + self.num_sf_tmem_cols = self.num_sfa_tmem_cols + self.num_sfb_tmem_cols + if self.overlapping_accum: + self.num_accumulator_tmem_cols = self.cta_tile_shape_mnk[1] * 2 - self.num_sf_tmem_cols + else: + self.num_accumulator_tmem_cols = self.cta_tile_shape_mnk[1] * self.num_acc_stage + + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake) + + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, self.mma_tiler, a_dtype, self.num_ab_stage) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, self.mma_tiler, b_dtype, self.num_ab_stage) + self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, self.mma_tiler, self.sf_vec_size, self.num_ab_stage) + self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, self.mma_tiler, self.sf_vec_size, self.num_ab_stage) + + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + a_smem_0 = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + b_smem_0 = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + sfa_smem_0 = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0)) + sfb_smem_0 = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(a_dtype, a_smem_0) + + cute.size_in_bytes(b_dtype, b_smem_0) + + cute.size_in_bytes(sf_dtype, sfa_smem_0) + + cute.size_in_bytes(sf_dtype, sfb_smem_0) + ) * atom_thr_size + + def run(self, mat_a, mat_b, scale_a, scale_b, e_bias, out_weights, out_ids, + M, N, K, routed_scaling_factor, top_k, stream=None): + if stream is None: + stream = cuda.CUstream(0) + + a_dtype = cutlass.Float4E2M1FN + b_dtype = cutlass.Float4E2M1FN + sf_dtype = cutlass.Float8E4M3FN + a_major_mode = utils.LayoutEnum.from_tensor(mat_a).mma_major_mode() + b_major_mode = utils.LayoutEnum.from_tensor(mat_b).mma_major_mode() + + tiled_mma = self._create_tiled_mma(a_dtype, a_major_mode, b_major_mode, sf_dtype) + tiled_mma_sfb = self._create_tiled_mma_sfb(a_dtype, a_major_mode, b_major_mode, sf_dtype) + self._setup_attributes(tiled_mma, tiled_mma_sfb, a_dtype, b_dtype, sf_dtype) + + a_smem_0 = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, mat_a, a_smem_0, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape) + + b_smem_0 = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, mat_b, b_smem_0, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape) + + sfa_smem_0 = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0)) + tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( + a_op, scale_a, sfa_smem_0, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape) + + sfb_smem_0 = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0)) + sfb_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma_sfb.thr_id) + tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + sfb_op, scale_b, sfb_smem_0, self.mma_tiler_sfb, tiled_mma_sfb, self.cluster_layout_sfb_vmnk.shape) + + num_M_tiles = cute.ceil_div(M, self.cta_tile_shape_mnk[0]) + num_N_tiles = cute.ceil_div(N, self.cta_tile_shape_mnk[1]) + L = 1 + grid = (num_M_tiles * num_N_tiles, 1, 1) + + tile_sched_params = utils.PersistentTileSchedulerParams( + (cutlass.Int32(num_M_tiles), cutlass.Int32(num_N_tiles), cutlass.Int32(L)), + (*self.cluster_shape_mn, 1)) + + @cute.jit + def _compiled_fn(mat_a, mat_b, scale_a, scale_b, e_bias, out_weights, out_ids): + self._kernel( + tiled_mma, tiled_mma_sfb, + tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b, + tma_atom_sfa, tma_tensor_sfa, tma_atom_sfb, tma_tensor_sfb, + self.cluster_layout_vmnk, self.cluster_layout_sfb_vmnk, + self.a_smem_layout_staged, self.b_smem_layout_staged, + self.sfa_smem_layout_staged, self.sfb_smem_layout_staged, + self.epi_tile, + e_bias, out_weights, out_ids, + tile_sched_params, + M, N, K, top_k, routed_scaling_factor, + ).launch( + grid=grid, block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + stream=stream, min_blocks_per_mp=1, + ) + + cute.compile(_compiled_fn, mat_a, mat_b, scale_a, scale_b, e_bias, out_weights, out_ids) + + @cute.kernel + def _kernel(self, tiled_mma, tiled_mma_sfb, + tma_atom_a, mA_mkl, tma_atom_b, mB_nkl, + tma_atom_sfa, mSFA_mkl, tma_atom_sfb, mSFB_nkl, + cluster_layout_vmnk, cluster_layout_sfb_vmnk, + a_smem_layout_staged, b_smem_layout_staged, + sfa_smem_layout_staged, sfb_smem_layout_staged, + epi_tile, + e_bias_tensor, out_w_tensor, out_id_tensor, + tile_sched_params, + M, N, K, top_k, routed_scaling_factor): + + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + use_2cta = cute.size(tiled_mma.thr_id.shape) == 2 + is_leader_cta = (bidx % cute.size(tiled_mma.thr_id.shape)) == 0 + mma_tile_v = bidx % cute.size(tiled_mma.thr_id.shape) + cta_rank = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + block_coord = cluster_layout_vmnk.get_flat_coord(cta_rank) + + acc_dtype = cutlass.Float32 + sf_dtype = cutlass.Float8E4M3FN + + # ============================================================ + # Shared storage + # ============================================================ + @cute.struct + class SharedStorage: + ab_full_mbar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + ab_empty_mbar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + acc_full_mbar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + acc_empty_mbar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + tmem_dealloc_mbar: cutlass.Int64 + tmem_holding: cutlass.Int32 + heap_scores: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 4*32*self.top_k], 128] + heap_indices: cute.struct.Align[cute.struct.MemRange[cutlass.Int32, 4*32*self.top_k], 128] + heap_acts: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 4*32*self.top_k], 128] + sA: cute.struct.Align[cute.struct.MemRange[cutlass.Float4E2M1FN, cute.cosize(a_smem_layout_staged.outer)], self.buffer_align_bytes] + sB: cute.struct.Align[cute.struct.MemRange[cutlass.Float4E2M1FN, cute.cosize(b_smem_layout_staged.outer)], self.buffer_align_bytes] + sSFA: cute.struct.Align[cute.struct.MemRange[cutlass.Float8E4M3FN, cute.cosize(sfa_smem_layout_staged.outer)], self.buffer_align_bytes] + sSFB: cute.struct.Align[cute.struct.MemRange[cutlass.Float8E4M3FN, cute.cosize(sfb_smem_layout_staged.outer)], self.buffer_align_bytes] + + smem = utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + + # ============================================================ + # Pipelines + # ============================================================ + ab_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_full_mbar.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1), + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + num_acc_cons = self.threads_per_warp * len(self.epilogue_warp_id) * (2 if use_2cta else 1) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_full_mbar.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, num_acc_cons), + cta_layout_vmnk=cluster_layout_vmnk, + ) + + tmem = utils.TmemAllocator( + storage.tmem_holding.ptr, + barrier_for_retrieve=pipeline.NamedBarrier( + barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=self.threads_per_warp * len((self.mma_warp_id, *self.epilogue_warp_id))), + allocator_warp_id=self.epilogue_warp_id[0], + is_two_cta=use_2cta, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr) + + cta_bar = pipeline.NamedBarrier(self.cta_sync_bar_id, self.threads_per_cta) + epi_bar = pipeline.NamedBarrier( + self.epilogue_sync_bar_id, + self.threads_per_warp * len(self.epilogue_warp_id)) + + # ============================================================ + # SMEM tensors + # ============================================================ + sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner) + sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner) + sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged.outer, swizzle=sfa_smem_layout_staged.inner) + sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged.outer, swizzle=sfb_smem_layout_staged.inner) + + # ============================================================ + # Multicast masks + # ============================================================ + a_mcast = None; b_mcast = None; sfa_mcast = None; sfb_mcast = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta): + a_mcast = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_coord, mcast_mode=2) + b_mcast = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_coord, mcast_mode=1) + sfa_mcast = a_mcast + sfb_mcast = cpasync.create_tma_multicast_mask(cluster_layout_sfb_vmnk, block_coord, mcast_mode=1) + + # ============================================================ + # Partition global tensors + # ============================================================ + gA = cute.local_tile(mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)) + gB = cute.local_tile(mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None)) + gSFA = cute.local_tile(mSFA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)) + gSFB = cute.local_tile(mSFB_nkl, cute.slice_(self.mma_tiler_sfb, (0, None, None)), (None, None, None)) + + k_tiles = cute.size(gA, mode=[3]) + thr_mma = tiled_mma.get_slice(mma_tile_v) + tCgA = thr_mma.partition_A(gA) + tCgB = thr_mma.partition_B(gB) + tCgSFA = thr_mma.partition_A(gSFA) + thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_v) + tCgSFB = thr_mma_sfb.partition_B(gSFB) + + # TMA partitions for A/B + a_cta_l = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_atom_a, block_coord[2], a_cta_l, + cute.group_modes(sA, 0, 3), cute.group_modes(tCgA, 0, 3)) + b_cta_l = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_atom_b, block_coord[1], b_cta_l, + cute.group_modes(sB, 0, 3), cute.group_modes(tCgB, 0, 3)) + + # TMA partitions for SFA/SFB + tAsSFA, tAgSFA = cpasync.tma_partition(tma_atom_sfa, block_coord[2], a_cta_l, + cute.group_modes(sSFA, 0, 3), cute.group_modes(tCgSFA, 0, 3)) + tAsSFA = cute.filter_zeros(tAsSFA) + tAgSFA = cute.filter_zeros(tAgSFA) + + sfb_cta_l = cute.make_layout(cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape) + block_coord_sfb = cluster_layout_sfb_vmnk.get_flat_coord(cta_rank) + tBsSFB, tBgSFB = cpasync.tma_partition(tma_atom_sfb, block_coord_sfb[1], sfb_cta_l, + cute.group_modes(sSFB, 0, 3), cute.group_modes(tCgSFB, 0, 3)) + tBsSFB = cute.filter_zeros(tBsSFB) + tBgSFB = cute.filter_zeros(tBgSFB) + + # TMEM accumulator shape + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_arrive_relaxed() + + # ============================================================ + # TMA WARP + # ============================================================ + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + cpasync.prefetch_descriptor(tma_atom_sfa) + cpasync.prefetch_descriptor(tma_atom_sfb) + + tsched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, bidx, cute.arch.grid_dim()) + wt = tsched.initial_work_tile_info() + ab_ps = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_ab_stage) + + while wt.is_valid_tile: + tc = wt.tile_idx + mc = (tc[0] // cute.size(tiled_mma.thr_id.shape), tc[1], tc[2]) + tAgA_s = tAgA[(None, mc[0], None, mc[2])] + tBgB_s = tBgB[(None, mc[1], None, mc[2])] + tAgSFA_s = tAgSFA[(None, mc[0], None, mc[2])] + slice_n = mc[1] + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): + slice_n = mc[1] // 2 + tBgSFB_s = tBgSFB[(None, slice_n, None, mc[2])] + + ab_ps.reset_count() + peek_ab = cutlass.Boolean(1) + if ab_ps.count < k_tiles: + peek_ab = ab_pipeline.producer_try_acquire(ab_ps) + + for kt in cutlass.range(0, k_tiles, 1, unroll=1): + ab_pipeline.producer_acquire(ab_ps, peek_ab) + cute.copy(tma_atom_a, tAgA_s[(None, ab_ps.count)], + tAsA[(None, ab_ps.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_ps), + mcast_mask=a_mcast) + cute.copy(tma_atom_b, tBgB_s[(None, ab_ps.count)], + tBsB[(None, ab_ps.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_ps), + mcast_mask=b_mcast) + cute.copy(tma_atom_sfa, tAgSFA_s[(None, ab_ps.count)], + tAsSFA[(None, ab_ps.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_ps), + mcast_mask=sfa_mcast) + cute.copy(tma_atom_sfb, tBgSFB_s[(None, ab_ps.count)], + tBsSFB[(None, ab_ps.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_ps), + mcast_mask=sfb_mcast) + ab_ps.advance() + peek_ab = cutlass.Boolean(1) + if ab_ps.count < k_tiles: + peek_ab = ab_pipeline.producer_try_acquire(ab_ps) + + ab_pipeline.producer_tail(ab_ps) + tsched.advance_to_next_work() + wt = tsched.get_current_work() + + # ============================================================ + # MMA WARP — blockscaled GEMM: (A * SFA) @ (B * SFB) -> TMEM + # ============================================================ + if warp_idx == self.mma_warp_id: + # Wait for cluster sync + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + else: + cta_bar.arrive_and_wait() + + # Wait for TMEM allocation + tmem.wait_for_alloc() + acc_tmem_ptr = tmem.retrieve_ptr(acc_dtype) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # MMA fragments + tCrA = tiled_mma.make_fragment_A(sA) + tCrB = tiled_mma.make_fragment_B(sB) + + # S2T copies for SFA: SMEM -> TMEM + # The SFA tmem region starts after the accumulator columns + sfa_tmem_ptr = acc_tmem_ptr + tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, self.mma_tiler, self.sf_vec_size, + cute.slice_(sfa_smem_layout_staged, (None, None, None, 0))) + tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) + + # S2T copies for SFB: SMEM -> TMEM + sfb_tmem_ptr = acc_tmem_ptr + tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma_sfb, self.mma_tiler, self.sf_vec_size, + cute.slice_(sfb_smem_layout_staged, (None, None, None, 0))) + tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) + + # S2T copy atoms + tiled_copy_s2t_sfa, tCsSFA_compact_s2t, tCtSFA_compact_s2t = \ + self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA, self.cta_group) + tiled_copy_s2t_sfb, tCsSFB_compact_s2t, tCtSFB_compact_s2t = \ + self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB, tcgen05.CtaGroup.ONE) + + # Tile scheduler + pipeline states + tsched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, bidx, cute.arch.grid_dim()) + wt = tsched.initial_work_tile_info() + ab_cs = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_ab_stage) + acc_ps = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num +_acc_stage) + + num_tiles_executed = cutlass.Int32(0) + + while wt.is_valid_tile: + # Wait for accumulator buffer empty + if is_leader_cta: + acc_pipeline.producer_acquire(acc_ps) + + # Get accumulator stage index + if cutlass.const_expr(self.overlapping_accum): + acc_stage_index = acc_ps.phase ^ 1 + else: + acc_stage_index = acc_ps.index + + tCtAcc = tCtAcc_base[(None, None, None, acc_stage_index)] + + # Clear accumulator for new tile + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + # Reset count for AB pipeline consumer + ab_cs.reset_count() + peek_ab_full = cutlass.Boolean(1) + if ab_cs.count < k_tiles and is_leader_cta: + peek_ab_full = ab_pipeline.consumer_try_wait(ab_cs) + + # Mainloop: K-tiles + for kt in cutlass.range(0, k_tiles, 1, unroll=1): + if is_leader_cta: + ab_pipeline.consumer_wait(ab_cs, peek_ab_full) + + # Copy SFA/SFB from SMEM to TMEM + s2t_stage = ( + None, None, None, None, ab_cs.index, + ) + cute.copy(tiled_copy_s2t_sfa, + tCsSFA_compact_s2t[s2t_stage], + tCtSFA_compact_s2t) + cute.copy(tiled_copy_s2t_sfb, + tCsSFB_compact_s2t[s2t_stage], + tCtSFB_compact_s2t) + + # Set SFA/SFB for MMA + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll=1): + sf_kblock = (None, None, kblock_idx) + tiled_mma.set(tcgen05.Field.SFA, + tCtSFA[sf_kblock].iterator) + tiled_mma.set(tcgen05.Field.SFB, + tCtSFB[sf_kblock].iterator) + + kb_coord = (None, None, kblock_idx, ab_cs.index) + cute.gemm(tiled_mma, tCrA[kb_coord], tCrB[kb_coord], + tCtAcc, tCtAcc) + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Release AB buffer + ab_pipeline.consumer_release(ab_cs) + ab_cs.advance() + peek_ab_full = cutlass.Boolean(1) + if ab_cs.count < k_tiles: + if is_leader_cta: + peek_ab_full = ab_pipeline.consumer_try_wait(ab_cs) + + # Commit accumulator full + if is_leader_cta: + acc_pipeline.producer_commit(acc_ps) + acc_ps.advance() + num_tiles_executed += cutlass.Int32(1) + + tsched.advance_to_next_work() + wt = tsched.get_current_work() + + # Wait for accumulator buffer empty + if is_leader_cta: + acc_pipeline.producer_tail(acc_ps) + + # Signal epilogue that MMA is done + tmem.relinquish_alloc_permit() + + # ============================================================ + # EPILOGUE WARPS — TMEM -> registers, router logic, GMEM store + # ============================================================ + if warp_idx in self.epilogue_warp_id: + # Wait for cluster sync + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + else: + cta_bar.arrive_and_wait() + + # Wait for TMEM allocation + tmem.wait_for_alloc() + acc_tmem_ptr = tmem.retrieve_ptr(acc_dtype) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # TMEM->register copy setup (paired atoms from CUTLASS) + epi_n = self.epi_tile_n + tiled_copy_t2r, tTR_tAcc_base = epilogue_tmem_copy_and_partition( + tCtAcc_base, epi_tile, self.epilogue_warp_id, acc_dtype, use_2cta) + tTR_rAcc = tiled_copy_t2r.fragments_slice(tiled_copy_t2r, tTR_tAcc_base) + + # Identity tensor for expert index mapping + cAcc = cute.make_identity_tensor( + (self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1])) + tCcAcc = thr_mma.partition_C(cAcc) + + # Tile scheduler + pipeline states + tsched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, bidx, cute.arch.grid_dim()) + wt = tsched.initial_work_tile_info() + acc_cs = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage) + + while wt.is_valid_tile: + acc_pipeline.consumer_wait(acc_cs) + + if cutlass.const_expr(self.overlapping_accum): + acc_stage_index = acc_cs.phase + else: + acc_stage_index = acc_cs.index + + # Set accumulator buffer for current tile + tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_stage_index)] + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + + # Per-thread register heap (top_k entries) + hs = [cutlass.Float32(-1e30)] * self.top_k + hi = [cutlass.Int32(-1)] * self.top_k + ha = [cutlass.Float32(0.0)] * self.top_k + + # Process subtiles (each subtile = epi_tile_n columns) + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in cutlass.range(subtile_cnt): + # Load accumulator from TMEM to registers + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # Fence for TMEM load + cute.arch.fence_view_async_tmem_load() + + # Early release accumulator for overlapping case + if cutlass.const_expr(self.overlapping_accum): + if subtile_idx == self.num_sf_tmem_cols // epi_n: + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_cs) + acc_cs.advance() + + # Process each element in the register fragment + rFlat = cute.flatten(tTR_rAcc) + cFlat = cute.flatten(tCcAcc) + elem_cnt = cute.size(rFlat) + for e in cutlass.range(elem_cnt, unroll=4): + logit = rFlat[e] + coord = cFlat[e] + row = coord[0] + col = coord[1] + # Expert index = col + (subtile_idx * epi_tile_n) + e_idx = col + (subtile_idx * epi_n) + + # sqrt(softplus(logit)) + abs_x = cute.math.absf(logit) + pos = cute.where(logit > cutlass.Float32(0.0), logit, cutlass.Float32(0.0)) + exp_neg = cute.math.exp(-abs_x) + sp = pos + cute.math.log(cutlass.Float32(1.0) + exp_neg) + act = cute.math.sqrt(sp) + + # score = act + e_bias (for selection only) + score = act + e_bias_tensor[e_idx] + + # Min-heap push: root = hs[0] (smallest of top_k) + do_push = score > hs[0] + if do_push: + # Replace root with new entry + old_s = hs[0]; old_i = hi[0]; old_a = ha[0] + hs[0] = score; hi[0] = e_idx; ha[0] = act + # Sift down (top_k=6, fully unrolled) + r = 0 + _done = cutlass.Bool(False) + for _sift in cutlass.range(3, unroll=1): + if not _done: + left = 2*r+1; right = 2*r+2 + sm = r + if left < self.top_k: + if hs[left] < hs[sm]: + sm = left + if right < self.top_k: + if hs[right] < hs[sm]: + sm = right + if sm == r: + _done = cutlass.Bool(True) + else: + ts = hs[r]; ti = hi[r]; ta = ha[r] + hs[r] = hs[sm]; hi[r] = hi[sm]; ha[r] = ha[sm] + hs[sm] = ts; hi[sm] = ti; ha[sm] = ta + r = sm + + # Release accumulator (non-overlapping case) + if cutlass.const_expr(not self.overlapping_accum): + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_cs) + acc_cs.advance() + + # Write heap to shared memory for cross-thread merge + tid = warp_idx * 32 + tidx + base = tid * self.top_k + for i in cutlass.range(self.top_k, unroll=1): + storage.heap_scores.data_ptr()[base + i] = hs[i] + storage.heap_indices.data_ptr()[base + i] = hi[i] + storage.heap_acts.data_ptr()[base + i] = ha[i] + + epi_bar.arrive_and_wait() + + # Thread 0 of warp 0 does the final merge + store + if warp_idx == 0 and tidx == 0: + # Initialize final heap from thread 0 + fs = list(hs); fi = list(hi); fa = list(ha) + # Merge all 128 threads (4 warps * 32) + for t in cutlass.range(1, 128, unroll=1): + for i in cutlass.range(self.top_k, unroll=1): + cs = storage.heap_scores.data_ptr()[t*self.top_k+i] + ci = storage.heap_indices.data_ptr()[t*self.top_k+i] + ca = storage.heap_acts.data_ptr()[t*self.top_k+i] + if ci >= 0: + if cs > fs[0]: + fs[0] = cs; fi[0] = ci; fa[0] = ca + # Sift down + r = 0 + _done2 = cutlass.Bool(False) + for _sift2 in cutlass.range(3, unroll=1): + if not _done2: + l = 2*r+1; ri = 2*r+2; sm = r + if l < self.top_k: + if fs[l] < fs[sm]: + sm = l + if ri < self.top_k: + if fs[ri] < fs[sm]: + sm = ri + if sm == r: + _done2 = cutlass.Bool(True) + else: + ts=fs[r]; ti=fi[r]; ta=fa[r] + fs[r]=fs[sm]; fi[r]=fi[sm]; fa[r]=fa[sm] + fs[sm]=ts; fi[sm]=ti; fa[sm]=ta + r = sm + + # Sort descending (selection sort, k=6) + sorted_s = [cutlass.Float32(-1e30)] * self.top_k + sorted_i = [cutlass.Int32(-1)] * self.top_k + sorted_a = [cutlass.Float32(0.0)] * self.top_k + for i in cutlass.range(self.top_k, unroll=1): + best = 0 + for j in cutlass.range(1, self.top_k, unroll=1): + if fs[j] > fs[best]: + best = j + sorted_s[i] = fs[best] + sorted_i[i] = fi[best] + sorted_a[i] = fa[best] + fs[best] = cutlass.Float32(-1e30) + + # Renormalize: w = act / sum(act) * scaling + act_sum = sorted_a[0] + sorted_a[1] + sorted_a[2] + sorted_a[3] + sorted_a[4] + sorted_a[5] + inv_sum = cutlass.Float32(1.0) / act_sum + sc = cutlass.Float32(routed_scaling_factor) + + # Get tile coordinates for output indexing + tc = wt.tile_idx + row_base = tc[0] // cute.size(tiled_mma.thr_id.shape) * self.cta_tile_shape_mnk[0] + + # Store to GMEM + for i in cutlass.range(self.top_k, unroll=1): + out_w_tensor[row_base + 0, i] = sorted_a[i] * inv_sum * sc + out_id_tensor[row_base + 0, i] = sorted_i[i] + + epi_bar.arrive_and_wait() + + tsched.advance_to_next_work() + wt = tsched.get_current_work() + + # Cleanup + tmem.relinquish_alloc_permit() + epi_bar.arrive_and_wait() + tmem.free(acc_tmem_ptr) + + +# ================================================================ +# Python wrapper — called by dense_router_dispatch_nvfp4 +# ================================================================ +def run_nvfp4_fused_router( + hidden_states: torch.Tensor, # [M, hidden_size] BF16 + mat_b, # CuTe tensor: gate weight (NVFP4, blockscaled layout) + scale_b, # CuTe tensor: gate weight scale factors (FP8 E4M3) + gsa, # Activation global scale (scalar or 1-elem tensor) + gsb_val: float, # Weight global scale value + e_bias: torch.Tensor, # [E] FP32 + routed_scaling_factor: float, + top_k: int = 6, + sf_vec_size: int = 16, +) -> tuple[torch.Tensor, torch.Tensor]: + """Run the NVFP4 fused router kernel. + + Single-kernel: NVFP4 block-scaled GEMM + fused router epilogue. + No intermediate GMEM buffer. No BF16 fallback. + + Parameters + ---------- + hidden_states : [M, K] BF16 — raw activation + mat_b : CuTe tensor — gate weight in NVFP4 blockscaled layout + scale_b : CuTe tensor — gate weight scale factors in blockscaled layout + gsa : activation global scale (scalar) + gsb_val : weight global scale (float) + e_bias : [E] FP32 — per-expert selection bias + routed_scaling_factor : float + top_k : int (default 6) + + Returns + ------- + (topk_weights, topk_ids) — [M, top_k] FP32 and [M, top_k] int32 + """ + import cutlass.torch as cutlass_torch + from dsv4.ops.quantize import quantize_activation_nvfp4 + + M = hidden_states.shape[0] + K = hidden_states.shape[1] + device = hidden_states.device + + # Quantize activation to NVFP4 + act_nvfp4, act_sf, _gsa, _gsb = quantize_activation_nvfp4( + hidden_states, sf_vec_size=sf_vec_size) + # Override global scales with provided values + act_gsa = gsa if gsa is not None else _gsa + act_gsb = gsb_val + + # Create CuTe tensors for activation + scales + # A tensor: [K_packed, M, L] where K_packed = K/2 (2 elements per byte for FP4) + K_packed = K // 2 + mat_a = cutlass_torch.from_dlpack(act_nvfp4) + mat_a = cute.mark_layout_dynamic(mat_a) + # SFA tensor: [K_sf, M, L] + scale_a = cutlass_torch.from_dlpack(act_sf) + scale_a = cute.mark_layout_dynamic(scale_a) + + # e_bias must be a CuTe tensor + e_bias_cute = cutlass_torch.from_dlpack(e_bias) + e_bias_cute = cute.mark_layout_dynamic(e_bias_cute) + + # Number of experts from e_bias + E = e_bias.shape[0] + + # Output buffers + out_weights = torch.zeros(M, top_k, dtype=torch.float32, device=device) + out_ids = torch.zeros(M, top_k, dtype=torch.int32, device=device) + out_w_cute = cutlass_torch.from_dlpack(out_weights) + out_w_cute = cute.mark_layout_dynamic(out_w_cute) + out_id_cute = cutlass_torch.from_dlpack(out_ids) + out_id_cute = cute.mark_layout_dynamic(out_id_cute) + + # MMA tiler: (128, 128, 64) for decode + mma_tiler_mnk = (128, 128, 64) + + kernel = Nvfp4FusedRouterKernel( + sf_vec_size=sf_vec_size, + mma_tiler_mnk=mma_tiler_mnk, + cluster_shape_mnk=(1, 1, 1), + top_k=top_k, ) + kernel.run( + mat_a, mat_b, scale_a, scale_b, + e_bias_cute, out_w_cute, out_id_cute, + M, E, K, routed_scaling_factor, top_k, + ) + return out_weights, out_ids