Test: enumerate CuTeDSL math API to check available operations

This commit is contained in:
2026-06-01 09:11:29 +00:00
parent 9b86b2b414
commit 4f4ae8febd
2 changed files with 326 additions and 201 deletions

View File

@@ -6,7 +6,7 @@ with fused router epilogue (sqrt(softplus) + e_bias + top-k + renorm).
PRODUCTION KERNEL. No intermediate GMEM buffer. No BF16 fallback.
The GEMM accumulates logits in TMEM, then the epilogue warps process them directly:
1. TMEM -> registers (via paired t2r atom from CUTLASS epilogue helpers)
2. For each logit: sqrt(softplus(logit)) + e_bias -> score; track top-k via min-heap
2. For each logit: sqrt(softplus(logit)) + e_bias -> score; track top-k via sorted insertion
3. After all subtiles: sort, renormalize, write (topk_weights, topk_ids) to GMEM
Warp specialization (6 warps, no scheduler for dense GEMM):
@@ -17,17 +17,21 @@ Warp specialization (6 warps, no scheduler for dense GEMM):
Pipeline structure (2 pipelines):
AB pipeline: TMA (producer) -> MMA (consumer) [PipelineTmaUmma]
Acc pipeline: MMA (producer) -> Epilogue (consumer) [PipelineUmmaAsync]
Architecture reference: FusedSwiGLUScaledGroupedGemmKernel (dsv4/kernels/gemm/fused_swiglu.py)
The blockscaled GEMM mainloop follows the same pattern exactly.
The epilogue is custom: instead of TMA store, we do TMEM->reg top-k reduction.
"""
from __future__ import annotations
from typing import Tuple
import math
from typing import Tuple, Optional, Type, Union
import cuda.bindings.driver as cuda
import torch
import cutlass
import cutlass.cute as cute
from cutlass.cute.typing import Pointer
from cutlass.cute.nvgpu import cpasync, tcgen05
import cutlass.utils as utils
import cutlass.pipeline as pipeline
@@ -35,12 +39,20 @@ import cutlass.utils.blackwell_helpers as sm100_utils
import cutlass.utils.blockscaled_layout as blockscaled_utils
from cutlass.utils.gemm.sm100 import (
epilogue_tmem_copy_and_partition,
epilogue_smem_copy_and_partition,
transform_partitioned_tensor_layout,
)
class Nvfp4FusedRouterKernel:
"""
NVFP4 blockscaled GEMM + fused router epilogue.
Dense (non-grouped) GEMM: [M, K] @ [K, E] -> [M, E] with NVFP4 weights.
Custom epilogue: TMEM -> registers -> sqrt(softplus) + e_bias + top-k + renorm -> GMEM.
This follows the FusedSwiGLUScaledGroupedGemmKernel pattern for the
blockscaled GEMM mainloop exactly, with a custom epilogue.
"""
def __init__(
self,
@@ -55,38 +67,48 @@ class Nvfp4FusedRouterKernel:
self.top_k = top_k
self.use_2cta_instrs = mma_tiler_mnk[0] == 256
self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
self.arch = "sm_100"
# 6-warp specialization (no scheduler warp for dense GEMM)
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4
self.tma_warp_id = 5
self.threads_per_warp = 32
self.threads_per_cta = self.threads_per_warp * 6
# Barrier IDs
self.cta_sync_bar_id = 1
self.epilogue_sync_bar_id = 2
self.tmem_alloc_sync_bar_id = 3
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
self.smem_capacity = utils.get_smem_capacity_in_bytes(self.arch)
self.occupancy = 1
self.buffer_align_bytes = 1024
# -----------------------------------------------------------------
# _create_tiled_mma / _create_tiled_mma_sfb
# -----------------------------------------------------------------
def _create_tiled_mma(self, a_dtype, a_major_mode, b_major_mode, sf_dtype):
return sm100_utils.make_blockscaled_trivial_tiled_mma(
a_dtype, a_major_mode, b_major_mode, sf_dtype,
self.sf_vec_size, self.cta_group,
(self.mma_tiler_mnk[0], self.mma_tiler_mnk[1]),
self.mma_inst_shape_mn,
)
def _create_tiled_mma_sfb(self, a_dtype, a_major_mode, b_major_mode, sf_dtype):
mma_inst_shape_mn_sfb = (
self.mma_tiler_mnk[0] // (2 if self.use_2cta_instrs else 1),
cute.round_up(self.mma_tiler_mnk[1], 128),
)
return sm100_utils.make_blockscaled_trivial_tiled_mma(
a_dtype, a_major_mode, b_major_mode, sf_dtype,
self.sf_vec_size, tcgen05.CtaGroup.ONE, mma_inst_shape_mn_sfb,
self.sf_vec_size, tcgen05.CtaGroup.ONE,
self.mma_inst_shape_mn_sfb,
)
# -----------------------------------------------------------------
# _setup_attributes
# -----------------------------------------------------------------
def _setup_attributes(self, tiled_mma, tiled_mma_sfb, a_dtype, b_dtype, sf_dtype):
"""Set up kernel attributes. Mirrors FusedSwiGLUScaledGroupedGemmKernel._setup_attributes."""
self.mma_inst_shape_mn = (self.mma_tiler_mnk[0], self.mma_tiler_mnk[1])
self.mma_inst_shape_mn_sfb = (
self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1),
@@ -108,11 +130,13 @@ class Nvfp4FusedRouterKernel:
)
self.cta_tile_shape_mnk = (
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
self.mma_tiler[1], self.mma_tiler[2],
self.mma_tiler[1],
self.mma_tiler[2],
)
self.cta_tile_shape_mnk_sfb = (
self.mma_tiler_sfb[0] // cute.size(tiled_mma.thr_id.shape),
self.mma_tiler_sfb[1], self.mma_tiler_sfb[2],
self.mma_tiler_sfb[1],
self.mma_tiler_sfb[2],
)
self.cluster_layout_vmnk = cute.tiled_divide(
@@ -129,29 +153,21 @@ class Nvfp4FusedRouterKernel:
self.is_b_mcast = self.num_mcast_ctas_b > 1
self.is_sfb_mcast = self.num_mcast_ctas_sfb > 1
# Epilogue tile: for router, we process all N columns (expert dimension).
# Use epi_tile = (128, 32) as the subtile for t2r copy.
# This determines how many columns are loaded from TMEM per subtile.
self.epi_tile = (
cute.make_layout(self.cta_tile_shape_mnk[0]),
cute.make_layout(self.cta_tile_shape_mnk[1]),
cute.make_layout((32, 1), stride=(1, 32)),
)
self.epi_tile_n = cute.size(self.epi_tile[1])
self.overlapping_accum = self.cta_tile_shape_mnk[1] == 256
self.num_acc_stage = 1 if self.overlapping_accum else 2
# Stage counts
self.num_acc_stage = 2
self.num_ab_stage = 2
self.num_c_stage = 2 # not used for TMA store, but needed for stage computation
sf_atom_mn = 32
self.num_sfa_tmem_cols = (self.cta_tile_shape_mnk[0] // sf_atom_mn) * mma_inst_tile_k
self.num_sfb_tmem_cols = (self.cta_tile_shape_mnk_sfb[1] // sf_atom_mn) * mma_inst_tile_k
self.num_sf_tmem_cols = self.num_sfa_tmem_cols + self.num_sfb_tmem_cols
if self.overlapping_accum:
self.num_accumulator_tmem_cols = self.cta_tile_shape_mnk[1] * 2 - self.num_sf_tmem_cols
else:
self.num_accumulator_tmem_cols = self.cta_tile_shape_mnk[1] * self.num_acc_stage
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake)
# Compute SMEM layouts for A, B, SFA, SFB
self.a_smem_layout_staged = sm100_utils.make_smem_layout_a(
tiled_mma, self.mma_tiler, a_dtype, self.num_ab_stage)
self.b_smem_layout_staged = sm100_utils.make_smem_layout_b(
@@ -161,6 +177,28 @@ class Nvfp4FusedRouterKernel:
self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb(
tiled_mma, self.mma_tiler, self.sf_vec_size, self.num_ab_stage)
# Overlapping accumulator (N=256 case)
self.overlapping_accum = self.cta_tile_shape_mnk[1] == 256
if self.overlapping_accum:
self.num_acc_pipeline_stages = 1
else:
self.num_acc_pipeline_stages = self.num_acc_stage
# TMEM column counts
sf_atom_mn = 32
self.num_sfa_tmem_cols = (self.cta_tile_shape_mnk[0] // sf_atom_mn) * mma_inst_tile_k
self.num_sfb_tmem_cols = (self.cta_tile_shape_mnk_sfb[1] // sf_atom_mn) * mma_inst_tile_k
self.num_sf_tmem_cols = self.num_sfa_tmem_cols + self.num_sfb_tmem_cols
self.num_accumulator_tmem_cols = self.cta_tile_shape_mnk[1] * self.num_acc_stage - (
self.num_sf_tmem_cols if self.overlapping_accum else 0
)
# Only when overlapping_accum, release accumulator buffer early in epilogue
self.iter_acc_early_release_in_epilogue = (
self.num_sf_tmem_cols // self.epi_tile_n
)
# TMA load bytes
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
a_smem_0 = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
b_smem_0 = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
@@ -173,42 +211,90 @@ class Nvfp4FusedRouterKernel:
cute.size_in_bytes(sf_dtype, sfb_smem_0)
) * atom_thr_size
self.iter_acc_early_release = self.num_sf_tmem_cols // self.epi_tile_n
# TMEM allocation size
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake)
# -----------------------------------------------------------------
# mainloop_s2t_copy_and_partition (same as fused_swiglu.py)
# -----------------------------------------------------------------
def mainloop_s2t_copy_and_partition(self, sSF, tSF, cta_group):
"""Make tiledCopy for SMEM -> TMEM load of a scale factor tensor."""
tCsSF_compact = cute.filter_zeros(sSF)
tCtSF_compact = cute.filter_zeros(tSF)
copy_atom_s2t = cute.make_copy_atom(
tcgen05.Cp4x32x128bOp(cta_group),
self.sf_dtype,
)
tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact)
thr_copy_s2t = tiled_copy_s2t.get_slice(0)
tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact)
tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(
tiled_copy_s2t, tCsSF_compact_s2t_
)
tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact)
return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t
# -----------------------------------------------------------------
# run() — Python entry point
# -----------------------------------------------------------------
def run(self, mat_a, mat_b, scale_a, scale_b, e_bias, out_weights, out_ids,
M, N, K, routed_scaling_factor, top_k, stream=None):
if stream is None:
stream = cuda.CUstream(0)
a_dtype = cutlass.Float4E2M1FN
b_dtype = cutlass.Float4E2M1FN
sf_dtype = cutlass.Float8E4M3FN
# Infer dtypes and major modes from tensors
a_dtype = mat_a.element_type
b_dtype = mat_b.element_type
sf_dtype = scale_a.element_type
a_major_mode = utils.LayoutEnum.from_tensor(mat_a).mma_major_mode()
b_major_mode = utils.LayoutEnum.from_tensor(mat_b).mma_major_mode()
# Save for kernel use
self.a_dtype = a_dtype
self.b_dtype = b_dtype
self.sf_dtype = sf_dtype
self.a_major_mode = a_major_mode
self.b_major_mode = b_major_mode
tiled_mma = self._create_tiled_mma(a_dtype, a_major_mode, b_major_mode, sf_dtype)
tiled_mma_sfb = self._create_tiled_mma_sfb(a_dtype, a_major_mode, b_major_mode, sf_dtype)
self._setup_attributes(tiled_mma, tiled_mma_sfb, a_dtype, b_dtype, sf_dtype)
a_smem_0 = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
# TMA atoms — following fused_swiglu.py exactly
# A
a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
a_op, mat_a, a_smem_0, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
a_op, mat_a, a_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
b_smem_0 = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
# B
b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id)
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
b_op, mat_b, b_smem_0, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
b_op, mat_b, b_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
sfa_smem_0 = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0))
# SFA
sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0))
tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A(
a_op, scale_a, sfa_smem_0, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
sfa_op, scale_a, sfa_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape,
internal_type=cutlass.Uint64)
sfb_smem_0 = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0))
sfb_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma_sfb.thr_id)
# SFB
sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id)
sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0))
tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B(
sfb_op, scale_b, sfb_smem_0, self.mma_tiler_sfb, tiled_mma_sfb, self.cluster_layout_sfb_vmnk.shape)
sfb_op, scale_b, sfb_smem_layout, self.mma_tiler_sfb, tiled_mma_sfb,
self.cluster_layout_sfb_vmnk.shape, internal_type=cutlass.Uint64)
# Grid: dense GEMM, one CTA per (M_tile, N_tile)
num_M_tiles = cute.ceil_div(M, self.cta_tile_shape_mnk[0])
num_N_tiles = cute.ceil_div(N, self.cta_tile_shape_mnk[1])
L = 1
@@ -239,6 +325,10 @@ class Nvfp4FusedRouterKernel:
cute.compile(_compiled_fn, mat_a, mat_b, scale_a, scale_b, e_bias, out_weights, out_ids)
# -----------------------------------------------------------------
# GPU kernel
# -----------------------------------------------------------------
@cute.kernel
def _kernel(self, tiled_mma, tiled_mma_sfb,
tma_atom_a, mA_mkl, tma_atom_b, mB_nkl,
@@ -262,32 +352,32 @@ class Nvfp4FusedRouterKernel:
block_coord = cluster_layout_vmnk.get_flat_coord(cta_rank)
acc_dtype = cutlass.Float32
sf_dtype = cutlass.Float8E4M3FN
# Reconstruct SMEM layout slices (same as fused_swiglu kernel)
a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, None, 0))
b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, None, 0))
sfa_smem_layout = cute.slice_(sfa_smem_layout_staged, (None, None, None, 0))
sfb_smem_layout = cute.slice_(sfb_smem_layout_staged, (None, None, None, 0))
# ============================================================
# Shared storage
# ============================================================
@cute.struct
class SharedStorage:
ab_full_mbar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
ab_empty_mbar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
acc_full_mbar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
acc_empty_mbar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
ab_full_mbar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
acc_full_mbar: cute.struct.MemRange[cutlass.Int64, self.num_acc_pipeline_stages * 2]
tmem_dealloc_mbar: cutlass.Int64
tmem_holding: cutlass.Int32
merge_scores: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 128*self.top_k], 128]
merge_indices: cute.struct.Align[cute.struct.MemRange[cutlass.Int32, 128*self.top_k], 128]
merge_acts: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 128*self.top_k], 128]
sA: cute.struct.Align[cute.struct.MemRange[cutlass.Float4E2M1FN, cute.cosize(a_smem_layout_staged.outer)], self.buffer_align_bytes]
sB: cute.struct.Align[cute.struct.MemRange[cutlass.Float4E2M1FN, cute.cosize(b_smem_layout_staged.outer)], self.buffer_align_bytes]
sSFA: cute.struct.Align[cute.struct.MemRange[cutlass.Float8E4M3FN, cute.cosize(sfa_smem_layout_staged.outer)], self.buffer_align_bytes]
sSFB: cute.struct.Align[cute.struct.MemRange[cutlass.Float8E4M3FN, cute.cosize(sfb_smem_layout_staged.outer)], self.buffer_align_bytes]
# SMEM for top-k merge: 128 threads × top_k entries
merge_scores: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 128 * self.top_k], 128]
merge_indices: cute.struct.Align[cute.struct.MemRange[cutlass.Int32, 128 * self.top_k], 128]
merge_acts: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 128 * self.top_k], 128]
smem = utils.SmemAllocator()
storage = smem.allocate(SharedStorage)
# ============================================================
# Pipelines
# Pipelines (following fused_swiglu.py exactly)
# ============================================================
ab_pipeline = pipeline.PipelineTmaUmma.create(
barrier_storage=storage.ab_full_mbar.data_ptr(),
@@ -298,15 +388,17 @@ class Nvfp4FusedRouterKernel:
self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1),
tx_count=self.num_tma_load_bytes,
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
)
num_acc_cons = self.threads_per_warp * len(self.epilogue_warp_id) * (2 if use_2cta else 1)
acc_pipeline = pipeline.PipelineUmmaAsync.create(
barrier_storage=storage.acc_full_mbar.data_ptr(),
num_stages=self.num_acc_stage,
num_stages=self.num_acc_pipeline_stages,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, num_acc_cons),
cta_layout_vmnk=cluster_layout_vmnk,
defer_sync=True,
)
tmem = utils.TmemAllocator(
@@ -324,12 +416,32 @@ class Nvfp4FusedRouterKernel:
self.threads_per_warp * len(self.epilogue_warp_id))
# ============================================================
# SMEM tensors
# SMEM tensors (following fused_swiglu.py pattern)
# A/B use swizzled layouts (ComposedLayout: .outer + .inner)
# SFA/SFB use plain layouts (not Composed)
# ============================================================
sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner)
sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner)
sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged.outer, swizzle=sfa_smem_layout_staged.inner)
sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged.outer, swizzle=sfb_smem_layout_staged.inner)
sA = smem.allocate_tensor(
element_type=self.a_dtype,
layout=a_smem_layout_staged.outer,
byte_alignment=128,
swizzle=a_smem_layout_staged.inner,
)
sB = smem.allocate_tensor(
element_type=self.b_dtype,
layout=b_smem_layout_staged.outer,
byte_alignment=128,
swizzle=b_smem_layout_staged.inner,
)
sSFA = smem.allocate_tensor(
element_type=self.sf_dtype,
layout=sfa_smem_layout_staged,
byte_alignment=128,
)
sSFB = smem.allocate_tensor(
element_type=self.sf_dtype,
layout=sfb_smem_layout_staged,
byte_alignment=128,
)
# ============================================================
# Multicast masks
@@ -342,7 +454,7 @@ class Nvfp4FusedRouterKernel:
sfb_mcast = cpasync.create_tma_multicast_mask(cluster_layout_sfb_vmnk, block_coord, mcast_mode=1)
# ============================================================
# Partition global tensors
# Partition global tensors (same as fused_swiglu TMA warp setup)
# ============================================================
gA = cute.local_tile(mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None))
gB = cute.local_tile(mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None))
@@ -357,7 +469,7 @@ class Nvfp4FusedRouterKernel:
thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_v)
tCgSFB = thr_mma_sfb.partition_B(gSFB)
# TMA partitions for A/B
# TMA partitions for A/B (following fused_swiglu)
a_cta_l = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_atom_a, block_coord[2], a_cta_l,
cute.group_modes(sA, 0, 3), cute.group_modes(tCgA, 0, 3))
@@ -365,28 +477,32 @@ class Nvfp4FusedRouterKernel:
tBsB, tBgB = cpasync.tma_partition(tma_atom_b, block_coord[1], b_cta_l,
cute.group_modes(sB, 0, 3), cute.group_modes(tCgB, 0, 3))
# TMA partitions for SFA/SFB
# TMA partitions for SFA/SFB (following fused_swiglu)
tAsSFA, tAgSFA = cpasync.tma_partition(tma_atom_sfa, block_coord[2], a_cta_l,
cute.group_modes(sSFA, 0, 3), cute.group_modes(tCgSFA, 0, 3))
tAsSFA = cute.filter_zeros(tAsSFA)
tAgSFA = cute.filter_zeros(tAgSFA)
sfb_cta_l = cute.make_layout(cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape)
block_coord_sfb = cluster_layout_sfb_vmnk.get_flat_coord(cta_rank)
sfb_cta_l = cute.make_layout(cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape)
tBsSFB, tBgSFB = cpasync.tma_partition(tma_atom_sfb, block_coord_sfb[1], sfb_cta_l,
cute.group_modes(sSFB, 0, 3), cute.group_modes(tCgSFB, 0, 3))
tBsSFB = cute.filter_zeros(tBsSFB)
tBgSFB = cute.filter_zeros(tBgSFB)
# TMEM accumulator shape
# TMEM accumulator
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
# Cluster arrive (before any TMA or pipeline ops)
if cute.size(self.cluster_shape_mn) > 1:
cute.arch.cluster_arrive_relaxed()
else:
cta_bar.arrive_and_wait()
# ============================================================
# TMA WARP
# TMA WARP — Load A, B, SFA, SFB from GMEM to SMEM
# (follows fused_swiglu TMA warp exactly)
# ============================================================
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_atom_a)
@@ -444,7 +560,8 @@ class Nvfp4FusedRouterKernel:
wt = tsched.get_current_work()
# ============================================================
# MMA WARP — blockscaled GEMM: (A * SFA) @ (B * SFB) -> TMEM
# MMA WARP — Blockscaled GEMM: (A * SFA) @ (B * SFB) -> TMEM
# (follows fused_swiglu MMA warp exactly)
# ============================================================
if warp_idx == self.mma_warp_id:
# Wait for cluster sync
@@ -463,7 +580,6 @@ class Nvfp4FusedRouterKernel:
tCrB = tiled_mma.make_fragment_B(sB)
# S2T copies for SFA: SMEM -> TMEM
# The SFA tmem region starts after the accumulator columns
sfa_tmem_ptr = acc_tmem_ptr
tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa(
tiled_mma, self.mma_tiler, self.sf_vec_size,
@@ -477,7 +593,7 @@ class Nvfp4FusedRouterKernel:
cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)))
tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout)
# S2T copy atoms
# S2T copy atoms (using fused_swiglu helper)
tiled_copy_s2t_sfa, tCsSFA_compact_s2t, tCtSFA_compact_s2t = \
self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA, self.cta_group)
tiled_copy_s2t_sfb, tCsSFB_compact_s2t, tCtSFB_compact_s2t = \
@@ -490,9 +606,7 @@ class Nvfp4FusedRouterKernel:
ab_cs = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_ab_stage)
acc_ps = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num_acc_stage)
num_tiles_executed = cutlass.Int32(0)
pipeline.PipelineUserType.Producer, self.num_acc_pipeline_stages)
while wt.is_valid_tile:
# Wait for accumulator buffer empty
@@ -516,30 +630,28 @@ class Nvfp4FusedRouterKernel:
if ab_cs.count < k_tiles and is_leader_cta:
peek_ab_full = ab_pipeline.consumer_try_wait(ab_cs)
# Mainloop: K-tiles
# Mainloop: K-tiles (following fused_swiglu exactly)
for kt in cutlass.range(0, k_tiles, 1, unroll=1):
if is_leader_cta:
ab_pipeline.consumer_wait(ab_cs, peek_ab_full)
# Copy SFA/SFB from SMEM to TMEM
s2t_stage = (
None, None, None, None, ab_cs.index,
)
s2t_stage_coord = (None, None, None, None, ab_cs.index)
cute.copy(tiled_copy_s2t_sfa,
tCsSFA_compact_s2t[s2t_stage],
tCsSFA_compact_s2t[s2t_stage_coord],
tCtSFA_compact_s2t)
cute.copy(tiled_copy_s2t_sfb,
tCsSFB_compact_s2t[s2t_stage],
tCsSFB_compact_s2t[s2t_stage_coord],
tCtSFB_compact_s2t)
# Set SFA/SFB for MMA
# GEMM: (A * SFA) @ (B * SFB) -> Acc
num_kblocks = cute.size(tCrA, mode=[2])
for kblock_idx in cutlass.range(num_kblocks, unroll=1):
sf_kblock = (None, None, kblock_idx)
sf_kblock_coord = (None, None, kblock_idx)
tiled_mma.set(tcgen05.Field.SFA,
tCtSFA[sf_kblock].iterator)
tCtSFA[sf_kblock_coord].iterator)
tiled_mma.set(tcgen05.Field.SFB,
tCtSFB[sf_kblock].iterator)
tCtSFB[sf_kblock_coord].iterator)
kb_coord = (None, None, kblock_idx, ab_cs.index)
cute.gemm(tiled_mma, tCrA[kb_coord], tCrB[kb_coord],
@@ -558,12 +670,11 @@ class Nvfp4FusedRouterKernel:
if is_leader_cta:
acc_pipeline.producer_commit(acc_ps)
acc_ps.advance()
num_tiles_executed += cutlass.Int32(1)
tsched.advance_to_next_work()
wt = tsched.get_current_work()
# Wait for accumulator buffer empty
# Wait for accumulator buffer empty (tail)
if is_leader_cta:
acc_pipeline.producer_tail(acc_ps)
@@ -576,15 +687,13 @@ class Nvfp4FusedRouterKernel:
#
# Strategy:
# 1. Read TMEM accumulator into registers via paired t2r copy
# (using epilogue_tmem_copy_and_partition from CUTLASS)
# 2. For each element: compute act = sqrt(softplus(logit)),
# score = act + e_bias[expert_idx]
# 3. Insert into per-thread running top-6 (sorted, fully unrolled)
# 4. After all tiles: write local top-6 to SMEM, one thread merges,
# sorts, renormalizes, writes to GMEM
#
# The top-6 is maintained in DESCENDING order:
# s0 >= s1 >= s2 >= s3 >= s4 >= s5
# Insertion uses fully unrolled comparisons — no dynamic indexing.
# 3. Maintain per-thread running top-k via sorted insertion
# (fully unrolled for top_k=6, descending order)
# 4. After all tiles: write local top-k to SMEM,
# thread 0 merges, sorts, renormalizes, writes to GMEM
#
if warp_idx in self.epilogue_warp_id:
if cute.size(self.cluster_shape_mn) > 1:
@@ -597,61 +706,29 @@ class Nvfp4FusedRouterKernel:
tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout)
# TMEM → register copy (paired atoms from CUTLASS)
epi_n = self.epi_tile_n
tiled_copy_t2r, tTR_tAcc_base = epilogue_tmem_copy_and_partition(
tCtAcc_base, epi_tile, self.epilogue_warp_id, acc_dtype, use_2cta)
tTR_rAcc = tiled_copy_t2r.fragments_slice(tiled_copy_t2r, tTR_tAcc_base)
# Identity tensor for (row, col) coordinates
cAcc = cute.make_identity_tensor(
(self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1]))
tCcAcc = thr_mma.partition_C(cAcc)
cFlat = cute.flatten(tCcAcc)
# Merge SMEM tensors (for cross-thread top-k merge)
s_merge_s = cute.make_tensor(
storage.merge_scores.data_ptr(),
cute.make_layout((128, TK)))
s_merge_i = cute.make_tensor(
storage.merge_indices.data_ptr(),
cute.make_layout((128, TK)))
s_merge_a = cute.make_tensor(
storage.merge_acts.data_ptr(),
cute.make_layout((128, TK)))
# ------------------------------------------------------------------
# Running top-6 per thread — individual scalar variables
# Per-thread running top-k — individual scalar variables
# Stored in DESCENDING order: s0 >= s1 >= s2 >= s3 >= s4 >= s5
# ------------------------------------------------------------------
s0 = cutlass.Float32(-1e30)
s1 = cutlass.Float32(-1e30)
s2 = cutlass.Float32(-1e30)
s3 = cutlass.Float32(-1e30)
s4 = cutlass.Float32(-1e30)
s5 = cutlass.Float32(-1e30)
i0 = cutlass.Int32(-1)
i1 = cutlass.Int32(-1)
i2 = cutlass.Int32(-1)
i3 = cutlass.Int32(-1)
i4 = cutlass.Int32(-1)
i5 = cutlass.Int32(-1)
a0 = cutlass.Float32(0.0)
a1 = cutlass.Float32(0.0)
a2 = cutlass.Float32(0.0)
a3 = cutlass.Float32(0.0)
a4 = cutlass.Float32(0.0)
a5 = cutlass.Float32(0.0)
TK = self.top_k
s0 = cutlass.Float32(-1e30); s1 = cutlass.Float32(-1e30)
s2 = cutlass.Float32(-1e30); s3 = cutlass.Float32(-1e30)
s4 = cutlass.Float32(-1e30); s5 = cutlass.Float32(-1e30)
i0 = cutlass.Int32(-1); i1 = cutlass.Int32(-1)
i2 = cutlass.Int32(-1); i3 = cutlass.Int32(-1)
i4 = cutlass.Int32(-1); i5 = cutlass.Int32(-1)
a0 = cutlass.Float32(0.0); a1 = cutlass.Float32(0.0)
a2 = cutlass.Float32(0.0); a3 = cutlass.Float32(0.0)
a4 = cutlass.Float32(0.0); a5 = cutlass.Float32(0.0)
# Tile scheduler + pipeline states
tsched = utils.StaticPersistentTileScheduler.create(
tile_sched_params, bidx, cute.arch.grid_dim())
wt = tsched.initial_work_tile_info()
acc_cs = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_acc_stage)
# Track which row we're computing top-k for (row 0 of each M-tile)
current_row = cutlass.Int32(-1)
num_tiles_done = cutlass.Int32(0)
pipeline.PipelineUserType.Consumer, self.num_acc_pipeline_stages)
while wt.is_valid_tile:
acc_pipeline.consumer_wait(acc_cs)
@@ -661,16 +738,11 @@ class Nvfp4FusedRouterKernel:
else:
acc_stage_index = acc_cs.index
# Get tile N offset (which 128-expert slice this tile covers)
# Get tile N offset (which expert slice this tile covers)
tc = wt.tile_idx
tile_n_offset = tc[1] * self.cta_tile_shape_mnk[1]
tile_m_base = tc[0] // cute.size(tiled_mma.thr_id.shape) * self.cta_tile_shape_mnk[0]
# If this is a new row, the running top-6 is already accumulated
# For the first tile of a row, we just continue accumulating
if num_tiles_done == cutlass.Int32(0):
current_row = tile_m_base
tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_stage_index)]
tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
@@ -684,33 +756,27 @@ class Nvfp4FusedRouterKernel:
# Early release accumulator for overlapping case
if cutlass.const_expr(self.overlapping_accum):
if subtile_idx == self.iter_acc_early_release:
if subtile_idx == self.iter_acc_early_release_in_epilogue:
with cute.arch.elect_one():
acc_pipeline.consumer_release(acc_cs)
acc_cs.advance()
# Process each element in the register fragment
rFlat = cute.flatten(tTR_rAcc)
elem_cnt = cute.size(rFlat)
rAcc = tTR_rAcc.load()
elem_cnt = cute.size(rAcc)
for e in cutlass.range(elem_cnt, unroll=4):
logit = rFlat[e]
coord = cFlat[e]
row = coord[0]
col = coord[1]
logit = rAcc[e]
# Expert index = subtile_offset + e
e_idx = cutlass.Int32(tile_n_offset) + cutlass.Int32(subtile_idx * self.epi_tile_n) + cutlass.Int32(e)
# Expert index = col + tile_n_offset + (subtile_idx * epi_n)
e_idx = col + tile_n_offset + (subtile_idx * epi_n)
# Only process row 0 (the actual token row)
# For M=1 padded to 128, only row 0 has valid data
if row == 0:
# Only process if expert index is valid
if e_idx < cutlass.Int32(N):
# sqrt(softplus(logit))
# softplus(x) = max(x, 0) + log(1 + exp(-|x|))
abs_x = cute.math.absf(logit)
pos = cute.math.fmax(logit, cutlass.Float32(0.0))
exp_neg = cute.math.exp(-abs_x)
one_plus = cutlass.Float32(1.0) + exp_neg
sp = pos + cute.math.log(one_plus)
sp = pos + cute.math.log(cutlass.Float32(1.0) + exp_neg)
act = cute.math.sqrt(sp)
# score = act + e_bias (for selection only)
@@ -718,10 +784,8 @@ class Nvfp4FusedRouterKernel:
# Sorted insertion into descending top-6
# s0 >= s1 >= s2 >= s3 >= s4 >= s5
# If score <= s5, skip
if score > s5:
if score > s4:
# Shift s4 → s5
s5 = s4; i5 = i4; a5 = a4
if score > s3:
s4 = s3; i4 = i3; a4 = a3
@@ -749,8 +813,6 @@ class Nvfp4FusedRouterKernel:
acc_pipeline.consumer_release(acc_cs)
acc_cs.advance()
num_tiles_done += cutlass.Int32(1)
tsched.advance_to_next_work()
wt = tsched.get_current_work()
@@ -759,16 +821,17 @@ class Nvfp4FusedRouterKernel:
# ==================================================================
# Each thread writes its running top-6 to SMEM
tid = warp_idx * 32 + tidx
s_merge_s[tid, 0] = s0; s_merge_s[tid, 1] = s1; s_merge_s[tid, 2] = s2
s_merge_s[tid, 3] = s3; s_merge_s[tid, 4] = s4; s_merge_s[tid, 5] = s5
s_merge_i[tid, 0] = i0; s_merge_i[tid, 1] = i1; s_merge_i[tid, 2] = i2
s_merge_i[tid, 3] = i3; s_merge_i[tid, 4] = i4; s_merge_i[tid, 5] = i5
s_merge_a[tid, 0] = a0; s_merge_a[tid, 1] = a1; s_merge_a[tid, 2] = a2
s_merge_a[tid, 3] = a3; s_merge_a[tid, 4] = a4; s_merge_a[tid, 5] = a5
for k_idx in cutlass.range(TK, unroll=1):
s_val = s0 if k_idx == 0 else (s1 if k_idx == 1 else (s2 if k_idx == 2 else (s3 if k_idx == 3 else (s4 if k_idx == 4 else s5))))
i_val = i0 if k_idx == 0 else (i1 if k_idx == 1 else (i2 if k_idx == 2 else (i3 if k_idx == 3 else (i4 if k_idx == 4 else i5))))
a_val = a0 if k_idx == 0 else (a1 if k_idx == 1 else (a2 if k_idx == 2 else (a3 if k_idx == 3 else (a4 if k_idx == 4 else a5))))
storage.merge_scores.data_ptr()[tid * TK + k_idx] = s_val
storage.merge_indices.data_ptr()[tid * TK + k_idx] = i_val
storage.merge_acts.data_ptr()[tid * TK + k_idx] = a_val
epi_bar.arrive_and_wait()
# Thread 0 merges all 128 threads' top-6 into final result
# Thread 0 of warp 0 does the final merge + store
if warp_idx == 0 and tidx == 0:
# Initialize final top-6 from thread 0's data
fs0 = s0; fs1 = s1; fs2 = s2; fs3 = s3; fs4 = s4; fs5 = s5
@@ -777,10 +840,10 @@ class Nvfp4FusedRouterKernel:
# Merge all other threads (1..127)
for t in cutlass.range(1, 128, unroll=1):
for k in cutlass.range(TK, unroll=1):
cs = s_merge_s[t, k]
ci = s_merge_i[t, k]
ca = s_merge_a[t, k]
for k_idx in cutlass.range(TK, unroll=1):
cs = storage.merge_scores.data_ptr()[t * TK + k_idx]
ci = storage.merge_indices.data_ptr()[t * TK + k_idx]
ca = storage.merge_acts.data_ptr()[t * TK + k_idx]
# Only merge if this is a valid entry (index >= 0)
if ci >= cutlass.Int32(0):
# Sorted insertion into final top-6 (descending)
@@ -812,20 +875,19 @@ class Nvfp4FusedRouterKernel:
inv_sum = cutlass.Float32(1.0) / act_sum
sc = cutlass.Float32(routed_scaling_factor)
# Store to GMEM (row 0 of the M-tile)
row_idx = cutlass.Int32(0)
out_w_tensor[row_idx, 0] = fa0 * inv_sum * sc
out_w_tensor[row_idx, 1] = fa1 * inv_sum * sc
out_w_tensor[row_idx, 2] = fa2 * inv_sum * sc
out_w_tensor[row_idx, 3] = fa3 * inv_sum * sc
out_w_tensor[row_idx, 4] = fa4 * inv_sum * sc
out_w_tensor[row_idx, 5] = fa5 * inv_sum * sc
out_id_tensor[row_idx, 0] = fi0
out_id_tensor[row_idx, 1] = fi1
out_id_tensor[row_idx, 2] = fi2
out_id_tensor[row_idx, 3] = fi3
out_id_tensor[row_idx, 4] = fi4
out_id_tensor[row_idx, 5] = fi5
# Store to GMEM
out_w_tensor[0, 0] = fa0 * inv_sum * sc
out_w_tensor[0, 1] = fa1 * inv_sum * sc
out_w_tensor[0, 2] = fa2 * inv_sum * sc
out_w_tensor[0, 3] = fa3 * inv_sum * sc
out_w_tensor[0, 4] = fa4 * inv_sum * sc
out_w_tensor[0, 5] = fa5 * inv_sum * sc
out_id_tensor[0, 0] = fi0
out_id_tensor[0, 1] = fi1
out_id_tensor[0, 2] = fi2
out_id_tensor[0, 3] = fi3
out_id_tensor[0, 4] = fi4
out_id_tensor[0, 5] = fi5
epi_bar.arrive_and_wait()
@@ -853,21 +915,6 @@ def run_nvfp4_fused_router(
Single-kernel: NVFP4 block-scaled GEMM + fused router epilogue.
No intermediate GMEM buffer. No BF16 fallback.
Parameters
----------
hidden_states : [M, K] BF16 — raw activation
mat_b : CuTe tensor — gate weight in NVFP4 blockscaled layout
scale_b : CuTe tensor — gate weight scale factors in blockscaled layout
gsa : activation global scale (scalar)
gsb_val : weight global scale (float)
e_bias : [E] FP32 — per-expert selection bias
routed_scaling_factor : float
top_k : int (default 6)
Returns
-------
(topk_weights, topk_ids) — [M, top_k] FP32 and [M, top_k] int32
"""
import cutlass.torch as cutlass_torch
from dsv4.ops.quantize import quantize_activation_nvfp4

View File

@@ -0,0 +1,78 @@
"""Test: check what CuTeDSL math operations are available."""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
def test_cute_math_api():
"""Enumerate available CuTeDSL math/arch operations."""
import cutlass
import cutlass.cute as cute
# Check cute.math module
print("=== cute.math attributes ===")
if hasattr(cute, 'math'):
for attr in sorted(dir(cute.math)):
if not attr.startswith('_'):
print(f" cute.math.{attr}")
else:
print(" cute.math does not exist")
# Check cute.arch module for math
print("\n=== cute.arch math-related attributes ===")
if hasattr(cute, 'arch'):
for attr in sorted(dir(cute.arch)):
if any(k in attr.lower() for k in ['sqrt', 'log', 'exp', 'abs', 'sin', 'cos', 'rsqrt', 'rcp', 'fma', 'div']):
print(f" cute.arch.{attr}")
# Check cute directly for math
print("\n=== cute math-related attributes ===")
for attr in sorted(dir(cute)):
if any(k in attr.lower() for k in ['sqrt', 'log', 'exp', 'abs', 'sin', 'cos', 'rsqrt', 'rcp']):
print(f" cute.{attr}")
# Check cutlass module for math
print("\n=== cutlass math-related attributes ===")
for attr in sorted(dir(cutlass)):
if any(k in attr.lower() for k in ['sqrt', 'log', 'exp', 'abs', 'rsqrt', 'rcp']):
print(f" cutlass.{attr}")
# Check if cute.exp exists
print(f"\n=== Key functions ===")
print(f" cute.exp exists: {hasattr(cute, 'exp')}")
print(f" cute.log exists: {hasattr(cute, 'log')}")
print(f" cute.sqrt exists: {hasattr(cute, 'sqrt')}")
print(f" cute.math exists: {hasattr(cute, 'math')}")
if hasattr(cute, 'math'):
print(f" cute.math.fmax exists: {hasattr(cute.math, 'fmax')}")
print(f" cute.math.fmin exists: {hasattr(cute.math, 'fmin')}")
print(f" cute.math.absf exists: {hasattr(cute.math, 'absf')}")
print(f" cute.math.sqrt exists: {hasattr(cute.math, 'sqrt')}")
print(f" cute.math.log exists: {hasattr(cute.math, 'log')}")
print(f" cute.math.exp exists: {hasattr(cute.math, 'exp')}")
print(f" cute.math.rsqrt exists: {hasattr(cute.math, 'rsqrt')}")
print(f" cute.math.rcp exists: {hasattr(cute.math, 'rcp')}")
print(f" cute.math.sin exists: {hasattr(cute.math, 'sin')}")
print(f" cute.math.cos exists: {hasattr(cute.math, 'cos')}")
print(f" cute.math.copysign exists: {hasattr(cute.math, 'copysign')}")
print(f" cute.math.clamp exists: {hasattr(cute.math, 'clamp')}")
# Check arch operations
print(f"\n cute.arch.fmax exists: {hasattr(cute.arch, 'fmax')}")
print(f" cute.arch.fmin exists: {hasattr(cute.arch, 'fmin')}")
# Try to find math operations in cutlass._mlir_ops or similar
print("\n=== MLIR operations ===")
for mod_name in ['cutlass._mlir_ops', 'cutlass.mlir', 'cutlass.cute._mlir']:
try:
mod = __import__(mod_name, fromlist=[''])
math_attrs = [a for a in dir(mod) if any(k in a.lower() for k in ['sqrt', 'log', 'exp', 'abs', 'rsqrt'])]
if math_attrs:
print(f" {mod_name}: {math_attrs}")
except ImportError:
pass
print("\nDone.")
if __name__ == "__main__":
test_cute_math_api()