Test: enumerate CuTeDSL math API to check available operations
This commit is contained in:
@@ -6,7 +6,7 @@ with fused router epilogue (sqrt(softplus) + e_bias + top-k + renorm).
|
||||
PRODUCTION KERNEL. No intermediate GMEM buffer. No BF16 fallback.
|
||||
The GEMM accumulates logits in TMEM, then the epilogue warps process them directly:
|
||||
1. TMEM -> registers (via paired t2r atom from CUTLASS epilogue helpers)
|
||||
2. For each logit: sqrt(softplus(logit)) + e_bias -> score; track top-k via min-heap
|
||||
2. For each logit: sqrt(softplus(logit)) + e_bias -> score; track top-k via sorted insertion
|
||||
3. After all subtiles: sort, renormalize, write (topk_weights, topk_ids) to GMEM
|
||||
|
||||
Warp specialization (6 warps, no scheduler for dense GEMM):
|
||||
@@ -17,17 +17,21 @@ Warp specialization (6 warps, no scheduler for dense GEMM):
|
||||
Pipeline structure (2 pipelines):
|
||||
AB pipeline: TMA (producer) -> MMA (consumer) [PipelineTmaUmma]
|
||||
Acc pipeline: MMA (producer) -> Epilogue (consumer) [PipelineUmmaAsync]
|
||||
|
||||
Architecture reference: FusedSwiGLUScaledGroupedGemmKernel (dsv4/kernels/gemm/fused_swiglu.py)
|
||||
The blockscaled GEMM mainloop follows the same pattern exactly.
|
||||
The epilogue is custom: instead of TMA store, we do TMEM->reg top-k reduction.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Tuple
|
||||
import math
|
||||
from typing import Tuple, Optional, Type, Union
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
import torch
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
from cutlass.cute.typing import Pointer
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
@@ -35,12 +39,20 @@ import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
import cutlass.utils.blockscaled_layout as blockscaled_utils
|
||||
from cutlass.utils.gemm.sm100 import (
|
||||
epilogue_tmem_copy_and_partition,
|
||||
epilogue_smem_copy_and_partition,
|
||||
transform_partitioned_tensor_layout,
|
||||
)
|
||||
|
||||
|
||||
class Nvfp4FusedRouterKernel:
|
||||
"""
|
||||
NVFP4 blockscaled GEMM + fused router epilogue.
|
||||
|
||||
Dense (non-grouped) GEMM: [M, K] @ [K, E] -> [M, E] with NVFP4 weights.
|
||||
Custom epilogue: TMEM -> registers -> sqrt(softplus) + e_bias + top-k + renorm -> GMEM.
|
||||
|
||||
This follows the FusedSwiGLUScaledGroupedGemmKernel pattern for the
|
||||
blockscaled GEMM mainloop exactly, with a custom epilogue.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -55,38 +67,48 @@ class Nvfp4FusedRouterKernel:
|
||||
self.top_k = top_k
|
||||
self.use_2cta_instrs = mma_tiler_mnk[0] == 256
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.arch = "sm_100"
|
||||
|
||||
# 6-warp specialization (no scheduler warp for dense GEMM)
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4
|
||||
self.tma_warp_id = 5
|
||||
self.threads_per_warp = 32
|
||||
self.threads_per_cta = self.threads_per_warp * 6
|
||||
|
||||
# Barrier IDs
|
||||
self.cta_sync_bar_id = 1
|
||||
self.epilogue_sync_bar_id = 2
|
||||
self.tmem_alloc_sync_bar_id = 3
|
||||
self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100")
|
||||
|
||||
self.smem_capacity = utils.get_smem_capacity_in_bytes(self.arch)
|
||||
self.occupancy = 1
|
||||
self.buffer_align_bytes = 1024
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# _create_tiled_mma / _create_tiled_mma_sfb
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
def _create_tiled_mma(self, a_dtype, a_major_mode, b_major_mode, sf_dtype):
|
||||
return sm100_utils.make_blockscaled_trivial_tiled_mma(
|
||||
a_dtype, a_major_mode, b_major_mode, sf_dtype,
|
||||
self.sf_vec_size, self.cta_group,
|
||||
(self.mma_tiler_mnk[0], self.mma_tiler_mnk[1]),
|
||||
self.mma_inst_shape_mn,
|
||||
)
|
||||
|
||||
def _create_tiled_mma_sfb(self, a_dtype, a_major_mode, b_major_mode, sf_dtype):
|
||||
mma_inst_shape_mn_sfb = (
|
||||
self.mma_tiler_mnk[0] // (2 if self.use_2cta_instrs else 1),
|
||||
cute.round_up(self.mma_tiler_mnk[1], 128),
|
||||
)
|
||||
return sm100_utils.make_blockscaled_trivial_tiled_mma(
|
||||
a_dtype, a_major_mode, b_major_mode, sf_dtype,
|
||||
self.sf_vec_size, tcgen05.CtaGroup.ONE, mma_inst_shape_mn_sfb,
|
||||
self.sf_vec_size, tcgen05.CtaGroup.ONE,
|
||||
self.mma_inst_shape_mn_sfb,
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# _setup_attributes
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
def _setup_attributes(self, tiled_mma, tiled_mma_sfb, a_dtype, b_dtype, sf_dtype):
|
||||
"""Set up kernel attributes. Mirrors FusedSwiGLUScaledGroupedGemmKernel._setup_attributes."""
|
||||
self.mma_inst_shape_mn = (self.mma_tiler_mnk[0], self.mma_tiler_mnk[1])
|
||||
self.mma_inst_shape_mn_sfb = (
|
||||
self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1),
|
||||
@@ -108,11 +130,13 @@ class Nvfp4FusedRouterKernel:
|
||||
)
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
|
||||
self.mma_tiler[1], self.mma_tiler[2],
|
||||
self.mma_tiler[1],
|
||||
self.mma_tiler[2],
|
||||
)
|
||||
self.cta_tile_shape_mnk_sfb = (
|
||||
self.mma_tiler_sfb[0] // cute.size(tiled_mma.thr_id.shape),
|
||||
self.mma_tiler_sfb[1], self.mma_tiler_sfb[2],
|
||||
self.mma_tiler_sfb[1],
|
||||
self.mma_tiler_sfb[2],
|
||||
)
|
||||
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(
|
||||
@@ -129,29 +153,21 @@ class Nvfp4FusedRouterKernel:
|
||||
self.is_b_mcast = self.num_mcast_ctas_b > 1
|
||||
self.is_sfb_mcast = self.num_mcast_ctas_sfb > 1
|
||||
|
||||
# Epilogue tile: for router, we process all N columns (expert dimension).
|
||||
# Use epi_tile = (128, 32) as the subtile for t2r copy.
|
||||
# This determines how many columns are loaded from TMEM per subtile.
|
||||
self.epi_tile = (
|
||||
cute.make_layout(self.cta_tile_shape_mnk[0]),
|
||||
cute.make_layout(self.cta_tile_shape_mnk[1]),
|
||||
cute.make_layout((32, 1), stride=(1, 32)),
|
||||
)
|
||||
self.epi_tile_n = cute.size(self.epi_tile[1])
|
||||
|
||||
self.overlapping_accum = self.cta_tile_shape_mnk[1] == 256
|
||||
self.num_acc_stage = 1 if self.overlapping_accum else 2
|
||||
# Stage counts
|
||||
self.num_acc_stage = 2
|
||||
self.num_ab_stage = 2
|
||||
self.num_c_stage = 2 # not used for TMA store, but needed for stage computation
|
||||
|
||||
sf_atom_mn = 32
|
||||
self.num_sfa_tmem_cols = (self.cta_tile_shape_mnk[0] // sf_atom_mn) * mma_inst_tile_k
|
||||
self.num_sfb_tmem_cols = (self.cta_tile_shape_mnk_sfb[1] // sf_atom_mn) * mma_inst_tile_k
|
||||
self.num_sf_tmem_cols = self.num_sfa_tmem_cols + self.num_sfb_tmem_cols
|
||||
if self.overlapping_accum:
|
||||
self.num_accumulator_tmem_cols = self.cta_tile_shape_mnk[1] * 2 - self.num_sf_tmem_cols
|
||||
else:
|
||||
self.num_accumulator_tmem_cols = self.cta_tile_shape_mnk[1] * self.num_acc_stage
|
||||
|
||||
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
|
||||
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake)
|
||||
|
||||
# Compute SMEM layouts for A, B, SFA, SFB
|
||||
self.a_smem_layout_staged = sm100_utils.make_smem_layout_a(
|
||||
tiled_mma, self.mma_tiler, a_dtype, self.num_ab_stage)
|
||||
self.b_smem_layout_staged = sm100_utils.make_smem_layout_b(
|
||||
@@ -161,6 +177,28 @@ class Nvfp4FusedRouterKernel:
|
||||
self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb(
|
||||
tiled_mma, self.mma_tiler, self.sf_vec_size, self.num_ab_stage)
|
||||
|
||||
# Overlapping accumulator (N=256 case)
|
||||
self.overlapping_accum = self.cta_tile_shape_mnk[1] == 256
|
||||
if self.overlapping_accum:
|
||||
self.num_acc_pipeline_stages = 1
|
||||
else:
|
||||
self.num_acc_pipeline_stages = self.num_acc_stage
|
||||
|
||||
# TMEM column counts
|
||||
sf_atom_mn = 32
|
||||
self.num_sfa_tmem_cols = (self.cta_tile_shape_mnk[0] // sf_atom_mn) * mma_inst_tile_k
|
||||
self.num_sfb_tmem_cols = (self.cta_tile_shape_mnk_sfb[1] // sf_atom_mn) * mma_inst_tile_k
|
||||
self.num_sf_tmem_cols = self.num_sfa_tmem_cols + self.num_sfb_tmem_cols
|
||||
self.num_accumulator_tmem_cols = self.cta_tile_shape_mnk[1] * self.num_acc_stage - (
|
||||
self.num_sf_tmem_cols if self.overlapping_accum else 0
|
||||
)
|
||||
|
||||
# Only when overlapping_accum, release accumulator buffer early in epilogue
|
||||
self.iter_acc_early_release_in_epilogue = (
|
||||
self.num_sf_tmem_cols // self.epi_tile_n
|
||||
)
|
||||
|
||||
# TMA load bytes
|
||||
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
|
||||
a_smem_0 = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
|
||||
b_smem_0 = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
|
||||
@@ -173,42 +211,90 @@ class Nvfp4FusedRouterKernel:
|
||||
cute.size_in_bytes(sf_dtype, sfb_smem_0)
|
||||
) * atom_thr_size
|
||||
|
||||
self.iter_acc_early_release = self.num_sf_tmem_cols // self.epi_tile_n
|
||||
# TMEM allocation size
|
||||
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
|
||||
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake)
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# mainloop_s2t_copy_and_partition (same as fused_swiglu.py)
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
def mainloop_s2t_copy_and_partition(self, sSF, tSF, cta_group):
|
||||
"""Make tiledCopy for SMEM -> TMEM load of a scale factor tensor."""
|
||||
tCsSF_compact = cute.filter_zeros(sSF)
|
||||
tCtSF_compact = cute.filter_zeros(tSF)
|
||||
|
||||
copy_atom_s2t = cute.make_copy_atom(
|
||||
tcgen05.Cp4x32x128bOp(cta_group),
|
||||
self.sf_dtype,
|
||||
)
|
||||
tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact)
|
||||
thr_copy_s2t = tiled_copy_s2t.get_slice(0)
|
||||
|
||||
tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact)
|
||||
tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(
|
||||
tiled_copy_s2t, tCsSF_compact_s2t_
|
||||
)
|
||||
tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact)
|
||||
|
||||
return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# run() — Python entry point
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
def run(self, mat_a, mat_b, scale_a, scale_b, e_bias, out_weights, out_ids,
|
||||
M, N, K, routed_scaling_factor, top_k, stream=None):
|
||||
if stream is None:
|
||||
stream = cuda.CUstream(0)
|
||||
|
||||
a_dtype = cutlass.Float4E2M1FN
|
||||
b_dtype = cutlass.Float4E2M1FN
|
||||
sf_dtype = cutlass.Float8E4M3FN
|
||||
# Infer dtypes and major modes from tensors
|
||||
a_dtype = mat_a.element_type
|
||||
b_dtype = mat_b.element_type
|
||||
sf_dtype = scale_a.element_type
|
||||
a_major_mode = utils.LayoutEnum.from_tensor(mat_a).mma_major_mode()
|
||||
b_major_mode = utils.LayoutEnum.from_tensor(mat_b).mma_major_mode()
|
||||
|
||||
# Save for kernel use
|
||||
self.a_dtype = a_dtype
|
||||
self.b_dtype = b_dtype
|
||||
self.sf_dtype = sf_dtype
|
||||
self.a_major_mode = a_major_mode
|
||||
self.b_major_mode = b_major_mode
|
||||
|
||||
tiled_mma = self._create_tiled_mma(a_dtype, a_major_mode, b_major_mode, sf_dtype)
|
||||
tiled_mma_sfb = self._create_tiled_mma_sfb(a_dtype, a_major_mode, b_major_mode, sf_dtype)
|
||||
self._setup_attributes(tiled_mma, tiled_mma_sfb, a_dtype, b_dtype, sf_dtype)
|
||||
|
||||
a_smem_0 = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
|
||||
# TMA atoms — following fused_swiglu.py exactly
|
||||
# A
|
||||
a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
|
||||
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
|
||||
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
a_op, mat_a, a_smem_0, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
|
||||
a_op, mat_a, a_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
|
||||
|
||||
b_smem_0 = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
|
||||
# B
|
||||
b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id)
|
||||
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
|
||||
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
b_op, mat_b, b_smem_0, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
|
||||
b_op, mat_b, b_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
|
||||
|
||||
sfa_smem_0 = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0))
|
||||
# SFA
|
||||
sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
|
||||
sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0))
|
||||
tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
a_op, scale_a, sfa_smem_0, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
|
||||
sfa_op, scale_a, sfa_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape,
|
||||
internal_type=cutlass.Uint64)
|
||||
|
||||
sfb_smem_0 = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0))
|
||||
sfb_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma_sfb.thr_id)
|
||||
# SFB
|
||||
sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id)
|
||||
sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0))
|
||||
tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
sfb_op, scale_b, sfb_smem_0, self.mma_tiler_sfb, tiled_mma_sfb, self.cluster_layout_sfb_vmnk.shape)
|
||||
sfb_op, scale_b, sfb_smem_layout, self.mma_tiler_sfb, tiled_mma_sfb,
|
||||
self.cluster_layout_sfb_vmnk.shape, internal_type=cutlass.Uint64)
|
||||
|
||||
# Grid: dense GEMM, one CTA per (M_tile, N_tile)
|
||||
num_M_tiles = cute.ceil_div(M, self.cta_tile_shape_mnk[0])
|
||||
num_N_tiles = cute.ceil_div(N, self.cta_tile_shape_mnk[1])
|
||||
L = 1
|
||||
@@ -239,6 +325,10 @@ class Nvfp4FusedRouterKernel:
|
||||
|
||||
cute.compile(_compiled_fn, mat_a, mat_b, scale_a, scale_b, e_bias, out_weights, out_ids)
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# GPU kernel
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, tiled_mma, tiled_mma_sfb,
|
||||
tma_atom_a, mA_mkl, tma_atom_b, mB_nkl,
|
||||
@@ -262,32 +352,32 @@ class Nvfp4FusedRouterKernel:
|
||||
block_coord = cluster_layout_vmnk.get_flat_coord(cta_rank)
|
||||
|
||||
acc_dtype = cutlass.Float32
|
||||
sf_dtype = cutlass.Float8E4M3FN
|
||||
|
||||
# Reconstruct SMEM layout slices (same as fused_swiglu kernel)
|
||||
a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, None, 0))
|
||||
b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, None, 0))
|
||||
sfa_smem_layout = cute.slice_(sfa_smem_layout_staged, (None, None, None, 0))
|
||||
sfb_smem_layout = cute.slice_(sfb_smem_layout_staged, (None, None, None, 0))
|
||||
|
||||
# ============================================================
|
||||
# Shared storage
|
||||
# ============================================================
|
||||
@cute.struct
|
||||
class SharedStorage:
|
||||
ab_full_mbar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
|
||||
ab_empty_mbar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage]
|
||||
acc_full_mbar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
|
||||
acc_empty_mbar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
|
||||
ab_full_mbar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
acc_full_mbar: cute.struct.MemRange[cutlass.Int64, self.num_acc_pipeline_stages * 2]
|
||||
tmem_dealloc_mbar: cutlass.Int64
|
||||
tmem_holding: cutlass.Int32
|
||||
merge_scores: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 128*self.top_k], 128]
|
||||
merge_indices: cute.struct.Align[cute.struct.MemRange[cutlass.Int32, 128*self.top_k], 128]
|
||||
merge_acts: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 128*self.top_k], 128]
|
||||
sA: cute.struct.Align[cute.struct.MemRange[cutlass.Float4E2M1FN, cute.cosize(a_smem_layout_staged.outer)], self.buffer_align_bytes]
|
||||
sB: cute.struct.Align[cute.struct.MemRange[cutlass.Float4E2M1FN, cute.cosize(b_smem_layout_staged.outer)], self.buffer_align_bytes]
|
||||
sSFA: cute.struct.Align[cute.struct.MemRange[cutlass.Float8E4M3FN, cute.cosize(sfa_smem_layout_staged.outer)], self.buffer_align_bytes]
|
||||
sSFB: cute.struct.Align[cute.struct.MemRange[cutlass.Float8E4M3FN, cute.cosize(sfb_smem_layout_staged.outer)], self.buffer_align_bytes]
|
||||
# SMEM for top-k merge: 128 threads × top_k entries
|
||||
merge_scores: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 128 * self.top_k], 128]
|
||||
merge_indices: cute.struct.Align[cute.struct.MemRange[cutlass.Int32, 128 * self.top_k], 128]
|
||||
merge_acts: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 128 * self.top_k], 128]
|
||||
|
||||
smem = utils.SmemAllocator()
|
||||
storage = smem.allocate(SharedStorage)
|
||||
|
||||
# ============================================================
|
||||
# Pipelines
|
||||
# Pipelines (following fused_swiglu.py exactly)
|
||||
# ============================================================
|
||||
ab_pipeline = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=storage.ab_full_mbar.data_ptr(),
|
||||
@@ -298,15 +388,17 @@ class Nvfp4FusedRouterKernel:
|
||||
self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1),
|
||||
tx_count=self.num_tma_load_bytes,
|
||||
cta_layout_vmnk=cluster_layout_vmnk,
|
||||
defer_sync=True,
|
||||
)
|
||||
|
||||
num_acc_cons = self.threads_per_warp * len(self.epilogue_warp_id) * (2 if use_2cta else 1)
|
||||
acc_pipeline = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=storage.acc_full_mbar.data_ptr(),
|
||||
num_stages=self.num_acc_stage,
|
||||
num_stages=self.num_acc_pipeline_stages,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, num_acc_cons),
|
||||
cta_layout_vmnk=cluster_layout_vmnk,
|
||||
defer_sync=True,
|
||||
)
|
||||
|
||||
tmem = utils.TmemAllocator(
|
||||
@@ -324,12 +416,32 @@ class Nvfp4FusedRouterKernel:
|
||||
self.threads_per_warp * len(self.epilogue_warp_id))
|
||||
|
||||
# ============================================================
|
||||
# SMEM tensors
|
||||
# SMEM tensors (following fused_swiglu.py pattern)
|
||||
# A/B use swizzled layouts (ComposedLayout: .outer + .inner)
|
||||
# SFA/SFB use plain layouts (not Composed)
|
||||
# ============================================================
|
||||
sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner)
|
||||
sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner)
|
||||
sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged.outer, swizzle=sfa_smem_layout_staged.inner)
|
||||
sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged.outer, swizzle=sfb_smem_layout_staged.inner)
|
||||
sA = smem.allocate_tensor(
|
||||
element_type=self.a_dtype,
|
||||
layout=a_smem_layout_staged.outer,
|
||||
byte_alignment=128,
|
||||
swizzle=a_smem_layout_staged.inner,
|
||||
)
|
||||
sB = smem.allocate_tensor(
|
||||
element_type=self.b_dtype,
|
||||
layout=b_smem_layout_staged.outer,
|
||||
byte_alignment=128,
|
||||
swizzle=b_smem_layout_staged.inner,
|
||||
)
|
||||
sSFA = smem.allocate_tensor(
|
||||
element_type=self.sf_dtype,
|
||||
layout=sfa_smem_layout_staged,
|
||||
byte_alignment=128,
|
||||
)
|
||||
sSFB = smem.allocate_tensor(
|
||||
element_type=self.sf_dtype,
|
||||
layout=sfb_smem_layout_staged,
|
||||
byte_alignment=128,
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# Multicast masks
|
||||
@@ -342,7 +454,7 @@ class Nvfp4FusedRouterKernel:
|
||||
sfb_mcast = cpasync.create_tma_multicast_mask(cluster_layout_sfb_vmnk, block_coord, mcast_mode=1)
|
||||
|
||||
# ============================================================
|
||||
# Partition global tensors
|
||||
# Partition global tensors (same as fused_swiglu TMA warp setup)
|
||||
# ============================================================
|
||||
gA = cute.local_tile(mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None))
|
||||
gB = cute.local_tile(mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None))
|
||||
@@ -357,7 +469,7 @@ class Nvfp4FusedRouterKernel:
|
||||
thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_v)
|
||||
tCgSFB = thr_mma_sfb.partition_B(gSFB)
|
||||
|
||||
# TMA partitions for A/B
|
||||
# TMA partitions for A/B (following fused_swiglu)
|
||||
a_cta_l = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_atom_a, block_coord[2], a_cta_l,
|
||||
cute.group_modes(sA, 0, 3), cute.group_modes(tCgA, 0, 3))
|
||||
@@ -365,28 +477,32 @@ class Nvfp4FusedRouterKernel:
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_atom_b, block_coord[1], b_cta_l,
|
||||
cute.group_modes(sB, 0, 3), cute.group_modes(tCgB, 0, 3))
|
||||
|
||||
# TMA partitions for SFA/SFB
|
||||
# TMA partitions for SFA/SFB (following fused_swiglu)
|
||||
tAsSFA, tAgSFA = cpasync.tma_partition(tma_atom_sfa, block_coord[2], a_cta_l,
|
||||
cute.group_modes(sSFA, 0, 3), cute.group_modes(tCgSFA, 0, 3))
|
||||
tAsSFA = cute.filter_zeros(tAsSFA)
|
||||
tAgSFA = cute.filter_zeros(tAgSFA)
|
||||
|
||||
sfb_cta_l = cute.make_layout(cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape)
|
||||
block_coord_sfb = cluster_layout_sfb_vmnk.get_flat_coord(cta_rank)
|
||||
sfb_cta_l = cute.make_layout(cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape)
|
||||
tBsSFB, tBgSFB = cpasync.tma_partition(tma_atom_sfb, block_coord_sfb[1], sfb_cta_l,
|
||||
cute.group_modes(sSFB, 0, 3), cute.group_modes(tCgSFB, 0, 3))
|
||||
tBsSFB = cute.filter_zeros(tBsSFB)
|
||||
tBgSFB = cute.filter_zeros(tBgSFB)
|
||||
|
||||
# TMEM accumulator shape
|
||||
# TMEM accumulator
|
||||
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
|
||||
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
|
||||
|
||||
# Cluster arrive (before any TMA or pipeline ops)
|
||||
if cute.size(self.cluster_shape_mn) > 1:
|
||||
cute.arch.cluster_arrive_relaxed()
|
||||
else:
|
||||
cta_bar.arrive_and_wait()
|
||||
|
||||
# ============================================================
|
||||
# TMA WARP
|
||||
# TMA WARP — Load A, B, SFA, SFB from GMEM to SMEM
|
||||
# (follows fused_swiglu TMA warp exactly)
|
||||
# ============================================================
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_atom_a)
|
||||
@@ -444,7 +560,8 @@ class Nvfp4FusedRouterKernel:
|
||||
wt = tsched.get_current_work()
|
||||
|
||||
# ============================================================
|
||||
# MMA WARP — blockscaled GEMM: (A * SFA) @ (B * SFB) -> TMEM
|
||||
# MMA WARP — Blockscaled GEMM: (A * SFA) @ (B * SFB) -> TMEM
|
||||
# (follows fused_swiglu MMA warp exactly)
|
||||
# ============================================================
|
||||
if warp_idx == self.mma_warp_id:
|
||||
# Wait for cluster sync
|
||||
@@ -463,7 +580,6 @@ class Nvfp4FusedRouterKernel:
|
||||
tCrB = tiled_mma.make_fragment_B(sB)
|
||||
|
||||
# S2T copies for SFA: SMEM -> TMEM
|
||||
# The SFA tmem region starts after the accumulator columns
|
||||
sfa_tmem_ptr = acc_tmem_ptr
|
||||
tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa(
|
||||
tiled_mma, self.mma_tiler, self.sf_vec_size,
|
||||
@@ -477,7 +593,7 @@ class Nvfp4FusedRouterKernel:
|
||||
cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)))
|
||||
tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout)
|
||||
|
||||
# S2T copy atoms
|
||||
# S2T copy atoms (using fused_swiglu helper)
|
||||
tiled_copy_s2t_sfa, tCsSFA_compact_s2t, tCtSFA_compact_s2t = \
|
||||
self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA, self.cta_group)
|
||||
tiled_copy_s2t_sfb, tCsSFB_compact_s2t, tCtSFB_compact_s2t = \
|
||||
@@ -490,9 +606,7 @@ class Nvfp4FusedRouterKernel:
|
||||
ab_cs = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, self.num_ab_stage)
|
||||
acc_ps = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
|
||||
num_tiles_executed = cutlass.Int32(0)
|
||||
pipeline.PipelineUserType.Producer, self.num_acc_pipeline_stages)
|
||||
|
||||
while wt.is_valid_tile:
|
||||
# Wait for accumulator buffer empty
|
||||
@@ -516,30 +630,28 @@ class Nvfp4FusedRouterKernel:
|
||||
if ab_cs.count < k_tiles and is_leader_cta:
|
||||
peek_ab_full = ab_pipeline.consumer_try_wait(ab_cs)
|
||||
|
||||
# Mainloop: K-tiles
|
||||
# Mainloop: K-tiles (following fused_swiglu exactly)
|
||||
for kt in cutlass.range(0, k_tiles, 1, unroll=1):
|
||||
if is_leader_cta:
|
||||
ab_pipeline.consumer_wait(ab_cs, peek_ab_full)
|
||||
|
||||
# Copy SFA/SFB from SMEM to TMEM
|
||||
s2t_stage = (
|
||||
None, None, None, None, ab_cs.index,
|
||||
)
|
||||
s2t_stage_coord = (None, None, None, None, ab_cs.index)
|
||||
cute.copy(tiled_copy_s2t_sfa,
|
||||
tCsSFA_compact_s2t[s2t_stage],
|
||||
tCsSFA_compact_s2t[s2t_stage_coord],
|
||||
tCtSFA_compact_s2t)
|
||||
cute.copy(tiled_copy_s2t_sfb,
|
||||
tCsSFB_compact_s2t[s2t_stage],
|
||||
tCsSFB_compact_s2t[s2t_stage_coord],
|
||||
tCtSFB_compact_s2t)
|
||||
|
||||
# Set SFA/SFB for MMA
|
||||
# GEMM: (A * SFA) @ (B * SFB) -> Acc
|
||||
num_kblocks = cute.size(tCrA, mode=[2])
|
||||
for kblock_idx in cutlass.range(num_kblocks, unroll=1):
|
||||
sf_kblock = (None, None, kblock_idx)
|
||||
sf_kblock_coord = (None, None, kblock_idx)
|
||||
tiled_mma.set(tcgen05.Field.SFA,
|
||||
tCtSFA[sf_kblock].iterator)
|
||||
tCtSFA[sf_kblock_coord].iterator)
|
||||
tiled_mma.set(tcgen05.Field.SFB,
|
||||
tCtSFB[sf_kblock].iterator)
|
||||
tCtSFB[sf_kblock_coord].iterator)
|
||||
|
||||
kb_coord = (None, None, kblock_idx, ab_cs.index)
|
||||
cute.gemm(tiled_mma, tCrA[kb_coord], tCrB[kb_coord],
|
||||
@@ -558,12 +670,11 @@ class Nvfp4FusedRouterKernel:
|
||||
if is_leader_cta:
|
||||
acc_pipeline.producer_commit(acc_ps)
|
||||
acc_ps.advance()
|
||||
num_tiles_executed += cutlass.Int32(1)
|
||||
|
||||
tsched.advance_to_next_work()
|
||||
wt = tsched.get_current_work()
|
||||
|
||||
# Wait for accumulator buffer empty
|
||||
# Wait for accumulator buffer empty (tail)
|
||||
if is_leader_cta:
|
||||
acc_pipeline.producer_tail(acc_ps)
|
||||
|
||||
@@ -576,15 +687,13 @@ class Nvfp4FusedRouterKernel:
|
||||
#
|
||||
# Strategy:
|
||||
# 1. Read TMEM accumulator into registers via paired t2r copy
|
||||
# (using epilogue_tmem_copy_and_partition from CUTLASS)
|
||||
# 2. For each element: compute act = sqrt(softplus(logit)),
|
||||
# score = act + e_bias[expert_idx]
|
||||
# 3. Insert into per-thread running top-6 (sorted, fully unrolled)
|
||||
# 4. After all tiles: write local top-6 to SMEM, one thread merges,
|
||||
# sorts, renormalizes, writes to GMEM
|
||||
#
|
||||
# The top-6 is maintained in DESCENDING order:
|
||||
# s0 >= s1 >= s2 >= s3 >= s4 >= s5
|
||||
# Insertion uses fully unrolled comparisons — no dynamic indexing.
|
||||
# 3. Maintain per-thread running top-k via sorted insertion
|
||||
# (fully unrolled for top_k=6, descending order)
|
||||
# 4. After all tiles: write local top-k to SMEM,
|
||||
# thread 0 merges, sorts, renormalizes, writes to GMEM
|
||||
#
|
||||
if warp_idx in self.epilogue_warp_id:
|
||||
if cute.size(self.cluster_shape_mn) > 1:
|
||||
@@ -597,61 +706,29 @@ class Nvfp4FusedRouterKernel:
|
||||
tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout)
|
||||
|
||||
# TMEM → register copy (paired atoms from CUTLASS)
|
||||
epi_n = self.epi_tile_n
|
||||
tiled_copy_t2r, tTR_tAcc_base = epilogue_tmem_copy_and_partition(
|
||||
tCtAcc_base, epi_tile, self.epilogue_warp_id, acc_dtype, use_2cta)
|
||||
tTR_rAcc = tiled_copy_t2r.fragments_slice(tiled_copy_t2r, tTR_tAcc_base)
|
||||
|
||||
# Identity tensor for (row, col) coordinates
|
||||
cAcc = cute.make_identity_tensor(
|
||||
(self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1]))
|
||||
tCcAcc = thr_mma.partition_C(cAcc)
|
||||
cFlat = cute.flatten(tCcAcc)
|
||||
|
||||
# Merge SMEM tensors (for cross-thread top-k merge)
|
||||
s_merge_s = cute.make_tensor(
|
||||
storage.merge_scores.data_ptr(),
|
||||
cute.make_layout((128, TK)))
|
||||
s_merge_i = cute.make_tensor(
|
||||
storage.merge_indices.data_ptr(),
|
||||
cute.make_layout((128, TK)))
|
||||
s_merge_a = cute.make_tensor(
|
||||
storage.merge_acts.data_ptr(),
|
||||
cute.make_layout((128, TK)))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Running top-6 per thread — individual scalar variables
|
||||
# Per-thread running top-k — individual scalar variables
|
||||
# Stored in DESCENDING order: s0 >= s1 >= s2 >= s3 >= s4 >= s5
|
||||
# ------------------------------------------------------------------
|
||||
s0 = cutlass.Float32(-1e30)
|
||||
s1 = cutlass.Float32(-1e30)
|
||||
s2 = cutlass.Float32(-1e30)
|
||||
s3 = cutlass.Float32(-1e30)
|
||||
s4 = cutlass.Float32(-1e30)
|
||||
s5 = cutlass.Float32(-1e30)
|
||||
i0 = cutlass.Int32(-1)
|
||||
i1 = cutlass.Int32(-1)
|
||||
i2 = cutlass.Int32(-1)
|
||||
i3 = cutlass.Int32(-1)
|
||||
i4 = cutlass.Int32(-1)
|
||||
i5 = cutlass.Int32(-1)
|
||||
a0 = cutlass.Float32(0.0)
|
||||
a1 = cutlass.Float32(0.0)
|
||||
a2 = cutlass.Float32(0.0)
|
||||
a3 = cutlass.Float32(0.0)
|
||||
a4 = cutlass.Float32(0.0)
|
||||
a5 = cutlass.Float32(0.0)
|
||||
TK = self.top_k
|
||||
s0 = cutlass.Float32(-1e30); s1 = cutlass.Float32(-1e30)
|
||||
s2 = cutlass.Float32(-1e30); s3 = cutlass.Float32(-1e30)
|
||||
s4 = cutlass.Float32(-1e30); s5 = cutlass.Float32(-1e30)
|
||||
i0 = cutlass.Int32(-1); i1 = cutlass.Int32(-1)
|
||||
i2 = cutlass.Int32(-1); i3 = cutlass.Int32(-1)
|
||||
i4 = cutlass.Int32(-1); i5 = cutlass.Int32(-1)
|
||||
a0 = cutlass.Float32(0.0); a1 = cutlass.Float32(0.0)
|
||||
a2 = cutlass.Float32(0.0); a3 = cutlass.Float32(0.0)
|
||||
a4 = cutlass.Float32(0.0); a5 = cutlass.Float32(0.0)
|
||||
|
||||
# Tile scheduler + pipeline states
|
||||
tsched = utils.StaticPersistentTileScheduler.create(
|
||||
tile_sched_params, bidx, cute.arch.grid_dim())
|
||||
wt = tsched.initial_work_tile_info()
|
||||
acc_cs = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
|
||||
# Track which row we're computing top-k for (row 0 of each M-tile)
|
||||
current_row = cutlass.Int32(-1)
|
||||
num_tiles_done = cutlass.Int32(0)
|
||||
pipeline.PipelineUserType.Consumer, self.num_acc_pipeline_stages)
|
||||
|
||||
while wt.is_valid_tile:
|
||||
acc_pipeline.consumer_wait(acc_cs)
|
||||
@@ -661,16 +738,11 @@ class Nvfp4FusedRouterKernel:
|
||||
else:
|
||||
acc_stage_index = acc_cs.index
|
||||
|
||||
# Get tile N offset (which 128-expert slice this tile covers)
|
||||
# Get tile N offset (which expert slice this tile covers)
|
||||
tc = wt.tile_idx
|
||||
tile_n_offset = tc[1] * self.cta_tile_shape_mnk[1]
|
||||
tile_m_base = tc[0] // cute.size(tiled_mma.thr_id.shape) * self.cta_tile_shape_mnk[0]
|
||||
|
||||
# If this is a new row, the running top-6 is already accumulated
|
||||
# For the first tile of a row, we just continue accumulating
|
||||
if num_tiles_done == cutlass.Int32(0):
|
||||
current_row = tile_m_base
|
||||
|
||||
tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_stage_index)]
|
||||
tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
|
||||
|
||||
@@ -684,33 +756,27 @@ class Nvfp4FusedRouterKernel:
|
||||
|
||||
# Early release accumulator for overlapping case
|
||||
if cutlass.const_expr(self.overlapping_accum):
|
||||
if subtile_idx == self.iter_acc_early_release:
|
||||
if subtile_idx == self.iter_acc_early_release_in_epilogue:
|
||||
with cute.arch.elect_one():
|
||||
acc_pipeline.consumer_release(acc_cs)
|
||||
acc_cs.advance()
|
||||
|
||||
# Process each element in the register fragment
|
||||
rFlat = cute.flatten(tTR_rAcc)
|
||||
elem_cnt = cute.size(rFlat)
|
||||
rAcc = tTR_rAcc.load()
|
||||
elem_cnt = cute.size(rAcc)
|
||||
for e in cutlass.range(elem_cnt, unroll=4):
|
||||
logit = rFlat[e]
|
||||
coord = cFlat[e]
|
||||
row = coord[0]
|
||||
col = coord[1]
|
||||
logit = rAcc[e]
|
||||
# Expert index = subtile_offset + e
|
||||
e_idx = cutlass.Int32(tile_n_offset) + cutlass.Int32(subtile_idx * self.epi_tile_n) + cutlass.Int32(e)
|
||||
|
||||
# Expert index = col + tile_n_offset + (subtile_idx * epi_n)
|
||||
e_idx = col + tile_n_offset + (subtile_idx * epi_n)
|
||||
|
||||
# Only process row 0 (the actual token row)
|
||||
# For M=1 padded to 128, only row 0 has valid data
|
||||
if row == 0:
|
||||
# Only process if expert index is valid
|
||||
if e_idx < cutlass.Int32(N):
|
||||
# sqrt(softplus(logit))
|
||||
# softplus(x) = max(x, 0) + log(1 + exp(-|x|))
|
||||
abs_x = cute.math.absf(logit)
|
||||
pos = cute.math.fmax(logit, cutlass.Float32(0.0))
|
||||
exp_neg = cute.math.exp(-abs_x)
|
||||
one_plus = cutlass.Float32(1.0) + exp_neg
|
||||
sp = pos + cute.math.log(one_plus)
|
||||
sp = pos + cute.math.log(cutlass.Float32(1.0) + exp_neg)
|
||||
act = cute.math.sqrt(sp)
|
||||
|
||||
# score = act + e_bias (for selection only)
|
||||
@@ -718,10 +784,8 @@ class Nvfp4FusedRouterKernel:
|
||||
|
||||
# Sorted insertion into descending top-6
|
||||
# s0 >= s1 >= s2 >= s3 >= s4 >= s5
|
||||
# If score <= s5, skip
|
||||
if score > s5:
|
||||
if score > s4:
|
||||
# Shift s4 → s5
|
||||
s5 = s4; i5 = i4; a5 = a4
|
||||
if score > s3:
|
||||
s4 = s3; i4 = i3; a4 = a3
|
||||
@@ -749,8 +813,6 @@ class Nvfp4FusedRouterKernel:
|
||||
acc_pipeline.consumer_release(acc_cs)
|
||||
acc_cs.advance()
|
||||
|
||||
num_tiles_done += cutlass.Int32(1)
|
||||
|
||||
tsched.advance_to_next_work()
|
||||
wt = tsched.get_current_work()
|
||||
|
||||
@@ -759,16 +821,17 @@ class Nvfp4FusedRouterKernel:
|
||||
# ==================================================================
|
||||
# Each thread writes its running top-6 to SMEM
|
||||
tid = warp_idx * 32 + tidx
|
||||
s_merge_s[tid, 0] = s0; s_merge_s[tid, 1] = s1; s_merge_s[tid, 2] = s2
|
||||
s_merge_s[tid, 3] = s3; s_merge_s[tid, 4] = s4; s_merge_s[tid, 5] = s5
|
||||
s_merge_i[tid, 0] = i0; s_merge_i[tid, 1] = i1; s_merge_i[tid, 2] = i2
|
||||
s_merge_i[tid, 3] = i3; s_merge_i[tid, 4] = i4; s_merge_i[tid, 5] = i5
|
||||
s_merge_a[tid, 0] = a0; s_merge_a[tid, 1] = a1; s_merge_a[tid, 2] = a2
|
||||
s_merge_a[tid, 3] = a3; s_merge_a[tid, 4] = a4; s_merge_a[tid, 5] = a5
|
||||
for k_idx in cutlass.range(TK, unroll=1):
|
||||
s_val = s0 if k_idx == 0 else (s1 if k_idx == 1 else (s2 if k_idx == 2 else (s3 if k_idx == 3 else (s4 if k_idx == 4 else s5))))
|
||||
i_val = i0 if k_idx == 0 else (i1 if k_idx == 1 else (i2 if k_idx == 2 else (i3 if k_idx == 3 else (i4 if k_idx == 4 else i5))))
|
||||
a_val = a0 if k_idx == 0 else (a1 if k_idx == 1 else (a2 if k_idx == 2 else (a3 if k_idx == 3 else (a4 if k_idx == 4 else a5))))
|
||||
storage.merge_scores.data_ptr()[tid * TK + k_idx] = s_val
|
||||
storage.merge_indices.data_ptr()[tid * TK + k_idx] = i_val
|
||||
storage.merge_acts.data_ptr()[tid * TK + k_idx] = a_val
|
||||
|
||||
epi_bar.arrive_and_wait()
|
||||
|
||||
# Thread 0 merges all 128 threads' top-6 into final result
|
||||
# Thread 0 of warp 0 does the final merge + store
|
||||
if warp_idx == 0 and tidx == 0:
|
||||
# Initialize final top-6 from thread 0's data
|
||||
fs0 = s0; fs1 = s1; fs2 = s2; fs3 = s3; fs4 = s4; fs5 = s5
|
||||
@@ -777,10 +840,10 @@ class Nvfp4FusedRouterKernel:
|
||||
|
||||
# Merge all other threads (1..127)
|
||||
for t in cutlass.range(1, 128, unroll=1):
|
||||
for k in cutlass.range(TK, unroll=1):
|
||||
cs = s_merge_s[t, k]
|
||||
ci = s_merge_i[t, k]
|
||||
ca = s_merge_a[t, k]
|
||||
for k_idx in cutlass.range(TK, unroll=1):
|
||||
cs = storage.merge_scores.data_ptr()[t * TK + k_idx]
|
||||
ci = storage.merge_indices.data_ptr()[t * TK + k_idx]
|
||||
ca = storage.merge_acts.data_ptr()[t * TK + k_idx]
|
||||
# Only merge if this is a valid entry (index >= 0)
|
||||
if ci >= cutlass.Int32(0):
|
||||
# Sorted insertion into final top-6 (descending)
|
||||
@@ -812,20 +875,19 @@ class Nvfp4FusedRouterKernel:
|
||||
inv_sum = cutlass.Float32(1.0) / act_sum
|
||||
sc = cutlass.Float32(routed_scaling_factor)
|
||||
|
||||
# Store to GMEM (row 0 of the M-tile)
|
||||
row_idx = cutlass.Int32(0)
|
||||
out_w_tensor[row_idx, 0] = fa0 * inv_sum * sc
|
||||
out_w_tensor[row_idx, 1] = fa1 * inv_sum * sc
|
||||
out_w_tensor[row_idx, 2] = fa2 * inv_sum * sc
|
||||
out_w_tensor[row_idx, 3] = fa3 * inv_sum * sc
|
||||
out_w_tensor[row_idx, 4] = fa4 * inv_sum * sc
|
||||
out_w_tensor[row_idx, 5] = fa5 * inv_sum * sc
|
||||
out_id_tensor[row_idx, 0] = fi0
|
||||
out_id_tensor[row_idx, 1] = fi1
|
||||
out_id_tensor[row_idx, 2] = fi2
|
||||
out_id_tensor[row_idx, 3] = fi3
|
||||
out_id_tensor[row_idx, 4] = fi4
|
||||
out_id_tensor[row_idx, 5] = fi5
|
||||
# Store to GMEM
|
||||
out_w_tensor[0, 0] = fa0 * inv_sum * sc
|
||||
out_w_tensor[0, 1] = fa1 * inv_sum * sc
|
||||
out_w_tensor[0, 2] = fa2 * inv_sum * sc
|
||||
out_w_tensor[0, 3] = fa3 * inv_sum * sc
|
||||
out_w_tensor[0, 4] = fa4 * inv_sum * sc
|
||||
out_w_tensor[0, 5] = fa5 * inv_sum * sc
|
||||
out_id_tensor[0, 0] = fi0
|
||||
out_id_tensor[0, 1] = fi1
|
||||
out_id_tensor[0, 2] = fi2
|
||||
out_id_tensor[0, 3] = fi3
|
||||
out_id_tensor[0, 4] = fi4
|
||||
out_id_tensor[0, 5] = fi5
|
||||
|
||||
epi_bar.arrive_and_wait()
|
||||
|
||||
@@ -853,21 +915,6 @@ def run_nvfp4_fused_router(
|
||||
|
||||
Single-kernel: NVFP4 block-scaled GEMM + fused router epilogue.
|
||||
No intermediate GMEM buffer. No BF16 fallback.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
hidden_states : [M, K] BF16 — raw activation
|
||||
mat_b : CuTe tensor — gate weight in NVFP4 blockscaled layout
|
||||
scale_b : CuTe tensor — gate weight scale factors in blockscaled layout
|
||||
gsa : activation global scale (scalar)
|
||||
gsb_val : weight global scale (float)
|
||||
e_bias : [E] FP32 — per-expert selection bias
|
||||
routed_scaling_factor : float
|
||||
top_k : int (default 6)
|
||||
|
||||
Returns
|
||||
-------
|
||||
(topk_weights, topk_ids) — [M, top_k] FP32 and [M, top_k] int32
|
||||
"""
|
||||
import cutlass.torch as cutlass_torch
|
||||
from dsv4.ops.quantize import quantize_activation_nvfp4
|
||||
|
||||
78
tests/unit/test_cute_math_api.py
Normal file
78
tests/unit/test_cute_math_api.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Test: check what CuTeDSL math operations are available."""
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
def test_cute_math_api():
|
||||
"""Enumerate available CuTeDSL math/arch operations."""
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
|
||||
# Check cute.math module
|
||||
print("=== cute.math attributes ===")
|
||||
if hasattr(cute, 'math'):
|
||||
for attr in sorted(dir(cute.math)):
|
||||
if not attr.startswith('_'):
|
||||
print(f" cute.math.{attr}")
|
||||
else:
|
||||
print(" cute.math does not exist")
|
||||
|
||||
# Check cute.arch module for math
|
||||
print("\n=== cute.arch math-related attributes ===")
|
||||
if hasattr(cute, 'arch'):
|
||||
for attr in sorted(dir(cute.arch)):
|
||||
if any(k in attr.lower() for k in ['sqrt', 'log', 'exp', 'abs', 'sin', 'cos', 'rsqrt', 'rcp', 'fma', 'div']):
|
||||
print(f" cute.arch.{attr}")
|
||||
|
||||
# Check cute directly for math
|
||||
print("\n=== cute math-related attributes ===")
|
||||
for attr in sorted(dir(cute)):
|
||||
if any(k in attr.lower() for k in ['sqrt', 'log', 'exp', 'abs', 'sin', 'cos', 'rsqrt', 'rcp']):
|
||||
print(f" cute.{attr}")
|
||||
|
||||
# Check cutlass module for math
|
||||
print("\n=== cutlass math-related attributes ===")
|
||||
for attr in sorted(dir(cutlass)):
|
||||
if any(k in attr.lower() for k in ['sqrt', 'log', 'exp', 'abs', 'rsqrt', 'rcp']):
|
||||
print(f" cutlass.{attr}")
|
||||
|
||||
# Check if cute.exp exists
|
||||
print(f"\n=== Key functions ===")
|
||||
print(f" cute.exp exists: {hasattr(cute, 'exp')}")
|
||||
print(f" cute.log exists: {hasattr(cute, 'log')}")
|
||||
print(f" cute.sqrt exists: {hasattr(cute, 'sqrt')}")
|
||||
print(f" cute.math exists: {hasattr(cute, 'math')}")
|
||||
|
||||
if hasattr(cute, 'math'):
|
||||
print(f" cute.math.fmax exists: {hasattr(cute.math, 'fmax')}")
|
||||
print(f" cute.math.fmin exists: {hasattr(cute.math, 'fmin')}")
|
||||
print(f" cute.math.absf exists: {hasattr(cute.math, 'absf')}")
|
||||
print(f" cute.math.sqrt exists: {hasattr(cute.math, 'sqrt')}")
|
||||
print(f" cute.math.log exists: {hasattr(cute.math, 'log')}")
|
||||
print(f" cute.math.exp exists: {hasattr(cute.math, 'exp')}")
|
||||
print(f" cute.math.rsqrt exists: {hasattr(cute.math, 'rsqrt')}")
|
||||
print(f" cute.math.rcp exists: {hasattr(cute.math, 'rcp')}")
|
||||
print(f" cute.math.sin exists: {hasattr(cute.math, 'sin')}")
|
||||
print(f" cute.math.cos exists: {hasattr(cute.math, 'cos')}")
|
||||
print(f" cute.math.copysign exists: {hasattr(cute.math, 'copysign')}")
|
||||
print(f" cute.math.clamp exists: {hasattr(cute.math, 'clamp')}")
|
||||
|
||||
# Check arch operations
|
||||
print(f"\n cute.arch.fmax exists: {hasattr(cute.arch, 'fmax')}")
|
||||
print(f" cute.arch.fmin exists: {hasattr(cute.arch, 'fmin')}")
|
||||
|
||||
# Try to find math operations in cutlass._mlir_ops or similar
|
||||
print("\n=== MLIR operations ===")
|
||||
for mod_name in ['cutlass._mlir_ops', 'cutlass.mlir', 'cutlass.cute._mlir']:
|
||||
try:
|
||||
mod = __import__(mod_name, fromlist=[''])
|
||||
math_attrs = [a for a in dir(mod) if any(k in a.lower() for k in ['sqrt', 'log', 'exp', 'abs', 'rsqrt'])]
|
||||
if math_attrs:
|
||||
print(f" {mod_name}: {math_attrs}")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
print("\nDone.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_cute_math_api()
|
||||
Reference in New Issue
Block a user