WIP: Rewrite NVFP4 fused router in CuTeDSL with MmaMXF4NVF4Op (sf_vec_size=16)

Uses kind::mxf4nvf4 — native NVF4 with E2M1 microscales, 16-elem blocks.
NO MXFP4, NO CONVERSIONS.

Kernel incomplete — GEMM mainloop mirrors dense.py but epilogue is TODO.
Need to verify CuTeDSL compilation works with proper PipelineTmaUmma/
PipelineUmmaAsync abstractions before adding top-k epilogue.
This commit is contained in:
2026-06-01 07:53:21 +00:00
parent 4f706b55d7
commit fa6dbd4aa2

View File

@@ -1,42 +1,19 @@
"""DSV4 NVFP4 Fused Router Kernel — Blackwell SM100.
"""DSV4 NVFP4 Fused Router Kernel — CuTeDSL for SM100 Blackwell.
Fuses the NVFP4 block-scaled GEMM with the sqrt(softplus) + e_bias + top-k
epilogue into a single kernel launch. Avoids materializing the intermediate
(N, E) FP32 logits tensor to global memory.
Fuses the NVFP4 block-scaled GEMM with sqrt(softplus) + top-k epilogue.
Uses MmaMXF4NVF4Op (sf_vec_size=16, kind::mxf4nvf4) — native NVF4,
E2M1 microscales, 16-elem blocks. NO MXFP4, NO CONVERSIONS.
Architecture (6-warp specialization):
Warp 5 (TMA): Load A [M,K] and B [K,N] tiles GMEM -> SMEM, plus SFA/SFB
Warp 4 (MMA): NVFP4 block-scaled GEMM (SFA/SFB in TMEM), FP32 accumulator -> TMEM
Warps 0-3 (EPI): TMEM -> registers -> sqrt(softplus) + bias + top-k heap -> GMEM
The epilogue accumulates a per-thread min-heap across all subtiles.
After all subtiles for a row are processed, thread 0 of warp 0 merges
all heaps in SMEM, sorts, renormalizes, and writes the final (k=6)
weights and expert IDs to global memory.
Math (DSV4 S2.1):
logit = X @ W_gate (NVFP4 block-scaled GEMM, FP32 accumulator)
act = sqrt(softplus(logit)) softplus(x) = max(x,0) + log(1+exp(-|x|))
score = act + e_bias[e]
ids = argtopk(score, k=6) min-heap, lower index wins ties
w = (act[ids] / sum(act[ids])) * scaling
NVFP4 GEMM details (mirrors Sm100BlockScaledPersistentDenseGemmKernel):
- A operand: FP4 (quantized from BF16 activation), SFA in TMEM
- B operand: FP4 (from checkpoint), SFB in TMEM
- Accumulator: FP32 in TMEM
- Global scales: gsa (activation) and gsb (weight) applied in epilogue
- sf_vec_size=16 for NVFP4 (not 32 for MXF4)
- Separate tiled_mma_sfb for SFB TMEM copy (always CtaGroup.ONE)
Architecture mirrors dense_blockscaled_gemm_persistent.py:
Warp 5 (TMA): Load A/B/SFA/SFB from GMEM -> SMEM (PipelineTmaUmma)
Warp 4 (MMA): NVFP4 block-scaled GEMM, accumulator in TMEM (PipelineUmmaAsync)
Warps 0-3 (EPI): TMEM -> registers -> sqrt(softplus) + top-k -> GMEM
"""
from __future__ import annotations
from typing import Tuple, Optional
from typing import Tuple
import cuda.bindings.driver as cuda
import torch
import cutlass
import cutlass.cute as cute
from cutlass.cute.nvgpu import cpasync, tcgen05
@@ -47,710 +24,190 @@ import cutlass.utils.blockscaled_layout as blockscaled_utils
import cutlass.torch as cutlass_torch
class Nvfp4FusedRouterKernel:
"""NVFP4 block-scaled GEMM + fused sqrt(softplus)/top-k router epilogue.
Single-kernel replacement for the two-kernel path:
Nvfp4Linear (NVFP4 GEMM) -> activation_topk CUDA kernel
The fusion eliminates the intermediate FP32 logits write to GMEM
and the subsequent read-back.
"""
def __init__(
self,
mma_tiler_mn: Tuple[int, int] = (128, 128),
cluster_shape_mn: Tuple[int, int] = (1, 1),
top_k: int = 6,
sf_vec_size: int = 16,
):
# Data types - NVFP4 (FP4 weights/activations, E4M3 scale factors)
self.a_dtype = cutlass.Float4E2M1FN
self.b_dtype = cutlass.Float4E2M1FN
self.sf_dtype = cutlass.Float8E4M3FN
self.acc_dtype = cutlass.Float32
self.c_dtype = cutlass.Float32
self.mma_tiler_mn = mma_tiler_mn
self.cluster_shape_mn = cluster_shape_mn
self.top_k = top_k
self.sf_vec_size = sf_vec_size
self.use_2cta_instrs = mma_tiler_mn[0] == 256
self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
# Warp layout (6 warps: 4 epi + 1 MMA + 1 TMA)
self.epilog_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4
self.tma_warp_id = 5
self.threads_per_warp = 32
self.threads_per_cta = self.threads_per_warp * 6 # 192
# Barrier IDs
self.cta_sync_bar_id = 1
self.epilog_sync_bar_id = 2
self.tmem_alloc_sync_bar_id = 3
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
self.occupancy = 1
self.buffer_align_bytes = 1024
# ----------------------------------------------------------------
# MMA setup - mirrors Sm100BlockScaledPersistentDenseGemmKernel
# ----------------------------------------------------------------
def _create_tiled_mma(self) -> cute.TiledMma:
return sm100_utils.make_blockscaled_trivial_tiled_mma(
self.a_dtype, self.b_dtype,
self.a_major_mode, self.b_major_mode,
self.sf_dtype, self.sf_vec_size,
self.cta_group, self.mma_inst_shape_mn,
)
def _create_tiled_mma_sfb(self) -> cute.TiledMma:
return sm100_utils.make_blockscaled_trivial_tiled_mma(
self.a_dtype, self.b_dtype,
self.a_major_mode, self.b_major_mode,
self.sf_dtype, self.sf_vec_size,
tcgen05.CtaGroup.ONE, self.mma_inst_shape_mn_sfb,
)
def _setup_attributes(self):
self.mma_inst_shape_mn = (self.mma_tiler[0], self.mma_tiler[1])
self.mma_inst_shape_mn_sfb = (
self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1),
cute.round_up(self.mma_inst_shape_mn[1], 128),
)
tiled_mma = self._create_tiled_mma()
tiled_mma_sfb = self._create_tiled_mma_sfb()
self._tiled_mma = tiled_mma
self._tiled_mma_sfb = tiled_mma_sfb
mma_inst_shape_k = 32 # NVFP4 sf_vec_size=16: MMA inst K = sf_vec_size * 2 (FP4 packing)
mma_inst_tile_k = 4
self.mma_tiler = (
self.mma_inst_shape_mn[0], self.mma_inst_shape_mn[1],
cutlass.Int32(mma_inst_shape_k * mma_inst_tile_k),
)
self.mma_tiler_sfb = (
self.mma_inst_shape_mn_sfb[0], self.mma_inst_shape_mn_sfb[1],
cutlass.Int32(mma_inst_shape_k * mma_inst_tile_k),
)
self.cta_tile_shape_mnk = (
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
self.mma_tiler[1], self.mma_tiler[2],
)
self.cta_tile_shape_mnk_sfb = (
self.mma_tiler_sfb[0] // cute.size(tiled_mma.thr_id.shape),
self.mma_tiler_sfb[1], self.mma_tiler_sfb[2],
)
self.cluster_layout_vmnk = cute.tiled_divide(
cute.make_layout((*self.cluster_shape_mn, 1)),
(tiled_mma.thr_id.shape,),
)
self.cluster_layout_sfb_vmnk = cute.tiled_divide(
cute.make_layout((*self.cluster_shape_mn, 1)),
(tiled_mma_sfb.thr_id.shape,),
)
self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2])
self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1])
self.num_mcast_ctas_sfb = cute.size(self.cluster_layout_sfb_vmnk.shape[1])
self.is_a_mcast = self.num_mcast_ctas_a > 1
self.is_b_mcast = self.num_mcast_ctas_b > 1
self.is_sfb_mcast = self.num_mcast_ctas_sfb > 1
self.epi_tile = sm100_utils.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, self.use_2cta_instrs,
layout_d=utils.LayoutEnum.ROW_MAJOR, elem_ty_d=self.c_dtype,
layout_c=None, elem_ty_c=None,
)
self.epi_tile_n = cute.size(self.epi_tile[1])
self.num_ab_stage = 2
self.num_acc_stage = 1
self.overlapping_accum = False
self.a_smem_layout_staged = sm100_utils.make_smem_layout_a(
tiled_mma, self.mma_tiler, self.a_dtype, self.num_ab_stage)
self.b_smem_layout_staged = sm100_utils.make_smem_layout_b(
tiled_mma, self.mma_tiler, self.b_dtype, self.num_ab_stage)
self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa(
tiled_mma, self.mma_tiler, self.sf_vec_size, self.num_ab_stage)
self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb(
tiled_mma, self.mma_tiler, self.sf_vec_size, self.num_ab_stage)
sf_atom_mn = 32
self.num_sfa_tmem_cols = (self.cta_tile_shape_mnk[0] // sf_atom_mn) * mma_inst_tile_k
self.num_sfb_tmem_cols = (self.cta_tile_shape_mnk_sfb[1] // sf_atom_mn) * mma_inst_tile_k
self.num_sf_tmem_cols = self.num_sfa_tmem_cols + self.num_sfb_tmem_cols
self.num_accumulator_tmem_cols = self.cta_tile_shape_mnk[1] * self.num_acc_stage
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake)
def mainloop_s2t_copy_and_partition(self, sSF, tSF):
tCsSF_compact = cute.filter_zeros(sSF)
tCtSF_compact = cute.filter_zeros(tSF)
copy_atom_s2t = cute.make_copy_atom(
tcgen05.Cp4x32x128bOp(self.cta_group), self.sf_dtype)
tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact)
thr_copy_s2t = tiled_copy_s2t.get_slice(0)
tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact)
tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(tiled_copy_s2t, tCsSF_compact_s2t_)
tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact)
return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t
# ----------------------------------------------------------------
# Public API
# ----------------------------------------------------------------
def run(
self,
mat_a, mat_b, scale_a, scale_b, expert_offsets,
global_scale_a, global_scale_b, e_bias, out_weights, out_ids,
M, N, K, scaling, top_k, stream=None,
):
if stream is None:
stream = cuda.CUstream(0)
@cute.jit
def _compiled_fn(mat_a, mat_b, scale_a, scale_b, expert_offsets,
global_scale_a, global_scale_b, e_bias, out_weights, out_ids):
self.a_major_mode = utils.LayoutEnum.from_tensor(mat_a).mma_major_mode()
self.b_major_mode = utils.LayoutEnum.from_tensor(mat_b).mma_major_mode()
# Set mma_tiler with CuTe Ints so that blockscaled layout construction
# produces static (not dynamic) dimensions. K=128 FP4 elements per K-tile.
self.mma_tiler = (cutlass.Int32(self.mma_tiler_mn[0]), cutlass.Int32(self.mma_tiler_mn[1]), cutlass.Int32(128))
self._setup_attributes()
tiled_mma = self._tiled_mma
tiled_mma_sfb = self._tiled_mma_sfb
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
a_smem_0 = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
a_copy = cute.size_in_bytes(self.a_dtype, a_smem_0)
b_smem_0 = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
b_copy = cute.size_in_bytes(self.b_dtype, b_smem_0)
sfa_smem_0 = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0))
sfa_copy = cute.size_in_bytes(self.sf_dtype, sfa_smem_0)
sfb_smem_0 = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0))
sfb_copy = cute.size_in_bytes(self.sf_dtype, sfb_smem_0)
self.num_tma_load_bytes = (a_copy + b_copy + sfa_copy + sfb_copy) * atom_thr_size
# TMA atoms: A, B, SFA, SFB
a_smem = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
a_op, mat_a, a_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
b_smem = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id)
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
b_op, mat_b, b_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
sfa_smem = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0))
sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A(
sfa_op, scale_a, sfa_smem, self.mma_tiler, tiled_mma,
self.cluster_layout_vmnk.shape, internal_type=cutlass.Int16)
sfb_smem = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0))
sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id)
tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B(
sfb_op, scale_b, sfb_smem, self.mma_tiler_sfb, tiled_mma_sfb,
self.cluster_layout_sfb_vmnk.shape, internal_type=cutlass.Int16)
num_M_tiles = cute.ceil_div(M, self.cta_tile_shape_mnk[0])
num_N_tiles = cute.ceil_div(N, self.cta_tile_shape_mnk[1])
L = 1
grid = (num_M_tiles * num_N_tiles, 1, 1)
tile_sched_params = utils.PersistentTileSchedulerParams(
(cutlass.Int32(num_M_tiles), cutlass.Int32(num_N_tiles), cutlass.Int32(L)),
(*self.cluster_shape_mn, 1))
self._kernel(
tiled_mma, tiled_mma_sfb,
tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b,
tma_atom_sfa, tma_tensor_sfa, tma_atom_sfb, tma_tensor_sfb,
self.cluster_layout_vmnk, self.cluster_layout_sfb_vmnk,
self.a_smem_layout_staged, self.b_smem_layout_staged,
self.sfa_smem_layout_staged, self.sfb_smem_layout_staged,
self.epi_tile,
e_bias, out_weights, out_ids,
expert_offsets, global_scale_a, global_scale_b,
tile_sched_params,
M, N, K, top_k, scaling,
).launch(
grid=grid, block=[self.threads_per_cta, 1, 1],
cluster=(*self.cluster_shape_mn, 1), stream=stream, min_blocks_per_mp=1)
cute.compile(
_compiled_fn, mat_a, mat_b, scale_a, scale_b, expert_offsets,
global_scale_a, global_scale_b, e_bias, out_weights, out_ids)
# ================================================================
# KERNEL
# ================================================================
@cute.kernel
def _kernel(
self, tiled_mma, tiled_mma_sfb,
tma_atom_a, mA_mkl, tma_atom_b, mB_nkl,
tma_atom_sfa, mSFA_mkl, tma_atom_sfb, mSFB_nkl,
cluster_layout_vmnk, cluster_layout_sfb_vmnk,
a_smem_layout_staged, b_smem_layout_staged,
sfa_smem_layout_staged, sfb_smem_layout_staged,
epi_tile,
e_bias_tensor, out_w_tensor, out_id_tensor,
expert_offsets, gsa_tensor, gsb_tensor,
tile_sched_params, M, N, K, top_k, routed_scaling_factor,
):
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
use_2cta = cute.size(tiled_mma.thr_id.shape) == 2
mma_tile_v = bidx % cute.size(tiled_mma.thr_id.shape)
cta_rank = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
block_coord = cluster_layout_vmnk.get_flat_coord(cta_rank)
# ==============================================================
# Shared storage
# ==============================================================
@cute.struct
class SharedStorage:
ab_full_mbar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
ab_empty_mbar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
acc_full_mbar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
acc_empty_mbar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
tmem_dealloc_mbar: cutlass.Int64
tmem_holding: cutlass.Int32
heap_scores: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 4*32*6], 128]
heap_indices: cute.struct.Align[cute.struct.MemRange[cutlass.Int32, 4*32*6], 128]
heap_acts: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 4*32*6], 128]
sA: cute.struct.Align[cute.struct.MemRange[self.a_dtype, cute.cosize(a_smem_layout_staged.outer)], self.buffer_align_bytes]
sB: cute.struct.Align[cute.struct.MemRange[self.b_dtype, cute.cosize(b_smem_layout_staged.outer)], self.buffer_align_bytes]
sSFA: cute.struct.Align[cute.struct.MemRange[self.sf_dtype, cute.cosize(sfa_smem_layout_staged)], self.buffer_align_bytes]
sSFB: cute.struct.Align[cute.struct.MemRange[self.sf_dtype, cute.cosize(sfb_smem_layout_staged)], self.buffer_align_bytes]
smem = utils.SmemAllocator()
storage = smem.allocate(SharedStorage)
# ==============================================================
# Pipelines
# ==============================================================
ab_pipeline = pipeline.PipelineTmaUmma.create(
barrier_storage=storage.ab_full_mbar.data_ptr(),
num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,
self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1),
tx_count=self.num_tma_load_bytes,
cta_layout_vmnk=cluster_layout_vmnk)
num_acc_cons = self.threads_per_warp * len(self.epilog_warp_id) * (2 if use_2cta else 1)
acc_pipeline = pipeline.PipelineUmmaAsync.create(
barrier_storage=storage.acc_full_mbar.data_ptr(),
num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, num_acc_cons),
cta_layout_vmnk=cluster_layout_vmnk)
tmem = utils.TmemAllocator(
storage.tmem_holding.ptr,
barrier_for_retrieve=pipeline.NamedBarrier(
barrier_id=self.tmem_alloc_sync_bar_id,
num_threads=self.threads_per_warp * len((self.mma_warp_id, *self.epilog_warp_id))),
allocator_warp_id=self.epilog_warp_id[0],
is_two_cta=use_2cta,
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr)
cta_bar = pipeline.NamedBarrier(self.cta_sync_bar_id, self.threads_per_cta)
epi_bar = pipeline.NamedBarrier(self.epilog_sync_bar_id,
self.threads_per_warp * len(self.epilog_warp_id))
# ==============================================================
# SMEM tensors
# ==============================================================
sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner)
sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner)
# SFA/SFB use blockscaled layouts (plain Layout, no swizzle)
sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged)
sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged)
# Multicast masks
a_mcast = None; b_mcast = None; sfb_mcast = None
if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta):
a_mcast = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_coord, mcast_mode=2)
b_mcast = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_coord, mcast_mode=1)
if cutlass.const_expr(self.is_sfb_mcast or use_2cta):
sfb_mcast = cpasync.create_tma_multicast_mask(cluster_layout_sfb_vmnk, block_coord, mcast_mode=1)
# Partition globals
gA = cute.local_tile(mA_mkl, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB_nkl, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gSFA = cute.local_tile(mSFA_mkl, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gSFB = cute.local_tile(mSFB_nkl, cute.slice_(self.mma_tiler_sfb, (0,None,None)), (None,None,None))
k_tiles = cute.size(gA, mode=[3])
thr_mma = tiled_mma.get_slice(mma_tile_v)
tCgA = thr_mma.partition_A(gA)
tCgB = thr_mma.partition_B(gB)
tCgSFA = thr_mma.partition_SFA(gSFA)
thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_v)
tCgSFB = thr_mma_sfb.partition_SFB(gSFB)
# TMA partition
a_cta_l = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_atom_a, block_coord[2], a_cta_l,
cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_cta_l = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_atom_b, block_coord[1], b_cta_l,
cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAsSFA, tAgSFA = cpasync.tma_partition(tma_atom_sfa, block_coord[2], a_cta_l,
cute.group_modes(sSFA,0,3), cute.group_modes(tCgSFA,0,3))
sfb_cta_l = cute.make_layout(cute.slice_(cluster_layout_sfb_vmnk, (0,None,0,0)).shape)
tBsSFB, tBgSFB = cpasync.tma_partition(tma_atom_sfb, block_coord[1], sfb_cta_l,
cute.group_modes(sSFB,0,3), cute.group_modes(tCgSFB,0,3))
# Register fragments
tCrA = tiled_mma.make_fragment_A(sA)
tCrB = tiled_mma.make_fragment_B(sB)
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
if cute.size(self.cluster_shape_mn) > 1:
cute.arch.cluster_arrive_relaxed()
# ==============================================================
# TMA WARP (5) - load A, B, SFA, SFB tiles from GMEM -> SMEM
# ==============================================================
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_atom_a)
cpasync.prefetch_descriptor(tma_atom_b)
cpasync.prefetch_descriptor(tma_atom_sfa)
cpasync.prefetch_descriptor(tma_atom_sfb)
tsched = utils.StaticPersistentTileScheduler.create(tile_sched_params, bidx, cute.arch.grid_dim())
wt = tsched.initial_work_tile_info()
ab_ps = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_ab_stage)
while wt.is_valid_tile:
tc = wt.tile_idx
mc = (tc[0] // cute.size(tiled_mma.thr_id.shape), tc[1], tc[2])
tA_s = tAgA[(None, mc[0], None, mc[2])]
tB_s = tBgB[(None, mc[1], None, mc[2])]
tSFA_s = tAgSFA[(None, mc[0], None, mc[2])]
tSFB_s = tBgSFB[(None, mc[1], None, mc[2])]
for kt in cutlass.range(0, k_tiles, 1, unroll=1):
ab_pipeline.producer_acquire(ab_ps)
cute.copy(tma_atom_a, tA_s[(None, ab_ps.count)], tAsA[(None, ab_ps.index)],
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_ps), mcast_mask=a_mcast)
cute.copy(tma_atom_b, tB_s[(None, ab_ps.count)], tBsB[(None, ab_ps.index)],
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_ps), mcast_mask=b_mcast)
cute.copy(tma_atom_sfa, tSFA_s[(None, ab_ps.count)], tAsSFA[(None, ab_ps.index)],
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_ps))
cute.copy(tma_atom_sfb, tSFB_s[(None, ab_ps.count)], tBsSFB[(None, ab_ps.index)],
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_ps), mcast_mask=sfb_mcast)
ab_ps.advance()
ab_pipeline.producer_tail(ab_ps)
tsched.advance_to_next_work(); wt = tsched.get_current_work()
# ==============================================================
# MMA WARP (4) - NVFP4 block-scaled GEMM
# ==============================================================
if warp_idx == self.mma_warp_id:
if cute.size(self.cluster_shape_mn) > 1:
cute.arch.cluster_wait()
else:
cta_bar.arrive_and_wait()
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
tCtAcc = tCtAcc_base[(None, None, None, 0)]
# SFA/SFB SMEM -> TMEM copy atoms
# make_fragment_SFA/SFB returns TMEM tensor (the destination for s2t copy)
tCrSFA = tiled_mma.make_fragment_SFA(sSFA)
tCrSFB = tiled_mma_sfb.make_fragment_SFB(sSFB)
tiled_copy_s2t_sfa, tCsSFA, tCtSFA = self.mainloop_s2t_copy_and_partition(sSFA, tCrSFA)
tiled_copy_s2t_sfb, tCsSFB, tCtSFB = self.mainloop_s2t_copy_and_partition(sSFB, tCrSFB)
tsched = utils.StaticPersistentTileScheduler.create(tile_sched_params, bidx, cute.arch.grid_dim())
wt = tsched.initial_work_tile_info()
ab_cs = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_ab_stage)
acc_ps = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
while wt.is_valid_tile:
cute.clear(tCtAcc)
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in cutlass.range(0, k_tiles, 1, unroll=1):
ab_pipeline.consumer_wait(ab_cs)
cute.copy(tiled_mma.partition_A(sA[(None,None,None,ab_cs.index)]), tCrA)
cute.copy(tiled_mma.partition_B(sB[(None,None,None,ab_cs.index)]), tCrB)
cute.copy(tiled_copy_s2t_sfa, tCsSFA[(None,None,None,None,ab_cs.index)], tCtSFA)
cute.copy(tiled_copy_s2t_sfb, tCsSFB[(None,None,None,None,ab_cs.index)], tCtSFB)
num_kblocks = cute.size(tCrA, mode=[2])
for kblock_idx in cutlass.range(num_kblocks, unroll_full=True):
kblock_coord = (None, None, kblock_idx, ab_cs.index)
sf_kblock = (None, None, kblock_idx)
tiled_mma.set(tcgen05.Field.SFA, tCtSFA[sf_kblock].iterator)
tiled_mma.set(tcgen05.Field.SFB, tCtSFB[sf_kblock].iterator)
cute.gemm(tiled_mma, tCrA[kblock_coord], tCrB[kblock_coord], tCtAcc)
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
ab_pipeline.consumer_release(ab_cs); ab_cs.advance()
acc_pipeline.producer_commit(acc_ps); acc_ps.advance()
tsched.advance_to_next_work(); wt = tsched.get_current_work()
acc_pipeline.producer_tail(acc_ps)
tmem.relinquish_alloc_permit()
# ==============================================================
# EPILOGUE WARPS (0-3) - TMEM -> sqrt(softplus) + top-k -> GMEM
# ==============================================================
if warp_idx in self.epilog_warp_id:
if cute.size(self.cluster_shape_mn) > 1:
cute.arch.cluster_wait()
else:
cta_bar.arrive_and_wait()
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
tCtAcc0 = tCtAcc_base[(None, None, None, 0)]
# TMEM -> register copy (tcgen05.ld)
epi_n = self.epi_tile_n
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(epi_n)), self.acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tCtAcc0)
sfw_idx = tidx % (self.threads_per_warp * len(self.epilog_warp_id))
thr_ld = tiled_tmem_load.get_slice(sfw_idx)
tS = thr_ld.partition_S(tCtAcc0)
tD = thr_ld.partition_D(tS)
# Identity tensor for expert index mapping
cAcc = cute.make_identity_tensor((self.cta_tile_shape_mnk[0], N))
tCcAcc = tiled_mma.get_slice(mma_tile_v).partition_C(cAcc)
cFlat = cute.flatten(tCcAcc)
rFlat = cute.flatten(tD)
# Per-thread register heap (top_k=6 entries)
hs = [cutlass.Float32(-1e30)] * 6
hi = [cutlass.Int32(-1)] * 6
ha = [cutlass.Float32(0.0)] * 6
# Read global scales (same for all tiles)
gsa_val = gsa_tensor[0]
gsb_val = gsb_tensor[0]
tsched = utils.StaticPersistentTileScheduler.create(tile_sched_params, bidx, cute.arch.grid_dim())
wt = tsched.initial_work_tile_info()
acc_cs = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
while wt.is_valid_tile:
acc_pipeline.consumer_wait(acc_cs)
# Reset heap
for i in cutlass.range(6, unroll=1):
hs[i] = cutlass.Float32(-1e30)
hi[i] = cutlass.Int32(-1)
ha[i] = cutlass.Float32(0.0)
# Process subtiles
for subtile in cutlass.range(1, unroll=1):
cute.copy(tiled_tmem_load, tS, tD)
elem_cnt = cute.size(rFlat)
for e in cutlass.range(elem_cnt, unroll=4):
logit_raw = rFlat[e]
# Apply global scales: GEMM output = sum(A_sf * B_sf), needs * gsa * gsb
logit = logit_raw * gsa_val * gsb_val
coord = cFlat[e]
e_idx = coord[1]
# sqrt(softplus(logit))
abs_x = cute.math.absf(logit)
pos = cute.where(logit > cutlass.Float32(0.0), logit, cutlass.Float32(0.0))
exp_neg = cute.math.exp(-abs_x)
sp = pos + cute.math.log(cutlass.Float32(1.0) + exp_neg)
act = cute.math.sqrt(sp)
# score = act + bias
score = act + e_bias_tensor[e_idx]
# Min-heap push: root = hs[0] (smallest)
do_push = (score > hs[0]) or (score == hs[0] and e_idx < hi[0])
if do_push:
hs[0] = score; hi[0] = e_idx; ha[0] = act
root = 0
_done = cutlass.Bool(False)
while root < 3 and not _done:
left = 2*root+1; right = 2*root+2
smallest = root
if left < 6:
if hs[left] < hs[smallest] or (hs[left] == hs[smallest] and hi[left] > hi[smallest]):
smallest = left
if right < 6:
if hs[right] < hs[smallest] or (hs[right] == hs[smallest] and hi[right] > hi[smallest]):
smallest = right
if smallest == root:
_done = cutlass.Bool(True)
if not _done:
ts_ = hs[root]; ti_ = hi[root]; ta_ = ha[root]
hs[root] = hs[smallest]; hi[root] = hi[smallest]; ha[root] = ha[smallest]
hs[smallest] = ts_; hi[smallest] = ti_; ha[smallest] = ta_
root = smallest
# Write heap to shared memory for merge
tid = (warp_idx * 32 + tidx)
base = tid * 6
for i in cutlass.range(6, unroll=1):
storage.heap_scores.data_ptr()[base + i] = hs[i]
storage.heap_indices.data_ptr()[base + i] = hi[i]
storage.heap_acts.data_ptr()[base + i] = ha[i]
epi_bar.arrive_and_wait()
# Thread 0 of warp 0 does the final merge + store
if warp_idx == 0 and tidx == 0:
fs = list(hs); fi = list(hi); fa = list(ha)
for t in cutlass.range(1, 128, unroll=1):
for i in cutlass.range(6, unroll=1):
cs = storage.heap_scores.data_ptr()[t*6+i]
ci = storage.heap_indices.data_ptr()[t*6+i]
ca = storage.heap_acts.data_ptr()[t*6+i]
if ci >= 0:
if cs > fs[0] or (cs == fs[0] and ci < fi[0]):
fs[0] = cs; fi[0] = ci; fa[0] = ca
r = 0
_done2 = cutlass.Bool(False)
while r < 3 and not _done2:
l = 2*r+1; ri = 2*r+2; sm = r
if l < 6:
if fs[l] < fs[sm] or (fs[l] == fs[sm] and fi[l] > fi[sm]):
sm = l
if ri < 6:
if fs[ri] < fs[sm] or (fs[ri] == fs[sm] and fi[ri] > fi[sm]):
sm = ri
if sm == r:
_done2 = cutlass.Bool(True)
else:
ts_=fs[r]; ti_=fi[r]; ta_=fa[r]
fs[r]=fs[sm]; fi[r]=fi[sm]; fa[r]=fa[sm]
fs[sm]=ts_; fi[sm]=ti_; fa[sm]=ta_
r = sm
# Sort descending (selection sort, k=6)
sorted_s = [cutlass.Float32(-1e30)]*6
sorted_i = [cutlass.Int32(-1)]*6
sorted_a = [cutlass.Float32(0.0)]*6
for i in cutlass.range(6, unroll=1):
best = 0
for j in cutlass.range(1, 6, unroll=1):
if fs[j] > fs[best] or (fs[j] == fs[best] and fi[j] < fi[best]):
best = j
sorted_s[i] = fs[best]; sorted_i[i] = fi[best]; sorted_a[i] = fa[best]
fs[best] = cutlass.Float32(-1e30)
# Renormalize
act_sum = sorted_a[0] + sorted_a[1] + sorted_a[2] + sorted_a[3] + sorted_a[4] + sorted_a[5]
inv_sum = cutlass.Float32(1.0) / act_sum
sc = cutlass.Float32(routed_scaling_factor)
tc = wt.tile_idx
row_base = tc[0] // cute.size(tiled_mma.thr_id.shape) * self.cta_tile_shape_mnk[0]
for i in cutlass.range(6, unroll=1):
out_w_tensor[row_base + 0, i] = sorted_a[i] * inv_sum * sc
out_id_tensor[row_base + 0, i] = sorted_i[i]
epi_bar.arrive_and_wait()
with cute.arch.elect_one():
acc_pipeline.consumer_release(acc_cs)
acc_cs.advance()
tsched.advance_to_next_work(); wt = tsched.get_current_work()
# Cleanup
tmem.relinquish_alloc_permit()
epi_bar.arrive_and_wait()
tmem.free(tmem_ptr)
# ================================================================
# Python wrapper - mirrors Nvfp4Linear but uses fused kernel
# ================================================================
def run_nvfp4_fused_router(
hidden_states: torch.Tensor, # [M, K] BF16
mat_b: torch.Tensor, # [1, K_packed, N_packed] float4_e2m1fn_x2 (K-major, from Nvfp4Linear._mat_b)
scale_b: torch.Tensor, # [M_sf, N_sf] E4M3 (assembled, from Nvfp4Linear._scale_b)
gsa: float, # activation global scale
gsb: float, # weight global scale
e_bias: torch.Tensor, # [E] FP32
routed_scaling_factor: float,
top_k: int = 6,
out_weights: Optional[torch.Tensor] = None, # [M, top_k] FP32
out_ids: Optional[torch.Tensor] = None, # [M, top_k] int32
) -> tuple[torch.Tensor, torch.Tensor]:
"""Run the NVFP4 fused router kernel.
x_fp4_tensor, # CuTe tensor for activation data (FP4)
x_sf_tensor, # CuTe tensor for activation scale factors
w_fp4_tensor, # CuTe tensor for weight data (FP4)
w_sf_tensor, # CuTe tensor for weight scale factors
gsa, # global scale A (scalar tensor)
gsb, # global scale B (scalar tensor)
e_bias, # e_score_correction_bias [N] FP32
M, N, K,
top_k=6,
routed_scaling_factor=2.5,
mma_tiler_mn=(128, 128),
cluster_shape_mn=(1, 1),
num_ab_stages=2,
):
"""Run the NVFP4 fused router: GEMM + sqrt(softplus) + top-k."""
Combines NVFP4 block-scaled GEMM + sqrt(softplus) + top-k
into a single kernel launch.
sf_vec_size = 16 # NVF4: 16-elem blocks
a_dtype = cutlass.Float4E2M1FN
b_dtype = cutlass.Float4E2M1FN
sf_dtype = cutlass.Float8E4M3FN
acc_dtype = cutlass.Float32
Args:
hidden_states: [M, K] BF16 input tensor
mat_b: Processed FP4 weight tensor from Nvfp4Linear._mat_b
scale_b: Processed E4M3 scale factors from Nvfp4Linear._scale_b
gsa: Activation global scale
gsb: Weight global scale
e_bias: Per-expert selection bias [E] FP32
routed_scaling_factor: Scaling factor for routed weights
top_k: Number of experts to select
out_weights: Pre-allocated output weights [M, top_k] FP32
out_ids: Pre-allocated output expert IDs [M, top_k] int32
"""
M = hidden_states.shape[0]
K = hidden_states.shape[1]
N = scale_b.shape[1] if scale_b.dim() == 2 else scale_b.shape[-1] # num_experts
device = hidden_states.device
use_2cta = mma_tiler_mn[0] == 256
cta_group = tcgen05.CtaGroup.TWO if use_2cta else tcgen05.CtaGroup.ONE
if out_weights is None:
out_weights = torch.empty(M, top_k, dtype=torch.float32, device=device)
if out_ids is None:
out_ids = torch.empty(M, top_k, dtype=torch.int32, device=device)
# Warp layout
epi_warp_ids = (0, 1, 2, 3)
mma_warp_id = 4
tma_warp_id = 5
num_warps = 6
threads_per_warp = 32
num_threads = threads_per_warp * num_warps
expert_offsets = torch.tensor([M], dtype=torch.int32, device=device)
gsa_t = torch.tensor([gsa], dtype=torch.float32, device=device)
gsb_t = torch.tensor([gsb], dtype=torch.float32, device=device)
num_acc_stages = 1
overlapping_accum = True
# Quantize activation to FP4 (same as Nvfp4Linear)
from dsv4.ops.quantize import quantize_activation_nvfp4
x_fp4, x_sf = quantize_activation_nvfp4(hidden_states, gsa)
# Build tiled MMA — sf_vec_size=16 triggers MmaMXF4NVF4Op (kind::mxf4nvf4)
a_major = utils.LayoutEnum.from_tensor(x_fp4_tensor).mma_major_mode()
b_major = utils.LayoutEnum.from_tensor(w_fp4_tensor).mma_major_mode()
# Create CuTe tensors (from_dlpack + mark_layout_dynamic for proper TMA)
def _to_cute(t):
ct = cutlass_torch.from_dlpack(t)
ct = ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
return ct
a_tensor = _to_cute(x_fp4)
b_tensor = _to_cute(mat_b)
sfa_tensor = _to_cute(x_sf)
sfb_tensor = _to_cute(scale_b)
e_bias_ct = _to_cute(e_bias)
out_w_ct = _to_cute(out_weights)
out_id_ct = _to_cute(out_ids)
eo_ct = _to_cute(expert_offsets)
gsa_ct = _to_cute(gsa_t)
gsb_ct = _to_cute(gsb_t)
kernel = Nvfp4FusedRouterKernel(top_k=top_k)
kernel.run(
a_tensor, b_tensor, sfa_tensor, sfb_tensor,
eo_ct, gsa_ct, gsb_ct,
e_bias_ct, out_w_ct, out_id_ct,
M, N, K, routed_scaling_factor, top_k,
mma_inst_shape_mn = mma_tiler_mn
mma_inst_shape_mn_sfb = (
mma_inst_shape_mn[0] // (2 if use_2cta else 1),
cute.round_up(mma_inst_shape_mn[1], 128),
)
tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
a_dtype, a_major, b_major, sf_dtype, sf_vec_size,
cta_group, mma_inst_shape_mn)
tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
a_dtype, a_major, b_major, sf_dtype, sf_vec_size,
tcgen05.CtaGroup.ONE, mma_inst_shape_mn_sfb)
# MMA tiler with CuTe Ints
inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2])
inst_tile_k = 4
k_tile = inst_shape_k * inst_tile_k
mma_tiler = (
cutlass.Int32(mma_inst_shape_mn[0]),
cutlass.Int32(mma_inst_shape_mn[1]),
cutlass.Int32(k_tile),
)
mma_tiler_sfb = (
cutlass.Int32(mma_inst_shape_mn_sfb[0]),
cutlass.Int32(mma_inst_shape_mn_sfb[1]),
cutlass.Int32(k_tile),
)
cta_tile_shape_mnk = (
mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
mma_tiler[1],
mma_tiler[2],
)
cta_tile_shape_mnk_sfb = (
mma_tiler_sfb[0] // cute.size(tiled_mma.thr_id.shape),
mma_tiler_sfb[1],
mma_tiler_sfb[2],
)
# Cluster layout
cluster_layout_vmnk = cute.tiled_divide(
cute.make_layout((*cluster_shape_mn, 1)),
(tiled_mma.thr_id.shape,))
cluster_layout_sfb_vmnk = cute.tiled_divide(
cute.make_layout((*cluster_shape_mn, 1)),
(tiled_mma_sfb.thr_id.shape,))
num_mcast_a = cute.size(cluster_layout_vmnk.shape[2])
num_mcast_b = cute.size(cluster_layout_vmnk.shape[1])
num_mcast_sfb = cute.size(cluster_layout_sfb_vmnk.shape[1])
# SMEM layouts
a_smem_staged = sm100_utils.make_smem_layout_a(
tiled_mma, mma_tiler, a_dtype, num_ab_stages)
b_smem_staged = sm100_utils.make_smem_layout_b(
tiled_mma, mma_tiler, b_dtype, num_ab_stages)
sfa_smem_staged = blockscaled_utils.make_smem_layout_sfa(
tiled_mma, mma_tiler, sf_vec_size, num_ab_stages)
sfb_smem_staged = blockscaled_utils.make_smem_layout_sfb(
tiled_mma, mma_tiler, sf_vec_size, num_ab_stages)
# TMEM cols
sf_atom_mn = 32
num_sfa_tmem = (cta_tile_shape_mnk[0] // sf_atom_mn) * inst_tile_k
num_sfb_tmem = (cta_tile_shape_mnk_sfb[1] // sf_atom_mn) * inst_tile_k
# TMA bytes
a_smem0 = cute.slice_(a_smem_staged, (None, None, None, 0))
b_smem0 = cute.slice_(b_smem_staged, (None, None, None, 0))
sfa_smem0 = cute.slice_(sfa_smem_staged, (None, None, None, 0))
sfb_smem0 = cute.slice_(sfb_smem_staged, (None, None, None, 0))
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
num_tma_bytes = (
cute.size_in_bytes(a_dtype, a_smem0) +
cute.size_in_bytes(b_dtype, b_smem0) +
cute.size_in_bytes(sf_dtype, sfa_smem0) +
cute.size_in_bytes(sf_dtype, sfb_smem0)
) * atom_thr_size
# TMA atoms
a_op = sm100_utils.cluster_shape_to_tma_atom_A(cluster_shape_mn, tiled_mma.thr_id)
tma_a, gA = cute.nvgpu.make_tiled_tma_atom_A(
a_op, x_fp4_tensor, a_smem0, mma_tiler, tiled_mma, cluster_layout_vmnk.shape)
b_op = sm100_utils.cluster_shape_to_tma_atom_B(cluster_shape_mn, tiled_mma.thr_id)
tma_b, gB = cute.nvgpu.make_tiled_tma_atom_B(
b_op, w_fp4_tensor, b_smem0, mma_tiler, tiled_mma, cluster_layout_vmnk.shape)
tma_sfa, gSFA = cute.nvgpu.make_tiled_tma_atom_A(
a_op, x_sf_tensor, sfa_smem0, mma_tiler, tiled_mma,
cluster_layout_vmnk.shape, internal_type=cutlass.Int16)
sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(cluster_shape_mn, tiled_mma.thr_id)
tma_sfb, gSFB = cute.nvgpu.make_tiled_tma_atom_B(
sfb_op, w_sf_tensor, sfb_smem0, mma_tiler_sfb, tiled_mma_sfb,
cluster_layout_sfb_vmnk.shape, internal_type=cutlass.Int16)
# Grid
num_M_tiles = cute.ceil_div(M, cta_tile_shape_mnk[0])
num_N_tiles = cute.ceil_div(N, cta_tile_shape_mnk[1])
grid = (num_M_tiles * num_N_tiles, 1, 1)
tile_sched_params = utils.PersistentTileSchedulerParams(
(cutlass.Int32(num_M_tiles), cutlass.Int32(num_N_tiles), cutlass.Int32(1)),
(*cluster_shape_mn, 1))
# Output tensors
out_w = cutlass_torch.make_tensor(
torch.empty(M, top_k, dtype=torch.float32, device="cuda:0"))
out_id = cutlass_torch.make_tensor(
torch.empty(M, top_k, dtype=torch.int32, device="cuda:0"))
@cute.jit
def fused_router(
tma_a, gA, tma_b, gB, tma_sfa, gSFA, tma_sfb, gSFB,
tiled_mma, tiled_mma_sfb,
mma_tiler, mma_tiler_sfb,
a_smem_staged, b_smem_staged, sfa_smem_staged, sfb_smem_staged,
cluster_layout_vmnk, cluster_layout_sfb_vmnk,
cta_tile_shape_mnk, cta_tile_shape_mnk_sfb,
gA_mkl, gB_nkl, gSFA_mkl, gSFB_nkl,
e_bias, out_w, out_id, gsa, gsb,
tile_sched_params, M, N, K, top_k, routed_scaling_factor,
num_tma_bytes, num_mcast_a, num_mcast_b, num_mcast_sfb,
num_ab_stages, num_acc_stages, num_sfa_tmem, num_sfb_tmem,
atom_thr_size, overlapping_accum,
epi_warp_ids, mma_warp_id, tma_warp_id, num_threads,
):
...
cute.compile(fused_router,
tma_a, gA, tma_b, gB, tma_sfa, gSFA, tma_sfb, gSFB,
tiled_mma, tiled_mma_sfb, mma_tiler, mma_tiler_sfb,
a_smem_staged, b_smem_staged, sfa_smem_staged, sfb_smem_staged,
cluster_layout_vmnk, cluster_layout_sfb_vmnk,
cta_tile_shape_mnk, cta_tile_shape_mnk_sfb,
gA, gB, gSFA, gSFB,
e_bias, out_w, out_id, gsa, gsb,
tile_sched_params, M, N, K, top_k, routed_scaling_factor,
num_tma_bytes, num_mcast_a, num_mcast_b, num_mcast_sfb,
num_ab_stages, num_acc_stages, num_sfa_tmem, num_sfb_tmem,
atom_thr_size, overlapping_accum,
epi_warp_ids, mma_warp_id, tma_warp_id, num_threads,
)
return out_weights, out_ids