WIP: Rewrite NVFP4 fused router in CuTeDSL with MmaMXF4NVF4Op (sf_vec_size=16)
Uses kind::mxf4nvf4 — native NVF4 with E2M1 microscales, 16-elem blocks. NO MXFP4, NO CONVERSIONS. Kernel incomplete — GEMM mainloop mirrors dense.py but epilogue is TODO. Need to verify CuTeDSL compilation works with proper PipelineTmaUmma/ PipelineUmmaAsync abstractions before adding top-k epilogue.
This commit is contained in:
@@ -1,42 +1,19 @@
|
||||
"""DSV4 NVFP4 Fused Router Kernel — Blackwell SM100.
|
||||
"""DSV4 NVFP4 Fused Router Kernel — CuTeDSL for SM100 Blackwell.
|
||||
|
||||
Fuses the NVFP4 block-scaled GEMM with the sqrt(softplus) + e_bias + top-k
|
||||
epilogue into a single kernel launch. Avoids materializing the intermediate
|
||||
(N, E) FP32 logits tensor to global memory.
|
||||
Fuses the NVFP4 block-scaled GEMM with sqrt(softplus) + top-k epilogue.
|
||||
Uses MmaMXF4NVF4Op (sf_vec_size=16, kind::mxf4nvf4) — native NVF4,
|
||||
E2M1 microscales, 16-elem blocks. NO MXFP4, NO CONVERSIONS.
|
||||
|
||||
Architecture (6-warp specialization):
|
||||
Warp 5 (TMA): Load A [M,K] and B [K,N] tiles GMEM -> SMEM, plus SFA/SFB
|
||||
Warp 4 (MMA): NVFP4 block-scaled GEMM (SFA/SFB in TMEM), FP32 accumulator -> TMEM
|
||||
Warps 0-3 (EPI): TMEM -> registers -> sqrt(softplus) + bias + top-k heap -> GMEM
|
||||
|
||||
The epilogue accumulates a per-thread min-heap across all subtiles.
|
||||
After all subtiles for a row are processed, thread 0 of warp 0 merges
|
||||
all heaps in SMEM, sorts, renormalizes, and writes the final (k=6)
|
||||
weights and expert IDs to global memory.
|
||||
|
||||
Math (DSV4 S2.1):
|
||||
logit = X @ W_gate (NVFP4 block-scaled GEMM, FP32 accumulator)
|
||||
act = sqrt(softplus(logit)) softplus(x) = max(x,0) + log(1+exp(-|x|))
|
||||
score = act + e_bias[e]
|
||||
ids = argtopk(score, k=6) min-heap, lower index wins ties
|
||||
w = (act[ids] / sum(act[ids])) * scaling
|
||||
|
||||
NVFP4 GEMM details (mirrors Sm100BlockScaledPersistentDenseGemmKernel):
|
||||
- A operand: FP4 (quantized from BF16 activation), SFA in TMEM
|
||||
- B operand: FP4 (from checkpoint), SFB in TMEM
|
||||
- Accumulator: FP32 in TMEM
|
||||
- Global scales: gsa (activation) and gsb (weight) applied in epilogue
|
||||
- sf_vec_size=16 for NVFP4 (not 32 for MXF4)
|
||||
- Separate tiled_mma_sfb for SFB TMEM copy (always CtaGroup.ONE)
|
||||
Architecture mirrors dense_blockscaled_gemm_persistent.py:
|
||||
Warp 5 (TMA): Load A/B/SFA/SFB from GMEM -> SMEM (PipelineTmaUmma)
|
||||
Warp 4 (MMA): NVFP4 block-scaled GEMM, accumulator in TMEM (PipelineUmmaAsync)
|
||||
Warps 0-3 (EPI): TMEM -> registers -> sqrt(softplus) + top-k -> GMEM
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Tuple, Optional
|
||||
|
||||
from typing import Tuple
|
||||
import cuda.bindings.driver as cuda
|
||||
import torch
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
@@ -47,710 +24,190 @@ import cutlass.utils.blockscaled_layout as blockscaled_utils
|
||||
import cutlass.torch as cutlass_torch
|
||||
|
||||
|
||||
class Nvfp4FusedRouterKernel:
|
||||
"""NVFP4 block-scaled GEMM + fused sqrt(softplus)/top-k router epilogue.
|
||||
|
||||
Single-kernel replacement for the two-kernel path:
|
||||
Nvfp4Linear (NVFP4 GEMM) -> activation_topk CUDA kernel
|
||||
|
||||
The fusion eliminates the intermediate FP32 logits write to GMEM
|
||||
and the subsequent read-back.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mma_tiler_mn: Tuple[int, int] = (128, 128),
|
||||
cluster_shape_mn: Tuple[int, int] = (1, 1),
|
||||
top_k: int = 6,
|
||||
sf_vec_size: int = 16,
|
||||
):
|
||||
# Data types - NVFP4 (FP4 weights/activations, E4M3 scale factors)
|
||||
self.a_dtype = cutlass.Float4E2M1FN
|
||||
self.b_dtype = cutlass.Float4E2M1FN
|
||||
self.sf_dtype = cutlass.Float8E4M3FN
|
||||
self.acc_dtype = cutlass.Float32
|
||||
self.c_dtype = cutlass.Float32
|
||||
|
||||
self.mma_tiler_mn = mma_tiler_mn
|
||||
self.cluster_shape_mn = cluster_shape_mn
|
||||
self.top_k = top_k
|
||||
self.sf_vec_size = sf_vec_size
|
||||
|
||||
self.use_2cta_instrs = mma_tiler_mn[0] == 256
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
|
||||
# Warp layout (6 warps: 4 epi + 1 MMA + 1 TMA)
|
||||
self.epilog_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4
|
||||
self.tma_warp_id = 5
|
||||
self.threads_per_warp = 32
|
||||
self.threads_per_cta = self.threads_per_warp * 6 # 192
|
||||
|
||||
# Barrier IDs
|
||||
self.cta_sync_bar_id = 1
|
||||
self.epilog_sync_bar_id = 2
|
||||
self.tmem_alloc_sync_bar_id = 3
|
||||
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
|
||||
self.occupancy = 1
|
||||
self.buffer_align_bytes = 1024
|
||||
|
||||
# ----------------------------------------------------------------
|
||||
# MMA setup - mirrors Sm100BlockScaledPersistentDenseGemmKernel
|
||||
# ----------------------------------------------------------------
|
||||
def _create_tiled_mma(self) -> cute.TiledMma:
|
||||
return sm100_utils.make_blockscaled_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype,
|
||||
self.a_major_mode, self.b_major_mode,
|
||||
self.sf_dtype, self.sf_vec_size,
|
||||
self.cta_group, self.mma_inst_shape_mn,
|
||||
)
|
||||
|
||||
def _create_tiled_mma_sfb(self) -> cute.TiledMma:
|
||||
return sm100_utils.make_blockscaled_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype,
|
||||
self.a_major_mode, self.b_major_mode,
|
||||
self.sf_dtype, self.sf_vec_size,
|
||||
tcgen05.CtaGroup.ONE, self.mma_inst_shape_mn_sfb,
|
||||
)
|
||||
|
||||
def _setup_attributes(self):
|
||||
self.mma_inst_shape_mn = (self.mma_tiler[0], self.mma_tiler[1])
|
||||
self.mma_inst_shape_mn_sfb = (
|
||||
self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1),
|
||||
cute.round_up(self.mma_inst_shape_mn[1], 128),
|
||||
)
|
||||
|
||||
tiled_mma = self._create_tiled_mma()
|
||||
tiled_mma_sfb = self._create_tiled_mma_sfb()
|
||||
self._tiled_mma = tiled_mma
|
||||
self._tiled_mma_sfb = tiled_mma_sfb
|
||||
|
||||
mma_inst_shape_k = 32 # NVFP4 sf_vec_size=16: MMA inst K = sf_vec_size * 2 (FP4 packing)
|
||||
mma_inst_tile_k = 4
|
||||
self.mma_tiler = (
|
||||
self.mma_inst_shape_mn[0], self.mma_inst_shape_mn[1],
|
||||
cutlass.Int32(mma_inst_shape_k * mma_inst_tile_k),
|
||||
)
|
||||
self.mma_tiler_sfb = (
|
||||
self.mma_inst_shape_mn_sfb[0], self.mma_inst_shape_mn_sfb[1],
|
||||
cutlass.Int32(mma_inst_shape_k * mma_inst_tile_k),
|
||||
)
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
|
||||
self.mma_tiler[1], self.mma_tiler[2],
|
||||
)
|
||||
self.cta_tile_shape_mnk_sfb = (
|
||||
self.mma_tiler_sfb[0] // cute.size(tiled_mma.thr_id.shape),
|
||||
self.mma_tiler_sfb[1], self.mma_tiler_sfb[2],
|
||||
)
|
||||
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(
|
||||
cute.make_layout((*self.cluster_shape_mn, 1)),
|
||||
(tiled_mma.thr_id.shape,),
|
||||
)
|
||||
self.cluster_layout_sfb_vmnk = cute.tiled_divide(
|
||||
cute.make_layout((*self.cluster_shape_mn, 1)),
|
||||
(tiled_mma_sfb.thr_id.shape,),
|
||||
)
|
||||
|
||||
self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2])
|
||||
self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1])
|
||||
self.num_mcast_ctas_sfb = cute.size(self.cluster_layout_sfb_vmnk.shape[1])
|
||||
self.is_a_mcast = self.num_mcast_ctas_a > 1
|
||||
self.is_b_mcast = self.num_mcast_ctas_b > 1
|
||||
self.is_sfb_mcast = self.num_mcast_ctas_sfb > 1
|
||||
|
||||
self.epi_tile = sm100_utils.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, self.use_2cta_instrs,
|
||||
layout_d=utils.LayoutEnum.ROW_MAJOR, elem_ty_d=self.c_dtype,
|
||||
layout_c=None, elem_ty_c=None,
|
||||
)
|
||||
self.epi_tile_n = cute.size(self.epi_tile[1])
|
||||
|
||||
self.num_ab_stage = 2
|
||||
self.num_acc_stage = 1
|
||||
self.overlapping_accum = False
|
||||
|
||||
self.a_smem_layout_staged = sm100_utils.make_smem_layout_a(
|
||||
tiled_mma, self.mma_tiler, self.a_dtype, self.num_ab_stage)
|
||||
self.b_smem_layout_staged = sm100_utils.make_smem_layout_b(
|
||||
tiled_mma, self.mma_tiler, self.b_dtype, self.num_ab_stage)
|
||||
self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa(
|
||||
tiled_mma, self.mma_tiler, self.sf_vec_size, self.num_ab_stage)
|
||||
self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb(
|
||||
tiled_mma, self.mma_tiler, self.sf_vec_size, self.num_ab_stage)
|
||||
|
||||
sf_atom_mn = 32
|
||||
self.num_sfa_tmem_cols = (self.cta_tile_shape_mnk[0] // sf_atom_mn) * mma_inst_tile_k
|
||||
self.num_sfb_tmem_cols = (self.cta_tile_shape_mnk_sfb[1] // sf_atom_mn) * mma_inst_tile_k
|
||||
self.num_sf_tmem_cols = self.num_sfa_tmem_cols + self.num_sfb_tmem_cols
|
||||
self.num_accumulator_tmem_cols = self.cta_tile_shape_mnk[1] * self.num_acc_stage
|
||||
|
||||
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
|
||||
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake)
|
||||
|
||||
def mainloop_s2t_copy_and_partition(self, sSF, tSF):
|
||||
tCsSF_compact = cute.filter_zeros(sSF)
|
||||
tCtSF_compact = cute.filter_zeros(tSF)
|
||||
copy_atom_s2t = cute.make_copy_atom(
|
||||
tcgen05.Cp4x32x128bOp(self.cta_group), self.sf_dtype)
|
||||
tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact)
|
||||
thr_copy_s2t = tiled_copy_s2t.get_slice(0)
|
||||
tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact)
|
||||
tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(tiled_copy_s2t, tCsSF_compact_s2t_)
|
||||
tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact)
|
||||
return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t
|
||||
|
||||
# ----------------------------------------------------------------
|
||||
# Public API
|
||||
# ----------------------------------------------------------------
|
||||
def run(
|
||||
self,
|
||||
mat_a, mat_b, scale_a, scale_b, expert_offsets,
|
||||
global_scale_a, global_scale_b, e_bias, out_weights, out_ids,
|
||||
M, N, K, scaling, top_k, stream=None,
|
||||
):
|
||||
if stream is None:
|
||||
stream = cuda.CUstream(0)
|
||||
|
||||
@cute.jit
|
||||
def _compiled_fn(mat_a, mat_b, scale_a, scale_b, expert_offsets,
|
||||
global_scale_a, global_scale_b, e_bias, out_weights, out_ids):
|
||||
self.a_major_mode = utils.LayoutEnum.from_tensor(mat_a).mma_major_mode()
|
||||
self.b_major_mode = utils.LayoutEnum.from_tensor(mat_b).mma_major_mode()
|
||||
|
||||
# Set mma_tiler with CuTe Ints so that blockscaled layout construction
|
||||
# produces static (not dynamic) dimensions. K=128 FP4 elements per K-tile.
|
||||
self.mma_tiler = (cutlass.Int32(self.mma_tiler_mn[0]), cutlass.Int32(self.mma_tiler_mn[1]), cutlass.Int32(128))
|
||||
|
||||
self._setup_attributes()
|
||||
tiled_mma = self._tiled_mma
|
||||
tiled_mma_sfb = self._tiled_mma_sfb
|
||||
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
|
||||
|
||||
a_smem_0 = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
|
||||
a_copy = cute.size_in_bytes(self.a_dtype, a_smem_0)
|
||||
b_smem_0 = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
|
||||
b_copy = cute.size_in_bytes(self.b_dtype, b_smem_0)
|
||||
sfa_smem_0 = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0))
|
||||
sfa_copy = cute.size_in_bytes(self.sf_dtype, sfa_smem_0)
|
||||
sfb_smem_0 = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0))
|
||||
sfb_copy = cute.size_in_bytes(self.sf_dtype, sfb_smem_0)
|
||||
self.num_tma_load_bytes = (a_copy + b_copy + sfa_copy + sfb_copy) * atom_thr_size
|
||||
|
||||
# TMA atoms: A, B, SFA, SFB
|
||||
a_smem = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
|
||||
a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
|
||||
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
a_op, mat_a, a_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
|
||||
|
||||
b_smem = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
|
||||
b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id)
|
||||
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
b_op, mat_b, b_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
|
||||
|
||||
sfa_smem = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0))
|
||||
sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
|
||||
tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
sfa_op, scale_a, sfa_smem, self.mma_tiler, tiled_mma,
|
||||
self.cluster_layout_vmnk.shape, internal_type=cutlass.Int16)
|
||||
|
||||
sfb_smem = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0))
|
||||
sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id)
|
||||
tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
sfb_op, scale_b, sfb_smem, self.mma_tiler_sfb, tiled_mma_sfb,
|
||||
self.cluster_layout_sfb_vmnk.shape, internal_type=cutlass.Int16)
|
||||
|
||||
num_M_tiles = cute.ceil_div(M, self.cta_tile_shape_mnk[0])
|
||||
num_N_tiles = cute.ceil_div(N, self.cta_tile_shape_mnk[1])
|
||||
L = 1
|
||||
grid = (num_M_tiles * num_N_tiles, 1, 1)
|
||||
|
||||
tile_sched_params = utils.PersistentTileSchedulerParams(
|
||||
(cutlass.Int32(num_M_tiles), cutlass.Int32(num_N_tiles), cutlass.Int32(L)),
|
||||
(*self.cluster_shape_mn, 1))
|
||||
|
||||
self._kernel(
|
||||
tiled_mma, tiled_mma_sfb,
|
||||
tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b,
|
||||
tma_atom_sfa, tma_tensor_sfa, tma_atom_sfb, tma_tensor_sfb,
|
||||
self.cluster_layout_vmnk, self.cluster_layout_sfb_vmnk,
|
||||
self.a_smem_layout_staged, self.b_smem_layout_staged,
|
||||
self.sfa_smem_layout_staged, self.sfb_smem_layout_staged,
|
||||
self.epi_tile,
|
||||
e_bias, out_weights, out_ids,
|
||||
expert_offsets, global_scale_a, global_scale_b,
|
||||
tile_sched_params,
|
||||
M, N, K, top_k, scaling,
|
||||
).launch(
|
||||
grid=grid, block=[self.threads_per_cta, 1, 1],
|
||||
cluster=(*self.cluster_shape_mn, 1), stream=stream, min_blocks_per_mp=1)
|
||||
|
||||
cute.compile(
|
||||
_compiled_fn, mat_a, mat_b, scale_a, scale_b, expert_offsets,
|
||||
global_scale_a, global_scale_b, e_bias, out_weights, out_ids)
|
||||
|
||||
# ================================================================
|
||||
# KERNEL
|
||||
# ================================================================
|
||||
@cute.kernel
|
||||
def _kernel(
|
||||
self, tiled_mma, tiled_mma_sfb,
|
||||
tma_atom_a, mA_mkl, tma_atom_b, mB_nkl,
|
||||
tma_atom_sfa, mSFA_mkl, tma_atom_sfb, mSFB_nkl,
|
||||
cluster_layout_vmnk, cluster_layout_sfb_vmnk,
|
||||
a_smem_layout_staged, b_smem_layout_staged,
|
||||
sfa_smem_layout_staged, sfb_smem_layout_staged,
|
||||
epi_tile,
|
||||
e_bias_tensor, out_w_tensor, out_id_tensor,
|
||||
expert_offsets, gsa_tensor, gsb_tensor,
|
||||
tile_sched_params, M, N, K, top_k, routed_scaling_factor,
|
||||
):
|
||||
warp_idx = cute.arch.warp_idx()
|
||||
warp_idx = cute.arch.make_warp_uniform(warp_idx)
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
bidx, _, _ = cute.arch.block_idx()
|
||||
|
||||
use_2cta = cute.size(tiled_mma.thr_id.shape) == 2
|
||||
mma_tile_v = bidx % cute.size(tiled_mma.thr_id.shape)
|
||||
cta_rank = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster())
|
||||
block_coord = cluster_layout_vmnk.get_flat_coord(cta_rank)
|
||||
|
||||
# ==============================================================
|
||||
# Shared storage
|
||||
# ==============================================================
|
||||
@cute.struct
|
||||
class SharedStorage:
|
||||
ab_full_mbar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
|
||||
ab_empty_mbar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
|
||||
acc_full_mbar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
|
||||
acc_empty_mbar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
|
||||
tmem_dealloc_mbar: cutlass.Int64
|
||||
tmem_holding: cutlass.Int32
|
||||
heap_scores: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 4*32*6], 128]
|
||||
heap_indices: cute.struct.Align[cute.struct.MemRange[cutlass.Int32, 4*32*6], 128]
|
||||
heap_acts: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 4*32*6], 128]
|
||||
sA: cute.struct.Align[cute.struct.MemRange[self.a_dtype, cute.cosize(a_smem_layout_staged.outer)], self.buffer_align_bytes]
|
||||
sB: cute.struct.Align[cute.struct.MemRange[self.b_dtype, cute.cosize(b_smem_layout_staged.outer)], self.buffer_align_bytes]
|
||||
sSFA: cute.struct.Align[cute.struct.MemRange[self.sf_dtype, cute.cosize(sfa_smem_layout_staged)], self.buffer_align_bytes]
|
||||
sSFB: cute.struct.Align[cute.struct.MemRange[self.sf_dtype, cute.cosize(sfb_smem_layout_staged)], self.buffer_align_bytes]
|
||||
|
||||
smem = utils.SmemAllocator()
|
||||
storage = smem.allocate(SharedStorage)
|
||||
|
||||
# ==============================================================
|
||||
# Pipelines
|
||||
# ==============================================================
|
||||
ab_pipeline = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=storage.ab_full_mbar.data_ptr(),
|
||||
num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,
|
||||
self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1),
|
||||
tx_count=self.num_tma_load_bytes,
|
||||
cta_layout_vmnk=cluster_layout_vmnk)
|
||||
|
||||
num_acc_cons = self.threads_per_warp * len(self.epilog_warp_id) * (2 if use_2cta else 1)
|
||||
acc_pipeline = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=storage.acc_full_mbar.data_ptr(),
|
||||
num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, num_acc_cons),
|
||||
cta_layout_vmnk=cluster_layout_vmnk)
|
||||
|
||||
tmem = utils.TmemAllocator(
|
||||
storage.tmem_holding.ptr,
|
||||
barrier_for_retrieve=pipeline.NamedBarrier(
|
||||
barrier_id=self.tmem_alloc_sync_bar_id,
|
||||
num_threads=self.threads_per_warp * len((self.mma_warp_id, *self.epilog_warp_id))),
|
||||
allocator_warp_id=self.epilog_warp_id[0],
|
||||
is_two_cta=use_2cta,
|
||||
two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar.ptr)
|
||||
|
||||
cta_bar = pipeline.NamedBarrier(self.cta_sync_bar_id, self.threads_per_cta)
|
||||
epi_bar = pipeline.NamedBarrier(self.epilog_sync_bar_id,
|
||||
self.threads_per_warp * len(self.epilog_warp_id))
|
||||
|
||||
# ==============================================================
|
||||
# SMEM tensors
|
||||
# ==============================================================
|
||||
sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner)
|
||||
sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner)
|
||||
# SFA/SFB use blockscaled layouts (plain Layout, no swizzle)
|
||||
sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged)
|
||||
sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged)
|
||||
|
||||
# Multicast masks
|
||||
a_mcast = None; b_mcast = None; sfb_mcast = None
|
||||
if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta):
|
||||
a_mcast = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_coord, mcast_mode=2)
|
||||
b_mcast = cpasync.create_tma_multicast_mask(cluster_layout_vmnk, block_coord, mcast_mode=1)
|
||||
if cutlass.const_expr(self.is_sfb_mcast or use_2cta):
|
||||
sfb_mcast = cpasync.create_tma_multicast_mask(cluster_layout_sfb_vmnk, block_coord, mcast_mode=1)
|
||||
|
||||
# Partition globals
|
||||
gA = cute.local_tile(mA_mkl, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB_nkl, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gSFA = cute.local_tile(mSFA_mkl, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gSFB = cute.local_tile(mSFB_nkl, cute.slice_(self.mma_tiler_sfb, (0,None,None)), (None,None,None))
|
||||
k_tiles = cute.size(gA, mode=[3])
|
||||
|
||||
thr_mma = tiled_mma.get_slice(mma_tile_v)
|
||||
tCgA = thr_mma.partition_A(gA)
|
||||
tCgB = thr_mma.partition_B(gB)
|
||||
tCgSFA = thr_mma.partition_SFA(gSFA)
|
||||
|
||||
thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_v)
|
||||
tCgSFB = thr_mma_sfb.partition_SFB(gSFB)
|
||||
|
||||
# TMA partition
|
||||
a_cta_l = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_atom_a, block_coord[2], a_cta_l,
|
||||
cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_cta_l = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_atom_b, block_coord[1], b_cta_l,
|
||||
cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAsSFA, tAgSFA = cpasync.tma_partition(tma_atom_sfa, block_coord[2], a_cta_l,
|
||||
cute.group_modes(sSFA,0,3), cute.group_modes(tCgSFA,0,3))
|
||||
sfb_cta_l = cute.make_layout(cute.slice_(cluster_layout_sfb_vmnk, (0,None,0,0)).shape)
|
||||
tBsSFB, tBgSFB = cpasync.tma_partition(tma_atom_sfb, block_coord[1], sfb_cta_l,
|
||||
cute.group_modes(sSFB,0,3), cute.group_modes(tCgSFB,0,3))
|
||||
|
||||
# Register fragments
|
||||
tCrA = tiled_mma.make_fragment_A(sA)
|
||||
tCrB = tiled_mma.make_fragment_B(sB)
|
||||
|
||||
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
|
||||
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
|
||||
|
||||
if cute.size(self.cluster_shape_mn) > 1:
|
||||
cute.arch.cluster_arrive_relaxed()
|
||||
|
||||
# ==============================================================
|
||||
# TMA WARP (5) - load A, B, SFA, SFB tiles from GMEM -> SMEM
|
||||
# ==============================================================
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_atom_a)
|
||||
cpasync.prefetch_descriptor(tma_atom_b)
|
||||
cpasync.prefetch_descriptor(tma_atom_sfa)
|
||||
cpasync.prefetch_descriptor(tma_atom_sfb)
|
||||
tsched = utils.StaticPersistentTileScheduler.create(tile_sched_params, bidx, cute.arch.grid_dim())
|
||||
wt = tsched.initial_work_tile_info()
|
||||
ab_ps = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_ab_stage)
|
||||
while wt.is_valid_tile:
|
||||
tc = wt.tile_idx
|
||||
mc = (tc[0] // cute.size(tiled_mma.thr_id.shape), tc[1], tc[2])
|
||||
tA_s = tAgA[(None, mc[0], None, mc[2])]
|
||||
tB_s = tBgB[(None, mc[1], None, mc[2])]
|
||||
tSFA_s = tAgSFA[(None, mc[0], None, mc[2])]
|
||||
tSFB_s = tBgSFB[(None, mc[1], None, mc[2])]
|
||||
for kt in cutlass.range(0, k_tiles, 1, unroll=1):
|
||||
ab_pipeline.producer_acquire(ab_ps)
|
||||
cute.copy(tma_atom_a, tA_s[(None, ab_ps.count)], tAsA[(None, ab_ps.index)],
|
||||
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_ps), mcast_mask=a_mcast)
|
||||
cute.copy(tma_atom_b, tB_s[(None, ab_ps.count)], tBsB[(None, ab_ps.index)],
|
||||
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_ps), mcast_mask=b_mcast)
|
||||
cute.copy(tma_atom_sfa, tSFA_s[(None, ab_ps.count)], tAsSFA[(None, ab_ps.index)],
|
||||
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_ps))
|
||||
cute.copy(tma_atom_sfb, tSFB_s[(None, ab_ps.count)], tBsSFB[(None, ab_ps.index)],
|
||||
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_ps), mcast_mask=sfb_mcast)
|
||||
ab_ps.advance()
|
||||
ab_pipeline.producer_tail(ab_ps)
|
||||
tsched.advance_to_next_work(); wt = tsched.get_current_work()
|
||||
|
||||
# ==============================================================
|
||||
# MMA WARP (4) - NVFP4 block-scaled GEMM
|
||||
# ==============================================================
|
||||
if warp_idx == self.mma_warp_id:
|
||||
if cute.size(self.cluster_shape_mn) > 1:
|
||||
cute.arch.cluster_wait()
|
||||
else:
|
||||
cta_bar.arrive_and_wait()
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
|
||||
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
|
||||
tCtAcc = tCtAcc_base[(None, None, None, 0)]
|
||||
|
||||
# SFA/SFB SMEM -> TMEM copy atoms
|
||||
# make_fragment_SFA/SFB returns TMEM tensor (the destination for s2t copy)
|
||||
tCrSFA = tiled_mma.make_fragment_SFA(sSFA)
|
||||
tCrSFB = tiled_mma_sfb.make_fragment_SFB(sSFB)
|
||||
|
||||
tiled_copy_s2t_sfa, tCsSFA, tCtSFA = self.mainloop_s2t_copy_and_partition(sSFA, tCrSFA)
|
||||
tiled_copy_s2t_sfb, tCsSFB, tCtSFB = self.mainloop_s2t_copy_and_partition(sSFB, tCrSFB)
|
||||
|
||||
tsched = utils.StaticPersistentTileScheduler.create(tile_sched_params, bidx, cute.arch.grid_dim())
|
||||
wt = tsched.initial_work_tile_info()
|
||||
ab_cs = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_ab_stage)
|
||||
acc_ps = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
while wt.is_valid_tile:
|
||||
cute.clear(tCtAcc)
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in cutlass.range(0, k_tiles, 1, unroll=1):
|
||||
ab_pipeline.consumer_wait(ab_cs)
|
||||
cute.copy(tiled_mma.partition_A(sA[(None,None,None,ab_cs.index)]), tCrA)
|
||||
cute.copy(tiled_mma.partition_B(sB[(None,None,None,ab_cs.index)]), tCrB)
|
||||
cute.copy(tiled_copy_s2t_sfa, tCsSFA[(None,None,None,None,ab_cs.index)], tCtSFA)
|
||||
cute.copy(tiled_copy_s2t_sfb, tCsSFB[(None,None,None,None,ab_cs.index)], tCtSFB)
|
||||
num_kblocks = cute.size(tCrA, mode=[2])
|
||||
for kblock_idx in cutlass.range(num_kblocks, unroll_full=True):
|
||||
kblock_coord = (None, None, kblock_idx, ab_cs.index)
|
||||
sf_kblock = (None, None, kblock_idx)
|
||||
tiled_mma.set(tcgen05.Field.SFA, tCtSFA[sf_kblock].iterator)
|
||||
tiled_mma.set(tcgen05.Field.SFB, tCtSFB[sf_kblock].iterator)
|
||||
cute.gemm(tiled_mma, tCrA[kblock_coord], tCrB[kblock_coord], tCtAcc)
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
ab_pipeline.consumer_release(ab_cs); ab_cs.advance()
|
||||
acc_pipeline.producer_commit(acc_ps); acc_ps.advance()
|
||||
tsched.advance_to_next_work(); wt = tsched.get_current_work()
|
||||
acc_pipeline.producer_tail(acc_ps)
|
||||
tmem.relinquish_alloc_permit()
|
||||
|
||||
# ==============================================================
|
||||
# EPILOGUE WARPS (0-3) - TMEM -> sqrt(softplus) + top-k -> GMEM
|
||||
# ==============================================================
|
||||
if warp_idx in self.epilog_warp_id:
|
||||
if cute.size(self.cluster_shape_mn) > 1:
|
||||
cute.arch.cluster_wait()
|
||||
else:
|
||||
cta_bar.arrive_and_wait()
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
|
||||
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
|
||||
tCtAcc0 = tCtAcc_base[(None, None, None, 0)]
|
||||
|
||||
# TMEM -> register copy (tcgen05.ld)
|
||||
epi_n = self.epi_tile_n
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(epi_n)), self.acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tCtAcc0)
|
||||
sfw_idx = tidx % (self.threads_per_warp * len(self.epilog_warp_id))
|
||||
thr_ld = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tS = thr_ld.partition_S(tCtAcc0)
|
||||
tD = thr_ld.partition_D(tS)
|
||||
|
||||
# Identity tensor for expert index mapping
|
||||
cAcc = cute.make_identity_tensor((self.cta_tile_shape_mnk[0], N))
|
||||
tCcAcc = tiled_mma.get_slice(mma_tile_v).partition_C(cAcc)
|
||||
cFlat = cute.flatten(tCcAcc)
|
||||
rFlat = cute.flatten(tD)
|
||||
|
||||
# Per-thread register heap (top_k=6 entries)
|
||||
hs = [cutlass.Float32(-1e30)] * 6
|
||||
hi = [cutlass.Int32(-1)] * 6
|
||||
ha = [cutlass.Float32(0.0)] * 6
|
||||
|
||||
# Read global scales (same for all tiles)
|
||||
gsa_val = gsa_tensor[0]
|
||||
gsb_val = gsb_tensor[0]
|
||||
|
||||
tsched = utils.StaticPersistentTileScheduler.create(tile_sched_params, bidx, cute.arch.grid_dim())
|
||||
wt = tsched.initial_work_tile_info()
|
||||
acc_cs = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
|
||||
while wt.is_valid_tile:
|
||||
acc_pipeline.consumer_wait(acc_cs)
|
||||
|
||||
# Reset heap
|
||||
for i in cutlass.range(6, unroll=1):
|
||||
hs[i] = cutlass.Float32(-1e30)
|
||||
hi[i] = cutlass.Int32(-1)
|
||||
ha[i] = cutlass.Float32(0.0)
|
||||
|
||||
# Process subtiles
|
||||
for subtile in cutlass.range(1, unroll=1):
|
||||
cute.copy(tiled_tmem_load, tS, tD)
|
||||
|
||||
elem_cnt = cute.size(rFlat)
|
||||
for e in cutlass.range(elem_cnt, unroll=4):
|
||||
logit_raw = rFlat[e]
|
||||
# Apply global scales: GEMM output = sum(A_sf * B_sf), needs * gsa * gsb
|
||||
logit = logit_raw * gsa_val * gsb_val
|
||||
coord = cFlat[e]
|
||||
e_idx = coord[1]
|
||||
|
||||
# sqrt(softplus(logit))
|
||||
abs_x = cute.math.absf(logit)
|
||||
pos = cute.where(logit > cutlass.Float32(0.0), logit, cutlass.Float32(0.0))
|
||||
exp_neg = cute.math.exp(-abs_x)
|
||||
sp = pos + cute.math.log(cutlass.Float32(1.0) + exp_neg)
|
||||
act = cute.math.sqrt(sp)
|
||||
|
||||
# score = act + bias
|
||||
score = act + e_bias_tensor[e_idx]
|
||||
|
||||
# Min-heap push: root = hs[0] (smallest)
|
||||
do_push = (score > hs[0]) or (score == hs[0] and e_idx < hi[0])
|
||||
if do_push:
|
||||
hs[0] = score; hi[0] = e_idx; ha[0] = act
|
||||
root = 0
|
||||
_done = cutlass.Bool(False)
|
||||
while root < 3 and not _done:
|
||||
left = 2*root+1; right = 2*root+2
|
||||
smallest = root
|
||||
if left < 6:
|
||||
if hs[left] < hs[smallest] or (hs[left] == hs[smallest] and hi[left] > hi[smallest]):
|
||||
smallest = left
|
||||
if right < 6:
|
||||
if hs[right] < hs[smallest] or (hs[right] == hs[smallest] and hi[right] > hi[smallest]):
|
||||
smallest = right
|
||||
if smallest == root:
|
||||
_done = cutlass.Bool(True)
|
||||
if not _done:
|
||||
ts_ = hs[root]; ti_ = hi[root]; ta_ = ha[root]
|
||||
hs[root] = hs[smallest]; hi[root] = hi[smallest]; ha[root] = ha[smallest]
|
||||
hs[smallest] = ts_; hi[smallest] = ti_; ha[smallest] = ta_
|
||||
root = smallest
|
||||
|
||||
# Write heap to shared memory for merge
|
||||
tid = (warp_idx * 32 + tidx)
|
||||
base = tid * 6
|
||||
for i in cutlass.range(6, unroll=1):
|
||||
storage.heap_scores.data_ptr()[base + i] = hs[i]
|
||||
storage.heap_indices.data_ptr()[base + i] = hi[i]
|
||||
storage.heap_acts.data_ptr()[base + i] = ha[i]
|
||||
|
||||
epi_bar.arrive_and_wait()
|
||||
|
||||
# Thread 0 of warp 0 does the final merge + store
|
||||
if warp_idx == 0 and tidx == 0:
|
||||
fs = list(hs); fi = list(hi); fa = list(ha)
|
||||
for t in cutlass.range(1, 128, unroll=1):
|
||||
for i in cutlass.range(6, unroll=1):
|
||||
cs = storage.heap_scores.data_ptr()[t*6+i]
|
||||
ci = storage.heap_indices.data_ptr()[t*6+i]
|
||||
ca = storage.heap_acts.data_ptr()[t*6+i]
|
||||
if ci >= 0:
|
||||
if cs > fs[0] or (cs == fs[0] and ci < fi[0]):
|
||||
fs[0] = cs; fi[0] = ci; fa[0] = ca
|
||||
r = 0
|
||||
_done2 = cutlass.Bool(False)
|
||||
while r < 3 and not _done2:
|
||||
l = 2*r+1; ri = 2*r+2; sm = r
|
||||
if l < 6:
|
||||
if fs[l] < fs[sm] or (fs[l] == fs[sm] and fi[l] > fi[sm]):
|
||||
sm = l
|
||||
if ri < 6:
|
||||
if fs[ri] < fs[sm] or (fs[ri] == fs[sm] and fi[ri] > fi[sm]):
|
||||
sm = ri
|
||||
if sm == r:
|
||||
_done2 = cutlass.Bool(True)
|
||||
else:
|
||||
ts_=fs[r]; ti_=fi[r]; ta_=fa[r]
|
||||
fs[r]=fs[sm]; fi[r]=fi[sm]; fa[r]=fa[sm]
|
||||
fs[sm]=ts_; fi[sm]=ti_; fa[sm]=ta_
|
||||
r = sm
|
||||
|
||||
# Sort descending (selection sort, k=6)
|
||||
sorted_s = [cutlass.Float32(-1e30)]*6
|
||||
sorted_i = [cutlass.Int32(-1)]*6
|
||||
sorted_a = [cutlass.Float32(0.0)]*6
|
||||
for i in cutlass.range(6, unroll=1):
|
||||
best = 0
|
||||
for j in cutlass.range(1, 6, unroll=1):
|
||||
if fs[j] > fs[best] or (fs[j] == fs[best] and fi[j] < fi[best]):
|
||||
best = j
|
||||
sorted_s[i] = fs[best]; sorted_i[i] = fi[best]; sorted_a[i] = fa[best]
|
||||
fs[best] = cutlass.Float32(-1e30)
|
||||
|
||||
# Renormalize
|
||||
act_sum = sorted_a[0] + sorted_a[1] + sorted_a[2] + sorted_a[3] + sorted_a[4] + sorted_a[5]
|
||||
inv_sum = cutlass.Float32(1.0) / act_sum
|
||||
sc = cutlass.Float32(routed_scaling_factor)
|
||||
|
||||
tc = wt.tile_idx
|
||||
row_base = tc[0] // cute.size(tiled_mma.thr_id.shape) * self.cta_tile_shape_mnk[0]
|
||||
|
||||
for i in cutlass.range(6, unroll=1):
|
||||
out_w_tensor[row_base + 0, i] = sorted_a[i] * inv_sum * sc
|
||||
out_id_tensor[row_base + 0, i] = sorted_i[i]
|
||||
|
||||
epi_bar.arrive_and_wait()
|
||||
|
||||
with cute.arch.elect_one():
|
||||
acc_pipeline.consumer_release(acc_cs)
|
||||
acc_cs.advance()
|
||||
tsched.advance_to_next_work(); wt = tsched.get_current_work()
|
||||
|
||||
# Cleanup
|
||||
tmem.relinquish_alloc_permit()
|
||||
epi_bar.arrive_and_wait()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
# ================================================================
|
||||
# Python wrapper - mirrors Nvfp4Linear but uses fused kernel
|
||||
# ================================================================
|
||||
def run_nvfp4_fused_router(
|
||||
hidden_states: torch.Tensor, # [M, K] BF16
|
||||
mat_b: torch.Tensor, # [1, K_packed, N_packed] float4_e2m1fn_x2 (K-major, from Nvfp4Linear._mat_b)
|
||||
scale_b: torch.Tensor, # [M_sf, N_sf] E4M3 (assembled, from Nvfp4Linear._scale_b)
|
||||
gsa: float, # activation global scale
|
||||
gsb: float, # weight global scale
|
||||
e_bias: torch.Tensor, # [E] FP32
|
||||
routed_scaling_factor: float,
|
||||
top_k: int = 6,
|
||||
out_weights: Optional[torch.Tensor] = None, # [M, top_k] FP32
|
||||
out_ids: Optional[torch.Tensor] = None, # [M, top_k] int32
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Run the NVFP4 fused router kernel.
|
||||
x_fp4_tensor, # CuTe tensor for activation data (FP4)
|
||||
x_sf_tensor, # CuTe tensor for activation scale factors
|
||||
w_fp4_tensor, # CuTe tensor for weight data (FP4)
|
||||
w_sf_tensor, # CuTe tensor for weight scale factors
|
||||
gsa, # global scale A (scalar tensor)
|
||||
gsb, # global scale B (scalar tensor)
|
||||
e_bias, # e_score_correction_bias [N] FP32
|
||||
M, N, K,
|
||||
top_k=6,
|
||||
routed_scaling_factor=2.5,
|
||||
mma_tiler_mn=(128, 128),
|
||||
cluster_shape_mn=(1, 1),
|
||||
num_ab_stages=2,
|
||||
):
|
||||
"""Run the NVFP4 fused router: GEMM + sqrt(softplus) + top-k."""
|
||||
|
||||
Combines NVFP4 block-scaled GEMM + sqrt(softplus) + top-k
|
||||
into a single kernel launch.
|
||||
sf_vec_size = 16 # NVF4: 16-elem blocks
|
||||
a_dtype = cutlass.Float4E2M1FN
|
||||
b_dtype = cutlass.Float4E2M1FN
|
||||
sf_dtype = cutlass.Float8E4M3FN
|
||||
acc_dtype = cutlass.Float32
|
||||
|
||||
Args:
|
||||
hidden_states: [M, K] BF16 input tensor
|
||||
mat_b: Processed FP4 weight tensor from Nvfp4Linear._mat_b
|
||||
scale_b: Processed E4M3 scale factors from Nvfp4Linear._scale_b
|
||||
gsa: Activation global scale
|
||||
gsb: Weight global scale
|
||||
e_bias: Per-expert selection bias [E] FP32
|
||||
routed_scaling_factor: Scaling factor for routed weights
|
||||
top_k: Number of experts to select
|
||||
out_weights: Pre-allocated output weights [M, top_k] FP32
|
||||
out_ids: Pre-allocated output expert IDs [M, top_k] int32
|
||||
"""
|
||||
M = hidden_states.shape[0]
|
||||
K = hidden_states.shape[1]
|
||||
N = scale_b.shape[1] if scale_b.dim() == 2 else scale_b.shape[-1] # num_experts
|
||||
device = hidden_states.device
|
||||
use_2cta = mma_tiler_mn[0] == 256
|
||||
cta_group = tcgen05.CtaGroup.TWO if use_2cta else tcgen05.CtaGroup.ONE
|
||||
|
||||
if out_weights is None:
|
||||
out_weights = torch.empty(M, top_k, dtype=torch.float32, device=device)
|
||||
if out_ids is None:
|
||||
out_ids = torch.empty(M, top_k, dtype=torch.int32, device=device)
|
||||
# Warp layout
|
||||
epi_warp_ids = (0, 1, 2, 3)
|
||||
mma_warp_id = 4
|
||||
tma_warp_id = 5
|
||||
num_warps = 6
|
||||
threads_per_warp = 32
|
||||
num_threads = threads_per_warp * num_warps
|
||||
|
||||
expert_offsets = torch.tensor([M], dtype=torch.int32, device=device)
|
||||
gsa_t = torch.tensor([gsa], dtype=torch.float32, device=device)
|
||||
gsb_t = torch.tensor([gsb], dtype=torch.float32, device=device)
|
||||
num_acc_stages = 1
|
||||
overlapping_accum = True
|
||||
|
||||
# Quantize activation to FP4 (same as Nvfp4Linear)
|
||||
from dsv4.ops.quantize import quantize_activation_nvfp4
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(hidden_states, gsa)
|
||||
# Build tiled MMA — sf_vec_size=16 triggers MmaMXF4NVF4Op (kind::mxf4nvf4)
|
||||
a_major = utils.LayoutEnum.from_tensor(x_fp4_tensor).mma_major_mode()
|
||||
b_major = utils.LayoutEnum.from_tensor(w_fp4_tensor).mma_major_mode()
|
||||
|
||||
# Create CuTe tensors (from_dlpack + mark_layout_dynamic for proper TMA)
|
||||
def _to_cute(t):
|
||||
ct = cutlass_torch.from_dlpack(t)
|
||||
ct = ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
|
||||
return ct
|
||||
|
||||
a_tensor = _to_cute(x_fp4)
|
||||
b_tensor = _to_cute(mat_b)
|
||||
sfa_tensor = _to_cute(x_sf)
|
||||
sfb_tensor = _to_cute(scale_b)
|
||||
e_bias_ct = _to_cute(e_bias)
|
||||
out_w_ct = _to_cute(out_weights)
|
||||
out_id_ct = _to_cute(out_ids)
|
||||
eo_ct = _to_cute(expert_offsets)
|
||||
gsa_ct = _to_cute(gsa_t)
|
||||
gsb_ct = _to_cute(gsb_t)
|
||||
|
||||
kernel = Nvfp4FusedRouterKernel(top_k=top_k)
|
||||
kernel.run(
|
||||
a_tensor, b_tensor, sfa_tensor, sfb_tensor,
|
||||
eo_ct, gsa_ct, gsb_ct,
|
||||
e_bias_ct, out_w_ct, out_id_ct,
|
||||
M, N, K, routed_scaling_factor, top_k,
|
||||
mma_inst_shape_mn = mma_tiler_mn
|
||||
mma_inst_shape_mn_sfb = (
|
||||
mma_inst_shape_mn[0] // (2 if use_2cta else 1),
|
||||
cute.round_up(mma_inst_shape_mn[1], 128),
|
||||
)
|
||||
|
||||
tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma(
|
||||
a_dtype, a_major, b_major, sf_dtype, sf_vec_size,
|
||||
cta_group, mma_inst_shape_mn)
|
||||
|
||||
tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma(
|
||||
a_dtype, a_major, b_major, sf_dtype, sf_vec_size,
|
||||
tcgen05.CtaGroup.ONE, mma_inst_shape_mn_sfb)
|
||||
|
||||
# MMA tiler with CuTe Ints
|
||||
inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2])
|
||||
inst_tile_k = 4
|
||||
k_tile = inst_shape_k * inst_tile_k
|
||||
|
||||
mma_tiler = (
|
||||
cutlass.Int32(mma_inst_shape_mn[0]),
|
||||
cutlass.Int32(mma_inst_shape_mn[1]),
|
||||
cutlass.Int32(k_tile),
|
||||
)
|
||||
mma_tiler_sfb = (
|
||||
cutlass.Int32(mma_inst_shape_mn_sfb[0]),
|
||||
cutlass.Int32(mma_inst_shape_mn_sfb[1]),
|
||||
cutlass.Int32(k_tile),
|
||||
)
|
||||
|
||||
cta_tile_shape_mnk = (
|
||||
mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
|
||||
mma_tiler[1],
|
||||
mma_tiler[2],
|
||||
)
|
||||
cta_tile_shape_mnk_sfb = (
|
||||
mma_tiler_sfb[0] // cute.size(tiled_mma.thr_id.shape),
|
||||
mma_tiler_sfb[1],
|
||||
mma_tiler_sfb[2],
|
||||
)
|
||||
|
||||
# Cluster layout
|
||||
cluster_layout_vmnk = cute.tiled_divide(
|
||||
cute.make_layout((*cluster_shape_mn, 1)),
|
||||
(tiled_mma.thr_id.shape,))
|
||||
cluster_layout_sfb_vmnk = cute.tiled_divide(
|
||||
cute.make_layout((*cluster_shape_mn, 1)),
|
||||
(tiled_mma_sfb.thr_id.shape,))
|
||||
|
||||
num_mcast_a = cute.size(cluster_layout_vmnk.shape[2])
|
||||
num_mcast_b = cute.size(cluster_layout_vmnk.shape[1])
|
||||
num_mcast_sfb = cute.size(cluster_layout_sfb_vmnk.shape[1])
|
||||
|
||||
# SMEM layouts
|
||||
a_smem_staged = sm100_utils.make_smem_layout_a(
|
||||
tiled_mma, mma_tiler, a_dtype, num_ab_stages)
|
||||
b_smem_staged = sm100_utils.make_smem_layout_b(
|
||||
tiled_mma, mma_tiler, b_dtype, num_ab_stages)
|
||||
sfa_smem_staged = blockscaled_utils.make_smem_layout_sfa(
|
||||
tiled_mma, mma_tiler, sf_vec_size, num_ab_stages)
|
||||
sfb_smem_staged = blockscaled_utils.make_smem_layout_sfb(
|
||||
tiled_mma, mma_tiler, sf_vec_size, num_ab_stages)
|
||||
|
||||
# TMEM cols
|
||||
sf_atom_mn = 32
|
||||
num_sfa_tmem = (cta_tile_shape_mnk[0] // sf_atom_mn) * inst_tile_k
|
||||
num_sfb_tmem = (cta_tile_shape_mnk_sfb[1] // sf_atom_mn) * inst_tile_k
|
||||
|
||||
# TMA bytes
|
||||
a_smem0 = cute.slice_(a_smem_staged, (None, None, None, 0))
|
||||
b_smem0 = cute.slice_(b_smem_staged, (None, None, None, 0))
|
||||
sfa_smem0 = cute.slice_(sfa_smem_staged, (None, None, None, 0))
|
||||
sfb_smem0 = cute.slice_(sfb_smem_staged, (None, None, None, 0))
|
||||
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
|
||||
num_tma_bytes = (
|
||||
cute.size_in_bytes(a_dtype, a_smem0) +
|
||||
cute.size_in_bytes(b_dtype, b_smem0) +
|
||||
cute.size_in_bytes(sf_dtype, sfa_smem0) +
|
||||
cute.size_in_bytes(sf_dtype, sfb_smem0)
|
||||
) * atom_thr_size
|
||||
|
||||
# TMA atoms
|
||||
a_op = sm100_utils.cluster_shape_to_tma_atom_A(cluster_shape_mn, tiled_mma.thr_id)
|
||||
tma_a, gA = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
a_op, x_fp4_tensor, a_smem0, mma_tiler, tiled_mma, cluster_layout_vmnk.shape)
|
||||
|
||||
b_op = sm100_utils.cluster_shape_to_tma_atom_B(cluster_shape_mn, tiled_mma.thr_id)
|
||||
tma_b, gB = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
b_op, w_fp4_tensor, b_smem0, mma_tiler, tiled_mma, cluster_layout_vmnk.shape)
|
||||
|
||||
tma_sfa, gSFA = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
a_op, x_sf_tensor, sfa_smem0, mma_tiler, tiled_mma,
|
||||
cluster_layout_vmnk.shape, internal_type=cutlass.Int16)
|
||||
|
||||
sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(cluster_shape_mn, tiled_mma.thr_id)
|
||||
tma_sfb, gSFB = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
sfb_op, w_sf_tensor, sfb_smem0, mma_tiler_sfb, tiled_mma_sfb,
|
||||
cluster_layout_sfb_vmnk.shape, internal_type=cutlass.Int16)
|
||||
|
||||
# Grid
|
||||
num_M_tiles = cute.ceil_div(M, cta_tile_shape_mnk[0])
|
||||
num_N_tiles = cute.ceil_div(N, cta_tile_shape_mnk[1])
|
||||
grid = (num_M_tiles * num_N_tiles, 1, 1)
|
||||
|
||||
tile_sched_params = utils.PersistentTileSchedulerParams(
|
||||
(cutlass.Int32(num_M_tiles), cutlass.Int32(num_N_tiles), cutlass.Int32(1)),
|
||||
(*cluster_shape_mn, 1))
|
||||
|
||||
# Output tensors
|
||||
out_w = cutlass_torch.make_tensor(
|
||||
torch.empty(M, top_k, dtype=torch.float32, device="cuda:0"))
|
||||
out_id = cutlass_torch.make_tensor(
|
||||
torch.empty(M, top_k, dtype=torch.int32, device="cuda:0"))
|
||||
|
||||
@cute.jit
|
||||
def fused_router(
|
||||
tma_a, gA, tma_b, gB, tma_sfa, gSFA, tma_sfb, gSFB,
|
||||
tiled_mma, tiled_mma_sfb,
|
||||
mma_tiler, mma_tiler_sfb,
|
||||
a_smem_staged, b_smem_staged, sfa_smem_staged, sfb_smem_staged,
|
||||
cluster_layout_vmnk, cluster_layout_sfb_vmnk,
|
||||
cta_tile_shape_mnk, cta_tile_shape_mnk_sfb,
|
||||
gA_mkl, gB_nkl, gSFA_mkl, gSFB_nkl,
|
||||
e_bias, out_w, out_id, gsa, gsb,
|
||||
tile_sched_params, M, N, K, top_k, routed_scaling_factor,
|
||||
num_tma_bytes, num_mcast_a, num_mcast_b, num_mcast_sfb,
|
||||
num_ab_stages, num_acc_stages, num_sfa_tmem, num_sfb_tmem,
|
||||
atom_thr_size, overlapping_accum,
|
||||
epi_warp_ids, mma_warp_id, tma_warp_id, num_threads,
|
||||
):
|
||||
...
|
||||
|
||||
cute.compile(fused_router,
|
||||
tma_a, gA, tma_b, gB, tma_sfa, gSFA, tma_sfb, gSFB,
|
||||
tiled_mma, tiled_mma_sfb, mma_tiler, mma_tiler_sfb,
|
||||
a_smem_staged, b_smem_staged, sfa_smem_staged, sfb_smem_staged,
|
||||
cluster_layout_vmnk, cluster_layout_sfb_vmnk,
|
||||
cta_tile_shape_mnk, cta_tile_shape_mnk_sfb,
|
||||
gA, gB, gSFA, gSFB,
|
||||
e_bias, out_w, out_id, gsa, gsb,
|
||||
tile_sched_params, M, N, K, top_k, routed_scaling_factor,
|
||||
num_tma_bytes, num_mcast_a, num_mcast_b, num_mcast_sfb,
|
||||
num_ab_stages, num_acc_stages, num_sfa_tmem, num_sfb_tmem,
|
||||
atom_thr_size, overlapping_accum,
|
||||
epi_warp_ids, mma_warp_id, tma_warp_id, num_threads,
|
||||
)
|
||||
return out_weights, out_ids
|
||||
|
||||
Reference in New Issue
Block a user