From 4f4ae8febd69a8b3debb1c1d6eb8d0d1f692fdd5 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 1 Jun 2026 09:11:29 +0000 Subject: [PATCH] Test: enumerate CuTeDSL math API to check available operations --- .../router/nvfp4_fused_router_kernel.py | 449 ++++++++++-------- tests/unit/test_cute_math_api.py | 78 +++ 2 files changed, 326 insertions(+), 201 deletions(-) create mode 100644 tests/unit/test_cute_math_api.py diff --git a/dsv4/kernels/router/nvfp4_fused_router_kernel.py b/dsv4/kernels/router/nvfp4_fused_router_kernel.py index b6efc3c7..097c9186 100644 --- a/dsv4/kernels/router/nvfp4_fused_router_kernel.py +++ b/dsv4/kernels/router/nvfp4_fused_router_kernel.py @@ -6,7 +6,7 @@ with fused router epilogue (sqrt(softplus) + e_bias + top-k + renorm). PRODUCTION KERNEL. No intermediate GMEM buffer. No BF16 fallback. The GEMM accumulates logits in TMEM, then the epilogue warps process them directly: 1. TMEM -> registers (via paired t2r atom from CUTLASS epilogue helpers) - 2. For each logit: sqrt(softplus(logit)) + e_bias -> score; track top-k via min-heap + 2. For each logit: sqrt(softplus(logit)) + e_bias -> score; track top-k via sorted insertion 3. After all subtiles: sort, renormalize, write (topk_weights, topk_ids) to GMEM Warp specialization (6 warps, no scheduler for dense GEMM): @@ -17,17 +17,21 @@ Warp specialization (6 warps, no scheduler for dense GEMM): Pipeline structure (2 pipelines): AB pipeline: TMA (producer) -> MMA (consumer) [PipelineTmaUmma] Acc pipeline: MMA (producer) -> Epilogue (consumer) [PipelineUmmaAsync] + +Architecture reference: FusedSwiGLUScaledGroupedGemmKernel (dsv4/kernels/gemm/fused_swiglu.py) +The blockscaled GEMM mainloop follows the same pattern exactly. +The epilogue is custom: instead of TMA store, we do TMEM->reg top-k reduction. """ from __future__ import annotations -from typing import Tuple -import math +from typing import Tuple, Optional, Type, Union import cuda.bindings.driver as cuda import torch import cutlass import cutlass.cute as cute +from cutlass.cute.typing import Pointer from cutlass.cute.nvgpu import cpasync, tcgen05 import cutlass.utils as utils import cutlass.pipeline as pipeline @@ -35,12 +39,20 @@ import cutlass.utils.blackwell_helpers as sm100_utils import cutlass.utils.blockscaled_layout as blockscaled_utils from cutlass.utils.gemm.sm100 import ( epilogue_tmem_copy_and_partition, - epilogue_smem_copy_and_partition, transform_partitioned_tensor_layout, ) class Nvfp4FusedRouterKernel: + """ + NVFP4 blockscaled GEMM + fused router epilogue. + + Dense (non-grouped) GEMM: [M, K] @ [K, E] -> [M, E] with NVFP4 weights. + Custom epilogue: TMEM -> registers -> sqrt(softplus) + e_bias + top-k + renorm -> GMEM. + + This follows the FusedSwiGLUScaledGroupedGemmKernel pattern for the + blockscaled GEMM mainloop exactly, with a custom epilogue. + """ def __init__( self, @@ -55,38 +67,48 @@ class Nvfp4FusedRouterKernel: self.top_k = top_k self.use_2cta_instrs = mma_tiler_mnk[0] == 256 self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE + self.arch = "sm_100" + # 6-warp specialization (no scheduler warp for dense GEMM) self.epilogue_warp_id = (0, 1, 2, 3) self.mma_warp_id = 4 self.tma_warp_id = 5 self.threads_per_warp = 32 self.threads_per_cta = self.threads_per_warp * 6 + # Barrier IDs self.cta_sync_bar_id = 1 self.epilogue_sync_bar_id = 2 self.tmem_alloc_sync_bar_id = 3 - self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + + self.smem_capacity = utils.get_smem_capacity_in_bytes(self.arch) self.occupancy = 1 self.buffer_align_bytes = 1024 + # ----------------------------------------------------------------- + # _create_tiled_mma / _create_tiled_mma_sfb + # ----------------------------------------------------------------- + def _create_tiled_mma(self, a_dtype, a_major_mode, b_major_mode, sf_dtype): return sm100_utils.make_blockscaled_trivial_tiled_mma( a_dtype, a_major_mode, b_major_mode, sf_dtype, self.sf_vec_size, self.cta_group, - (self.mma_tiler_mnk[0], self.mma_tiler_mnk[1]), + self.mma_inst_shape_mn, ) def _create_tiled_mma_sfb(self, a_dtype, a_major_mode, b_major_mode, sf_dtype): - mma_inst_shape_mn_sfb = ( - self.mma_tiler_mnk[0] // (2 if self.use_2cta_instrs else 1), - cute.round_up(self.mma_tiler_mnk[1], 128), - ) return sm100_utils.make_blockscaled_trivial_tiled_mma( a_dtype, a_major_mode, b_major_mode, sf_dtype, - self.sf_vec_size, tcgen05.CtaGroup.ONE, mma_inst_shape_mn_sfb, + self.sf_vec_size, tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mn_sfb, ) + # ----------------------------------------------------------------- + # _setup_attributes + # ----------------------------------------------------------------- + def _setup_attributes(self, tiled_mma, tiled_mma_sfb, a_dtype, b_dtype, sf_dtype): + """Set up kernel attributes. Mirrors FusedSwiGLUScaledGroupedGemmKernel._setup_attributes.""" self.mma_inst_shape_mn = (self.mma_tiler_mnk[0], self.mma_tiler_mnk[1]) self.mma_inst_shape_mn_sfb = ( self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1), @@ -108,11 +130,13 @@ class Nvfp4FusedRouterKernel: ) self.cta_tile_shape_mnk = ( self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), - self.mma_tiler[1], self.mma_tiler[2], + self.mma_tiler[1], + self.mma_tiler[2], ) self.cta_tile_shape_mnk_sfb = ( self.mma_tiler_sfb[0] // cute.size(tiled_mma.thr_id.shape), - self.mma_tiler_sfb[1], self.mma_tiler_sfb[2], + self.mma_tiler_sfb[1], + self.mma_tiler_sfb[2], ) self.cluster_layout_vmnk = cute.tiled_divide( @@ -129,29 +153,21 @@ class Nvfp4FusedRouterKernel: self.is_b_mcast = self.num_mcast_ctas_b > 1 self.is_sfb_mcast = self.num_mcast_ctas_sfb > 1 + # Epilogue tile: for router, we process all N columns (expert dimension). + # Use epi_tile = (128, 32) as the subtile for t2r copy. + # This determines how many columns are loaded from TMEM per subtile. self.epi_tile = ( cute.make_layout(self.cta_tile_shape_mnk[0]), - cute.make_layout(self.cta_tile_shape_mnk[1]), + cute.make_layout((32, 1), stride=(1, 32)), ) self.epi_tile_n = cute.size(self.epi_tile[1]) - self.overlapping_accum = self.cta_tile_shape_mnk[1] == 256 - self.num_acc_stage = 1 if self.overlapping_accum else 2 + # Stage counts + self.num_acc_stage = 2 self.num_ab_stage = 2 + self.num_c_stage = 2 # not used for TMA store, but needed for stage computation - sf_atom_mn = 32 - self.num_sfa_tmem_cols = (self.cta_tile_shape_mnk[0] // sf_atom_mn) * mma_inst_tile_k - self.num_sfb_tmem_cols = (self.cta_tile_shape_mnk_sfb[1] // sf_atom_mn) * mma_inst_tile_k - self.num_sf_tmem_cols = self.num_sfa_tmem_cols + self.num_sfb_tmem_cols - if self.overlapping_accum: - self.num_accumulator_tmem_cols = self.cta_tile_shape_mnk[1] * 2 - self.num_sf_tmem_cols - else: - self.num_accumulator_tmem_cols = self.cta_tile_shape_mnk[1] * self.num_acc_stage - - acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) - tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) - self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake) - + # Compute SMEM layouts for A, B, SFA, SFB self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( tiled_mma, self.mma_tiler, a_dtype, self.num_ab_stage) self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( @@ -161,6 +177,28 @@ class Nvfp4FusedRouterKernel: self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( tiled_mma, self.mma_tiler, self.sf_vec_size, self.num_ab_stage) + # Overlapping accumulator (N=256 case) + self.overlapping_accum = self.cta_tile_shape_mnk[1] == 256 + if self.overlapping_accum: + self.num_acc_pipeline_stages = 1 + else: + self.num_acc_pipeline_stages = self.num_acc_stage + + # TMEM column counts + sf_atom_mn = 32 + self.num_sfa_tmem_cols = (self.cta_tile_shape_mnk[0] // sf_atom_mn) * mma_inst_tile_k + self.num_sfb_tmem_cols = (self.cta_tile_shape_mnk_sfb[1] // sf_atom_mn) * mma_inst_tile_k + self.num_sf_tmem_cols = self.num_sfa_tmem_cols + self.num_sfb_tmem_cols + self.num_accumulator_tmem_cols = self.cta_tile_shape_mnk[1] * self.num_acc_stage - ( + self.num_sf_tmem_cols if self.overlapping_accum else 0 + ) + + # Only when overlapping_accum, release accumulator buffer early in epilogue + self.iter_acc_early_release_in_epilogue = ( + self.num_sf_tmem_cols // self.epi_tile_n + ) + + # TMA load bytes atom_thr_size = cute.size(tiled_mma.thr_id.shape) a_smem_0 = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) b_smem_0 = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) @@ -173,42 +211,90 @@ class Nvfp4FusedRouterKernel: cute.size_in_bytes(sf_dtype, sfb_smem_0) ) * atom_thr_size - self.iter_acc_early_release = self.num_sf_tmem_cols // self.epi_tile_n + # TMEM allocation size + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake) + + # ----------------------------------------------------------------- + # mainloop_s2t_copy_and_partition (same as fused_swiglu.py) + # ----------------------------------------------------------------- + + def mainloop_s2t_copy_and_partition(self, sSF, tSF, cta_group): + """Make tiledCopy for SMEM -> TMEM load of a scale factor tensor.""" + tCsSF_compact = cute.filter_zeros(sSF) + tCtSF_compact = cute.filter_zeros(tSF) + + copy_atom_s2t = cute.make_copy_atom( + tcgen05.Cp4x32x128bOp(cta_group), + self.sf_dtype, + ) + tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact) + thr_copy_s2t = tiled_copy_s2t.get_slice(0) + + tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact) + tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t, tCsSF_compact_s2t_ + ) + tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact) + + return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t + + # ----------------------------------------------------------------- + # run() — Python entry point + # ----------------------------------------------------------------- def run(self, mat_a, mat_b, scale_a, scale_b, e_bias, out_weights, out_ids, M, N, K, routed_scaling_factor, top_k, stream=None): if stream is None: stream = cuda.CUstream(0) - a_dtype = cutlass.Float4E2M1FN - b_dtype = cutlass.Float4E2M1FN - sf_dtype = cutlass.Float8E4M3FN + # Infer dtypes and major modes from tensors + a_dtype = mat_a.element_type + b_dtype = mat_b.element_type + sf_dtype = scale_a.element_type a_major_mode = utils.LayoutEnum.from_tensor(mat_a).mma_major_mode() b_major_mode = utils.LayoutEnum.from_tensor(mat_b).mma_major_mode() + # Save for kernel use + self.a_dtype = a_dtype + self.b_dtype = b_dtype + self.sf_dtype = sf_dtype + self.a_major_mode = a_major_mode + self.b_major_mode = b_major_mode + tiled_mma = self._create_tiled_mma(a_dtype, a_major_mode, b_major_mode, sf_dtype) tiled_mma_sfb = self._create_tiled_mma_sfb(a_dtype, a_major_mode, b_major_mode, sf_dtype) self._setup_attributes(tiled_mma, tiled_mma_sfb, a_dtype, b_dtype, sf_dtype) - a_smem_0 = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + # TMA atoms — following fused_swiglu.py exactly + # A a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( - a_op, mat_a, a_smem_0, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape) + a_op, mat_a, a_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape) - b_smem_0 = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + # B b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( - b_op, mat_b, b_smem_0, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape) + b_op, mat_b, b_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape) - sfa_smem_0 = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0)) + # SFA + sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) + sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0)) tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( - a_op, scale_a, sfa_smem_0, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape) + sfa_op, scale_a, sfa_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape, + internal_type=cutlass.Uint64) - sfb_smem_0 = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0)) - sfb_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma_sfb.thr_id) + # SFB + sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id) + sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0)) tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( - sfb_op, scale_b, sfb_smem_0, self.mma_tiler_sfb, tiled_mma_sfb, self.cluster_layout_sfb_vmnk.shape) + sfb_op, scale_b, sfb_smem_layout, self.mma_tiler_sfb, tiled_mma_sfb, + self.cluster_layout_sfb_vmnk.shape, internal_type=cutlass.Uint64) + # Grid: dense GEMM, one CTA per (M_tile, N_tile) num_M_tiles = cute.ceil_div(M, self.cta_tile_shape_mnk[0]) num_N_tiles = cute.ceil_div(N, self.cta_tile_shape_mnk[1]) L = 1 @@ -239,6 +325,10 @@ class Nvfp4FusedRouterKernel: cute.compile(_compiled_fn, mat_a, mat_b, scale_a, scale_b, e_bias, out_weights, out_ids) + # ----------------------------------------------------------------- + # GPU kernel + # ----------------------------------------------------------------- + @cute.kernel def _kernel(self, tiled_mma, tiled_mma_sfb, tma_atom_a, mA_mkl, tma_atom_b, mB_nkl, @@ -262,32 +352,32 @@ class Nvfp4FusedRouterKernel: block_coord = cluster_layout_vmnk.get_flat_coord(cta_rank) acc_dtype = cutlass.Float32 - sf_dtype = cutlass.Float8E4M3FN + + # Reconstruct SMEM layout slices (same as fused_swiglu kernel) + a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, None, 0)) + b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, None, 0)) + sfa_smem_layout = cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)) + sfb_smem_layout = cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)) # ============================================================ # Shared storage # ============================================================ @cute.struct class SharedStorage: - ab_full_mbar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] - ab_empty_mbar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] - acc_full_mbar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] - acc_empty_mbar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + ab_full_mbar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + acc_full_mbar: cute.struct.MemRange[cutlass.Int64, self.num_acc_pipeline_stages * 2] tmem_dealloc_mbar: cutlass.Int64 tmem_holding: cutlass.Int32 - merge_scores: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 128*self.top_k], 128] - merge_indices: cute.struct.Align[cute.struct.MemRange[cutlass.Int32, 128*self.top_k], 128] - merge_acts: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 128*self.top_k], 128] - sA: cute.struct.Align[cute.struct.MemRange[cutlass.Float4E2M1FN, cute.cosize(a_smem_layout_staged.outer)], self.buffer_align_bytes] - sB: cute.struct.Align[cute.struct.MemRange[cutlass.Float4E2M1FN, cute.cosize(b_smem_layout_staged.outer)], self.buffer_align_bytes] - sSFA: cute.struct.Align[cute.struct.MemRange[cutlass.Float8E4M3FN, cute.cosize(sfa_smem_layout_staged.outer)], self.buffer_align_bytes] - sSFB: cute.struct.Align[cute.struct.MemRange[cutlass.Float8E4M3FN, cute.cosize(sfb_smem_layout_staged.outer)], self.buffer_align_bytes] + # SMEM for top-k merge: 128 threads × top_k entries + merge_scores: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 128 * self.top_k], 128] + merge_indices: cute.struct.Align[cute.struct.MemRange[cutlass.Int32, 128 * self.top_k], 128] + merge_acts: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 128 * self.top_k], 128] smem = utils.SmemAllocator() storage = smem.allocate(SharedStorage) # ============================================================ - # Pipelines + # Pipelines (following fused_swiglu.py exactly) # ============================================================ ab_pipeline = pipeline.PipelineTmaUmma.create( barrier_storage=storage.ab_full_mbar.data_ptr(), @@ -298,15 +388,17 @@ class Nvfp4FusedRouterKernel: self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1), tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, ) num_acc_cons = self.threads_per_warp * len(self.epilogue_warp_id) * (2 if use_2cta else 1) acc_pipeline = pipeline.PipelineUmmaAsync.create( barrier_storage=storage.acc_full_mbar.data_ptr(), - num_stages=self.num_acc_stage, + num_stages=self.num_acc_pipeline_stages, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, num_acc_cons), cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, ) tmem = utils.TmemAllocator( @@ -324,12 +416,32 @@ class Nvfp4FusedRouterKernel: self.threads_per_warp * len(self.epilogue_warp_id)) # ============================================================ - # SMEM tensors + # SMEM tensors (following fused_swiglu.py pattern) + # A/B use swizzled layouts (ComposedLayout: .outer + .inner) + # SFA/SFB use plain layouts (not Composed) # ============================================================ - sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner) - sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner) - sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged.outer, swizzle=sfa_smem_layout_staged.inner) - sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged.outer, swizzle=sfb_smem_layout_staged.inner) + sA = smem.allocate_tensor( + element_type=self.a_dtype, + layout=a_smem_layout_staged.outer, + byte_alignment=128, + swizzle=a_smem_layout_staged.inner, + ) + sB = smem.allocate_tensor( + element_type=self.b_dtype, + layout=b_smem_layout_staged.outer, + byte_alignment=128, + swizzle=b_smem_layout_staged.inner, + ) + sSFA = smem.allocate_tensor( + element_type=self.sf_dtype, + layout=sfa_smem_layout_staged, + byte_alignment=128, + ) + sSFB = smem.allocate_tensor( + element_type=self.sf_dtype, + layout=sfb_smem_layout_staged, + byte_alignment=128, + ) # ============================================================ # Multicast masks @@ -342,7 +454,7 @@ class Nvfp4FusedRouterKernel: sfb_mcast = cpasync.create_tma_multicast_mask(cluster_layout_sfb_vmnk, block_coord, mcast_mode=1) # ============================================================ - # Partition global tensors + # Partition global tensors (same as fused_swiglu TMA warp setup) # ============================================================ gA = cute.local_tile(mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None)) gB = cute.local_tile(mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None)) @@ -357,7 +469,7 @@ class Nvfp4FusedRouterKernel: thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_v) tCgSFB = thr_mma_sfb.partition_B(gSFB) - # TMA partitions for A/B + # TMA partitions for A/B (following fused_swiglu) a_cta_l = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape) tAsA, tAgA = cpasync.tma_partition(tma_atom_a, block_coord[2], a_cta_l, cute.group_modes(sA, 0, 3), cute.group_modes(tCgA, 0, 3)) @@ -365,28 +477,32 @@ class Nvfp4FusedRouterKernel: tBsB, tBgB = cpasync.tma_partition(tma_atom_b, block_coord[1], b_cta_l, cute.group_modes(sB, 0, 3), cute.group_modes(tCgB, 0, 3)) - # TMA partitions for SFA/SFB + # TMA partitions for SFA/SFB (following fused_swiglu) tAsSFA, tAgSFA = cpasync.tma_partition(tma_atom_sfa, block_coord[2], a_cta_l, cute.group_modes(sSFA, 0, 3), cute.group_modes(tCgSFA, 0, 3)) tAsSFA = cute.filter_zeros(tAsSFA) tAgSFA = cute.filter_zeros(tAgSFA) - sfb_cta_l = cute.make_layout(cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape) block_coord_sfb = cluster_layout_sfb_vmnk.get_flat_coord(cta_rank) + sfb_cta_l = cute.make_layout(cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape) tBsSFB, tBgSFB = cpasync.tma_partition(tma_atom_sfb, block_coord_sfb[1], sfb_cta_l, cute.group_modes(sSFB, 0, 3), cute.group_modes(tCgSFB, 0, 3)) tBsSFB = cute.filter_zeros(tBsSFB) tBgSFB = cute.filter_zeros(tBgSFB) - # TMEM accumulator shape + # TMEM accumulator acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + # Cluster arrive (before any TMA or pipeline ops) if cute.size(self.cluster_shape_mn) > 1: cute.arch.cluster_arrive_relaxed() + else: + cta_bar.arrive_and_wait() # ============================================================ - # TMA WARP + # TMA WARP — Load A, B, SFA, SFB from GMEM to SMEM + # (follows fused_swiglu TMA warp exactly) # ============================================================ if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_atom_a) @@ -444,7 +560,8 @@ class Nvfp4FusedRouterKernel: wt = tsched.get_current_work() # ============================================================ - # MMA WARP — blockscaled GEMM: (A * SFA) @ (B * SFB) -> TMEM + # MMA WARP — Blockscaled GEMM: (A * SFA) @ (B * SFB) -> TMEM + # (follows fused_swiglu MMA warp exactly) # ============================================================ if warp_idx == self.mma_warp_id: # Wait for cluster sync @@ -463,7 +580,6 @@ class Nvfp4FusedRouterKernel: tCrB = tiled_mma.make_fragment_B(sB) # S2T copies for SFA: SMEM -> TMEM - # The SFA tmem region starts after the accumulator columns sfa_tmem_ptr = acc_tmem_ptr tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( tiled_mma, self.mma_tiler, self.sf_vec_size, @@ -477,7 +593,7 @@ class Nvfp4FusedRouterKernel: cute.slice_(sfb_smem_layout_staged, (None, None, None, 0))) tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) - # S2T copy atoms + # S2T copy atoms (using fused_swiglu helper) tiled_copy_s2t_sfa, tCsSFA_compact_s2t, tCtSFA_compact_s2t = \ self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA, self.cta_group) tiled_copy_s2t_sfb, tCsSFB_compact_s2t, tCtSFB_compact_s2t = \ @@ -490,9 +606,7 @@ class Nvfp4FusedRouterKernel: ab_cs = pipeline.make_pipeline_state( pipeline.PipelineUserType.Consumer, self.num_ab_stage) acc_ps = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, self.num_acc_stage) - - num_tiles_executed = cutlass.Int32(0) + pipeline.PipelineUserType.Producer, self.num_acc_pipeline_stages) while wt.is_valid_tile: # Wait for accumulator buffer empty @@ -516,30 +630,28 @@ class Nvfp4FusedRouterKernel: if ab_cs.count < k_tiles and is_leader_cta: peek_ab_full = ab_pipeline.consumer_try_wait(ab_cs) - # Mainloop: K-tiles + # Mainloop: K-tiles (following fused_swiglu exactly) for kt in cutlass.range(0, k_tiles, 1, unroll=1): if is_leader_cta: ab_pipeline.consumer_wait(ab_cs, peek_ab_full) # Copy SFA/SFB from SMEM to TMEM - s2t_stage = ( - None, None, None, None, ab_cs.index, - ) + s2t_stage_coord = (None, None, None, None, ab_cs.index) cute.copy(tiled_copy_s2t_sfa, - tCsSFA_compact_s2t[s2t_stage], + tCsSFA_compact_s2t[s2t_stage_coord], tCtSFA_compact_s2t) cute.copy(tiled_copy_s2t_sfb, - tCsSFB_compact_s2t[s2t_stage], + tCsSFB_compact_s2t[s2t_stage_coord], tCtSFB_compact_s2t) - # Set SFA/SFB for MMA + # GEMM: (A * SFA) @ (B * SFB) -> Acc num_kblocks = cute.size(tCrA, mode=[2]) for kblock_idx in cutlass.range(num_kblocks, unroll=1): - sf_kblock = (None, None, kblock_idx) + sf_kblock_coord = (None, None, kblock_idx) tiled_mma.set(tcgen05.Field.SFA, - tCtSFA[sf_kblock].iterator) + tCtSFA[sf_kblock_coord].iterator) tiled_mma.set(tcgen05.Field.SFB, - tCtSFB[sf_kblock].iterator) + tCtSFB[sf_kblock_coord].iterator) kb_coord = (None, None, kblock_idx, ab_cs.index) cute.gemm(tiled_mma, tCrA[kb_coord], tCrB[kb_coord], @@ -558,12 +670,11 @@ class Nvfp4FusedRouterKernel: if is_leader_cta: acc_pipeline.producer_commit(acc_ps) acc_ps.advance() - num_tiles_executed += cutlass.Int32(1) tsched.advance_to_next_work() wt = tsched.get_current_work() - # Wait for accumulator buffer empty + # Wait for accumulator buffer empty (tail) if is_leader_cta: acc_pipeline.producer_tail(acc_ps) @@ -576,15 +687,13 @@ class Nvfp4FusedRouterKernel: # # Strategy: # 1. Read TMEM accumulator into registers via paired t2r copy + # (using epilogue_tmem_copy_and_partition from CUTLASS) # 2. For each element: compute act = sqrt(softplus(logit)), # score = act + e_bias[expert_idx] - # 3. Insert into per-thread running top-6 (sorted, fully unrolled) - # 4. After all tiles: write local top-6 to SMEM, one thread merges, - # sorts, renormalizes, writes to GMEM - # - # The top-6 is maintained in DESCENDING order: - # s0 >= s1 >= s2 >= s3 >= s4 >= s5 - # Insertion uses fully unrolled comparisons — no dynamic indexing. + # 3. Maintain per-thread running top-k via sorted insertion + # (fully unrolled for top_k=6, descending order) + # 4. After all tiles: write local top-k to SMEM, + # thread 0 merges, sorts, renormalizes, writes to GMEM # if warp_idx in self.epilogue_warp_id: if cute.size(self.cluster_shape_mn) > 1: @@ -597,61 +706,29 @@ class Nvfp4FusedRouterKernel: tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) # TMEM → register copy (paired atoms from CUTLASS) - epi_n = self.epi_tile_n tiled_copy_t2r, tTR_tAcc_base = epilogue_tmem_copy_and_partition( tCtAcc_base, epi_tile, self.epilogue_warp_id, acc_dtype, use_2cta) tTR_rAcc = tiled_copy_t2r.fragments_slice(tiled_copy_t2r, tTR_tAcc_base) - # Identity tensor for (row, col) coordinates - cAcc = cute.make_identity_tensor( - (self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1])) - tCcAcc = thr_mma.partition_C(cAcc) - cFlat = cute.flatten(tCcAcc) - - # Merge SMEM tensors (for cross-thread top-k merge) - s_merge_s = cute.make_tensor( - storage.merge_scores.data_ptr(), - cute.make_layout((128, TK))) - s_merge_i = cute.make_tensor( - storage.merge_indices.data_ptr(), - cute.make_layout((128, TK))) - s_merge_a = cute.make_tensor( - storage.merge_acts.data_ptr(), - cute.make_layout((128, TK))) - - # ------------------------------------------------------------------ - # Running top-6 per thread — individual scalar variables + # Per-thread running top-k — individual scalar variables # Stored in DESCENDING order: s0 >= s1 >= s2 >= s3 >= s4 >= s5 - # ------------------------------------------------------------------ - s0 = cutlass.Float32(-1e30) - s1 = cutlass.Float32(-1e30) - s2 = cutlass.Float32(-1e30) - s3 = cutlass.Float32(-1e30) - s4 = cutlass.Float32(-1e30) - s5 = cutlass.Float32(-1e30) - i0 = cutlass.Int32(-1) - i1 = cutlass.Int32(-1) - i2 = cutlass.Int32(-1) - i3 = cutlass.Int32(-1) - i4 = cutlass.Int32(-1) - i5 = cutlass.Int32(-1) - a0 = cutlass.Float32(0.0) - a1 = cutlass.Float32(0.0) - a2 = cutlass.Float32(0.0) - a3 = cutlass.Float32(0.0) - a4 = cutlass.Float32(0.0) - a5 = cutlass.Float32(0.0) + TK = self.top_k + s0 = cutlass.Float32(-1e30); s1 = cutlass.Float32(-1e30) + s2 = cutlass.Float32(-1e30); s3 = cutlass.Float32(-1e30) + s4 = cutlass.Float32(-1e30); s5 = cutlass.Float32(-1e30) + i0 = cutlass.Int32(-1); i1 = cutlass.Int32(-1) + i2 = cutlass.Int32(-1); i3 = cutlass.Int32(-1) + i4 = cutlass.Int32(-1); i5 = cutlass.Int32(-1) + a0 = cutlass.Float32(0.0); a1 = cutlass.Float32(0.0) + a2 = cutlass.Float32(0.0); a3 = cutlass.Float32(0.0) + a4 = cutlass.Float32(0.0); a5 = cutlass.Float32(0.0) # Tile scheduler + pipeline states tsched = utils.StaticPersistentTileScheduler.create( tile_sched_params, bidx, cute.arch.grid_dim()) wt = tsched.initial_work_tile_info() acc_cs = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.num_acc_stage) - - # Track which row we're computing top-k for (row 0 of each M-tile) - current_row = cutlass.Int32(-1) - num_tiles_done = cutlass.Int32(0) + pipeline.PipelineUserType.Consumer, self.num_acc_pipeline_stages) while wt.is_valid_tile: acc_pipeline.consumer_wait(acc_cs) @@ -661,16 +738,11 @@ class Nvfp4FusedRouterKernel: else: acc_stage_index = acc_cs.index - # Get tile N offset (which 128-expert slice this tile covers) + # Get tile N offset (which expert slice this tile covers) tc = wt.tile_idx tile_n_offset = tc[1] * self.cta_tile_shape_mnk[1] tile_m_base = tc[0] // cute.size(tiled_mma.thr_id.shape) * self.cta_tile_shape_mnk[0] - # If this is a new row, the running top-6 is already accumulated - # For the first tile of a row, we just continue accumulating - if num_tiles_done == cutlass.Int32(0): - current_row = tile_m_base - tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_stage_index)] tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) @@ -684,33 +756,27 @@ class Nvfp4FusedRouterKernel: # Early release accumulator for overlapping case if cutlass.const_expr(self.overlapping_accum): - if subtile_idx == self.iter_acc_early_release: + if subtile_idx == self.iter_acc_early_release_in_epilogue: with cute.arch.elect_one(): acc_pipeline.consumer_release(acc_cs) acc_cs.advance() # Process each element in the register fragment - rFlat = cute.flatten(tTR_rAcc) - elem_cnt = cute.size(rFlat) + rAcc = tTR_rAcc.load() + elem_cnt = cute.size(rAcc) for e in cutlass.range(elem_cnt, unroll=4): - logit = rFlat[e] - coord = cFlat[e] - row = coord[0] - col = coord[1] + logit = rAcc[e] + # Expert index = subtile_offset + e + e_idx = cutlass.Int32(tile_n_offset) + cutlass.Int32(subtile_idx * self.epi_tile_n) + cutlass.Int32(e) - # Expert index = col + tile_n_offset + (subtile_idx * epi_n) - e_idx = col + tile_n_offset + (subtile_idx * epi_n) - - # Only process row 0 (the actual token row) - # For M=1 padded to 128, only row 0 has valid data - if row == 0: + # Only process if expert index is valid + if e_idx < cutlass.Int32(N): # sqrt(softplus(logit)) # softplus(x) = max(x, 0) + log(1 + exp(-|x|)) abs_x = cute.math.absf(logit) pos = cute.math.fmax(logit, cutlass.Float32(0.0)) exp_neg = cute.math.exp(-abs_x) - one_plus = cutlass.Float32(1.0) + exp_neg - sp = pos + cute.math.log(one_plus) + sp = pos + cute.math.log(cutlass.Float32(1.0) + exp_neg) act = cute.math.sqrt(sp) # score = act + e_bias (for selection only) @@ -718,10 +784,8 @@ class Nvfp4FusedRouterKernel: # Sorted insertion into descending top-6 # s0 >= s1 >= s2 >= s3 >= s4 >= s5 - # If score <= s5, skip if score > s5: if score > s4: - # Shift s4 → s5 s5 = s4; i5 = i4; a5 = a4 if score > s3: s4 = s3; i4 = i3; a4 = a3 @@ -749,8 +813,6 @@ class Nvfp4FusedRouterKernel: acc_pipeline.consumer_release(acc_cs) acc_cs.advance() - num_tiles_done += cutlass.Int32(1) - tsched.advance_to_next_work() wt = tsched.get_current_work() @@ -759,16 +821,17 @@ class Nvfp4FusedRouterKernel: # ================================================================== # Each thread writes its running top-6 to SMEM tid = warp_idx * 32 + tidx - s_merge_s[tid, 0] = s0; s_merge_s[tid, 1] = s1; s_merge_s[tid, 2] = s2 - s_merge_s[tid, 3] = s3; s_merge_s[tid, 4] = s4; s_merge_s[tid, 5] = s5 - s_merge_i[tid, 0] = i0; s_merge_i[tid, 1] = i1; s_merge_i[tid, 2] = i2 - s_merge_i[tid, 3] = i3; s_merge_i[tid, 4] = i4; s_merge_i[tid, 5] = i5 - s_merge_a[tid, 0] = a0; s_merge_a[tid, 1] = a1; s_merge_a[tid, 2] = a2 - s_merge_a[tid, 3] = a3; s_merge_a[tid, 4] = a4; s_merge_a[tid, 5] = a5 + for k_idx in cutlass.range(TK, unroll=1): + s_val = s0 if k_idx == 0 else (s1 if k_idx == 1 else (s2 if k_idx == 2 else (s3 if k_idx == 3 else (s4 if k_idx == 4 else s5)))) + i_val = i0 if k_idx == 0 else (i1 if k_idx == 1 else (i2 if k_idx == 2 else (i3 if k_idx == 3 else (i4 if k_idx == 4 else i5)))) + a_val = a0 if k_idx == 0 else (a1 if k_idx == 1 else (a2 if k_idx == 2 else (a3 if k_idx == 3 else (a4 if k_idx == 4 else a5)))) + storage.merge_scores.data_ptr()[tid * TK + k_idx] = s_val + storage.merge_indices.data_ptr()[tid * TK + k_idx] = i_val + storage.merge_acts.data_ptr()[tid * TK + k_idx] = a_val epi_bar.arrive_and_wait() - # Thread 0 merges all 128 threads' top-6 into final result + # Thread 0 of warp 0 does the final merge + store if warp_idx == 0 and tidx == 0: # Initialize final top-6 from thread 0's data fs0 = s0; fs1 = s1; fs2 = s2; fs3 = s3; fs4 = s4; fs5 = s5 @@ -777,10 +840,10 @@ class Nvfp4FusedRouterKernel: # Merge all other threads (1..127) for t in cutlass.range(1, 128, unroll=1): - for k in cutlass.range(TK, unroll=1): - cs = s_merge_s[t, k] - ci = s_merge_i[t, k] - ca = s_merge_a[t, k] + for k_idx in cutlass.range(TK, unroll=1): + cs = storage.merge_scores.data_ptr()[t * TK + k_idx] + ci = storage.merge_indices.data_ptr()[t * TK + k_idx] + ca = storage.merge_acts.data_ptr()[t * TK + k_idx] # Only merge if this is a valid entry (index >= 0) if ci >= cutlass.Int32(0): # Sorted insertion into final top-6 (descending) @@ -812,20 +875,19 @@ class Nvfp4FusedRouterKernel: inv_sum = cutlass.Float32(1.0) / act_sum sc = cutlass.Float32(routed_scaling_factor) - # Store to GMEM (row 0 of the M-tile) - row_idx = cutlass.Int32(0) - out_w_tensor[row_idx, 0] = fa0 * inv_sum * sc - out_w_tensor[row_idx, 1] = fa1 * inv_sum * sc - out_w_tensor[row_idx, 2] = fa2 * inv_sum * sc - out_w_tensor[row_idx, 3] = fa3 * inv_sum * sc - out_w_tensor[row_idx, 4] = fa4 * inv_sum * sc - out_w_tensor[row_idx, 5] = fa5 * inv_sum * sc - out_id_tensor[row_idx, 0] = fi0 - out_id_tensor[row_idx, 1] = fi1 - out_id_tensor[row_idx, 2] = fi2 - out_id_tensor[row_idx, 3] = fi3 - out_id_tensor[row_idx, 4] = fi4 - out_id_tensor[row_idx, 5] = fi5 + # Store to GMEM + out_w_tensor[0, 0] = fa0 * inv_sum * sc + out_w_tensor[0, 1] = fa1 * inv_sum * sc + out_w_tensor[0, 2] = fa2 * inv_sum * sc + out_w_tensor[0, 3] = fa3 * inv_sum * sc + out_w_tensor[0, 4] = fa4 * inv_sum * sc + out_w_tensor[0, 5] = fa5 * inv_sum * sc + out_id_tensor[0, 0] = fi0 + out_id_tensor[0, 1] = fi1 + out_id_tensor[0, 2] = fi2 + out_id_tensor[0, 3] = fi3 + out_id_tensor[0, 4] = fi4 + out_id_tensor[0, 5] = fi5 epi_bar.arrive_and_wait() @@ -853,21 +915,6 @@ def run_nvfp4_fused_router( Single-kernel: NVFP4 block-scaled GEMM + fused router epilogue. No intermediate GMEM buffer. No BF16 fallback. - - Parameters - ---------- - hidden_states : [M, K] BF16 — raw activation - mat_b : CuTe tensor — gate weight in NVFP4 blockscaled layout - scale_b : CuTe tensor — gate weight scale factors in blockscaled layout - gsa : activation global scale (scalar) - gsb_val : weight global scale (float) - e_bias : [E] FP32 — per-expert selection bias - routed_scaling_factor : float - top_k : int (default 6) - - Returns - ------- - (topk_weights, topk_ids) — [M, top_k] FP32 and [M, top_k] int32 """ import cutlass.torch as cutlass_torch from dsv4.ops.quantize import quantize_activation_nvfp4 diff --git a/tests/unit/test_cute_math_api.py b/tests/unit/test_cute_math_api.py new file mode 100644 index 00000000..4328c913 --- /dev/null +++ b/tests/unit/test_cute_math_api.py @@ -0,0 +1,78 @@ +"""Test: check what CuTeDSL math operations are available.""" +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +def test_cute_math_api(): + """Enumerate available CuTeDSL math/arch operations.""" + import cutlass + import cutlass.cute as cute + + # Check cute.math module + print("=== cute.math attributes ===") + if hasattr(cute, 'math'): + for attr in sorted(dir(cute.math)): + if not attr.startswith('_'): + print(f" cute.math.{attr}") + else: + print(" cute.math does not exist") + + # Check cute.arch module for math + print("\n=== cute.arch math-related attributes ===") + if hasattr(cute, 'arch'): + for attr in sorted(dir(cute.arch)): + if any(k in attr.lower() for k in ['sqrt', 'log', 'exp', 'abs', 'sin', 'cos', 'rsqrt', 'rcp', 'fma', 'div']): + print(f" cute.arch.{attr}") + + # Check cute directly for math + print("\n=== cute math-related attributes ===") + for attr in sorted(dir(cute)): + if any(k in attr.lower() for k in ['sqrt', 'log', 'exp', 'abs', 'sin', 'cos', 'rsqrt', 'rcp']): + print(f" cute.{attr}") + + # Check cutlass module for math + print("\n=== cutlass math-related attributes ===") + for attr in sorted(dir(cutlass)): + if any(k in attr.lower() for k in ['sqrt', 'log', 'exp', 'abs', 'rsqrt', 'rcp']): + print(f" cutlass.{attr}") + + # Check if cute.exp exists + print(f"\n=== Key functions ===") + print(f" cute.exp exists: {hasattr(cute, 'exp')}") + print(f" cute.log exists: {hasattr(cute, 'log')}") + print(f" cute.sqrt exists: {hasattr(cute, 'sqrt')}") + print(f" cute.math exists: {hasattr(cute, 'math')}") + + if hasattr(cute, 'math'): + print(f" cute.math.fmax exists: {hasattr(cute.math, 'fmax')}") + print(f" cute.math.fmin exists: {hasattr(cute.math, 'fmin')}") + print(f" cute.math.absf exists: {hasattr(cute.math, 'absf')}") + print(f" cute.math.sqrt exists: {hasattr(cute.math, 'sqrt')}") + print(f" cute.math.log exists: {hasattr(cute.math, 'log')}") + print(f" cute.math.exp exists: {hasattr(cute.math, 'exp')}") + print(f" cute.math.rsqrt exists: {hasattr(cute.math, 'rsqrt')}") + print(f" cute.math.rcp exists: {hasattr(cute.math, 'rcp')}") + print(f" cute.math.sin exists: {hasattr(cute.math, 'sin')}") + print(f" cute.math.cos exists: {hasattr(cute.math, 'cos')}") + print(f" cute.math.copysign exists: {hasattr(cute.math, 'copysign')}") + print(f" cute.math.clamp exists: {hasattr(cute.math, 'clamp')}") + + # Check arch operations + print(f"\n cute.arch.fmax exists: {hasattr(cute.arch, 'fmax')}") + print(f" cute.arch.fmin exists: {hasattr(cute.arch, 'fmin')}") + + # Try to find math operations in cutlass._mlir_ops or similar + print("\n=== MLIR operations ===") + for mod_name in ['cutlass._mlir_ops', 'cutlass.mlir', 'cutlass.cute._mlir']: + try: + mod = __import__(mod_name, fromlist=['']) + math_attrs = [a for a in dir(mod) if any(k in a.lower() for k in ['sqrt', 'log', 'exp', 'abs', 'rsqrt'])] + if math_attrs: + print(f" {mod_name}: {math_attrs}") + except ImportError: + pass + + print("\nDone.") + +if __name__ == "__main__": + test_cute_math_api()