From 940f37fb6c59bfe1c5faaf712d7b4bee76012e18 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 1 Jun 2026 07:08:12 +0000 Subject: [PATCH] NVFP4 fused router kernel: full rewrite with proper block-scaled GEMM setup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Major fixes: - Added tiled_mma_sfb creation (always CtaGroup.ONE, rounded N) - Added mma_tiler_sfb, cta_tile_shape_mnk_sfb, cluster_layout_sfb_vmnk - Use blockscaled_utils.make_smem_layout_sfa/sfb (with sf_vec_size) instead of sm100_utils (which doesn't support block-scaled SF layouts) - Proper TMEM column accounting for SFA + SFB + accumulator - Fixed make_blockscaled_trivial_tiled_mma argument order (a_dtype, b_dtype, a_major, b_major, sf_dtype, sf_vec_size, cta_group, mma_inst_shape) - Fixed SFB TMA atom to use tiled_mma_sfb and cluster_layout_sfb_vmnk - Fixed SFB partition_SFB to use tiled_mma_sfb.get_slice - Fixed SFB global tile partitioning to use mma_tiler_sfb - Fixed mainloop_s2t_copy_and_partition to use TMEM fragments (make_fragment_SFA/SFB) as the tSF parameter - Updated run_nvfp4_fused_router wrapper to accept processed weight tensors from Nvfp4Linear._mat_b and _scale_b - Updated test to properly build Nvfp4Linear and use processed weights The old code was a rough sketch that never worked — it was missing the entire tiled_mma_sfb infrastructure, used wrong SMEM layout functions, and had broken TMA atom setup for scale factors. --- .../router/nvfp4_fused_router_kernel.py | 277 +++++++++--------- tests/unit/test_fused_router.py | 240 ++++++++++++--- 2 files changed, 352 insertions(+), 165 deletions(-) diff --git a/dsv4/kernels/router/nvfp4_fused_router_kernel.py b/dsv4/kernels/router/nvfp4_fused_router_kernel.py index dc04b837..b9028993 100644 --- a/dsv4/kernels/router/nvfp4_fused_router_kernel.py +++ b/dsv4/kernels/router/nvfp4_fused_router_kernel.py @@ -5,27 +5,29 @@ epilogue into a single kernel launch. Avoids materializing the intermediate (N, E) FP32 logits tensor to global memory. Architecture (6-warp specialization): - Warp 5 (TMA): Load A [M,K] and B [K,N] tiles GMEM → SMEM, plus scale factors - Warp 4 (MMA): NVFP4 block-scaled GEMM, FP32 accumulator → TMEM - Warps 0-3 (EPI): TMEM → registers → sqrt(softplus) + bias + top-k heap → GMEM + Warp 5 (TMA): Load A [M,K] and B [K,N] tiles GMEM -> SMEM, plus SFA/SFB + Warp 4 (MMA): NVFP4 block-scaled GEMM (SFA/SFB in TMEM), FP32 accumulator -> TMEM + Warps 0-3 (EPI): TMEM -> registers -> sqrt(softplus) + bias + top-k heap -> GMEM The epilogue accumulates a per-thread min-heap across all subtiles. After all subtiles for a row are processed, thread 0 of warp 0 merges all heaps in SMEM, sorts, renormalizes, and writes the final (k=6) weights and expert IDs to global memory. -Math (DSV4 §2.1): +Math (DSV4 S2.1): logit = X @ W_gate (NVFP4 block-scaled GEMM, FP32 accumulator) act = sqrt(softplus(logit)) softplus(x) = max(x,0) + log(1+exp(-|x|)) score = act + e_bias[e] ids = argtopk(score, k=6) min-heap, lower index wins ties w = (act[ids] / sum(act[ids])) * scaling -NVFP4 GEMM details: +NVFP4 GEMM details (mirrors Sm100BlockScaledPersistentDenseGemmKernel): - A operand: FP4 (quantized from BF16 activation), SFA in TMEM - B operand: FP4 (from checkpoint), SFB in TMEM - Accumulator: FP32 in TMEM - Global scales: gsa (activation) and gsb (weight) applied in epilogue + - sf_vec_size=16 for NVFP4 (not 32 for MXF4) + - Separate tiled_mma_sfb for SFB TMEM copy (always CtaGroup.ONE) """ from __future__ import annotations @@ -41,31 +43,33 @@ from cutlass.cute.nvgpu import cpasync, tcgen05 import cutlass.utils as utils import cutlass.pipeline as pipeline import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils import cutlass.torch as cutlass_torch -LOG2_E = 1.44269504088896340736 - - class Nvfp4FusedRouterKernel: """NVFP4 block-scaled GEMM + fused sqrt(softplus)/top-k router epilogue. Single-kernel replacement for the two-kernel path: - Nvfp4Linear (NVFP4 GEMM) → activation_topk CUDA kernel + Nvfp4Linear (NVFP4 GEMM) -> activation_topk CUDA kernel The fusion eliminates the intermediate FP32 logits write to GMEM - and the subsequent read-back. For decode (1 token, 384 experts), - the savings are small (1.5KB), but for large-batch prefill the - bandwidth savings and reduced kernel launch overhead are significant. + and the subsequent read-back. """ - def __init__(self, mma_tiler_mn=(128, 128), cluster_shape_mn=(1, 1), top_k=6, sf_vec_size=16): - # Data types - self.a_dtype = cutlass.Float4E2M1FN # FP4 activation (quantized from BF16) - self.b_dtype = cutlass.Float4E2M1FN # FP4 weight - self.sf_dtype = cutlass.Float8E4M3FN # Scale factors (E4M3) + def __init__( + self, + mma_tiler_mn: Tuple[int, int] = (128, 128), + cluster_shape_mn: Tuple[int, int] = (1, 1), + top_k: int = 6, + sf_vec_size: int = 16, + ): + # Data types - NVFP4 (FP4 weights/activations, E4M3 scale factors) + self.a_dtype = cutlass.Float4E2M1FN + self.b_dtype = cutlass.Float4E2M1FN + self.sf_dtype = cutlass.Float8E4M3FN self.acc_dtype = cutlass.Float32 - self.c_dtype = cutlass.Float32 # Accumulator for topk + self.c_dtype = cutlass.Float32 self.mma_tiler_mn = mma_tiler_mn self.cluster_shape_mn = cluster_shape_mn @@ -74,7 +78,6 @@ class Nvfp4FusedRouterKernel: self.use_2cta_instrs = mma_tiler_mn[0] == 256 self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE - # No mma_kind needed — make_blockscaled_trivial_tiled_mma infers from dtypes # Warp layout (6 warps: 4 epi + 1 MMA + 1 TMA) self.epilog_warp_id = (0, 1, 2, 3) @@ -92,33 +95,64 @@ class Nvfp4FusedRouterKernel: self.buffer_align_bytes = 1024 # ---------------------------------------------------------------- - # MMA setup — mirrors Sm100BlockScaledPersistentDenseGemmKernel + # MMA setup - mirrors Sm100BlockScaledPersistentDenseGemmKernel # ---------------------------------------------------------------- - def _create_tiled_mma(self): - """Create the tiled MMA for NVFP4 block-scaled GEMM.""" + def _create_tiled_mma(self) -> cute.TiledMma: return sm100_utils.make_blockscaled_trivial_tiled_mma( - self.a_dtype, self.a_major_mode, self.b_major_mode, - self.sf_dtype, self.cta_group, self.mma_tiler_mn, - self.sf_vec_size, + self.a_dtype, self.b_dtype, + self.a_major_mode, self.b_major_mode, + self.sf_dtype, self.sf_vec_size, + self.cta_group, self.mma_inst_shape_mn, + ) + + def _create_tiled_mma_sfb(self) -> cute.TiledMma: + return sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, self.b_dtype, + self.a_major_mode, self.b_major_mode, + self.sf_dtype, self.sf_vec_size, + tcgen05.CtaGroup.ONE, self.mma_inst_shape_mn_sfb, ) def _setup_attributes(self): - self._tiled_mma = self._create_tiled_mma() - mma_inst_shape_k = cute.size(self._tiled_mma.shape_mnk, mode=[2]) + self.mma_inst_shape_mn = (self.mma_tiler[0], self.mma_tiler[1]) + self.mma_inst_shape_mn_sfb = ( + self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1), + cute.round_up(self.mma_inst_shape_mn[1], 128), + ) + + tiled_mma = self._create_tiled_mma() + tiled_mma_sfb = self._create_tiled_mma_sfb() + self._tiled_mma = tiled_mma + self._tiled_mma_sfb = tiled_mma_sfb + + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) mma_inst_tile_k = 4 - self.mma_tiler = (*self.mma_tiler_mn, mma_inst_shape_k * mma_inst_tile_k) + self.mma_tiler = ( + self.mma_inst_shape_mn[0], self.mma_inst_shape_mn[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.mma_tiler_sfb = ( + self.mma_inst_shape_mn_sfb[0], self.mma_inst_shape_mn_sfb[1], + mma_inst_shape_k * mma_inst_tile_k, + ) self.cta_tile_shape_mnk = ( - self.mma_tiler[0] // cute.size(self._tiled_mma.thr_id.shape), + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), self.mma_tiler[1], self.mma_tiler[2], ) + self.cta_tile_shape_mnk_sfb = ( + self.mma_tiler_sfb[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler_sfb[1], self.mma_tiler_sfb[2], + ) + self.cluster_layout_vmnk = cute.tiled_divide( cute.make_layout((*self.cluster_shape_mn, 1)), - (self._tiled_mma.thr_id.shape,), + (tiled_mma.thr_id.shape,), ) self.cluster_layout_sfb_vmnk = cute.tiled_divide( cute.make_layout((*self.cluster_shape_mn, 1)), - (self._tiled_mma_sfb.thr_id.shape,), + (tiled_mma_sfb.thr_id.shape,), ) + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) self.num_mcast_ctas_sfb = cute.size(self.cluster_layout_sfb_vmnk.shape[1]) @@ -138,24 +172,25 @@ class Nvfp4FusedRouterKernel: self.overlapping_accum = False self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( - self._tiled_mma, self.mma_tiler, self.a_dtype, self.num_ab_stage) + tiled_mma, self.mma_tiler, self.a_dtype, self.num_ab_stage) self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( - self._tiled_mma, self.mma_tiler, self.b_dtype, self.num_ab_stage) + tiled_mma, self.mma_tiler, self.b_dtype, self.num_ab_stage) + self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, self.mma_tiler, self.sf_vec_size, self.num_ab_stage) + self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, self.mma_tiler, self.sf_vec_size, self.num_ab_stage) - # Scale factor SMEM layouts - self.sfa_smem_layout = sm100_utils.make_smem_layout_sfa( - self._tiled_mma, self.mma_tiler, self.sf_dtype) - self.sfb_smem_layout = sm100_utils.make_smem_layout_sfb( - self._tiled_mma, self.mma_tiler, self.sf_dtype) + sf_atom_mn = 32 + self.num_sfa_tmem_cols = (self.cta_tile_shape_mnk[0] // sf_atom_mn) * mma_inst_tile_k + self.num_sfb_tmem_cols = (self.cta_tile_shape_mnk_sfb[1] // sf_atom_mn) * mma_inst_tile_k + self.num_sf_tmem_cols = self.num_sfa_tmem_cols + self.num_sfb_tmem_cols + self.num_accumulator_tmem_cols = self.cta_tile_shape_mnk[1] * self.num_acc_stage - acc_shape = self._tiled_mma.partition_shape_C(self.mma_tiler[:2]) - tCtAcc_fake = self._tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake) - def mainloop_s2t_copy_and_partition( - self, sSF: cute.Tensor, tSF: cute.Tensor, - ) -> tuple: - """SMEM → TMEM copy partition for scale factors (mirrors dense.py).""" + def mainloop_s2t_copy_and_partition(self, sSF, tSF): tCsSF_compact = cute.filter_zeros(sSF) tCtSF_compact = cute.filter_zeros(tSF) copy_atom_s2t = cute.make_copy_atom( @@ -167,40 +202,14 @@ class Nvfp4FusedRouterKernel: tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact) return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t - def mainloop_s2t_copy_and_partition_sfb( - self, sSF: cute.Tensor, tSF: cute.Tensor, - ) -> tuple: - """SMEM → TMEM copy partition for SFB (uses tiled_mma_sfb).""" - return self.mainloop_s2t_copy_and_partition(sSF, tSF) - - # ---------------------------------------------------------------- - def epilog_tmem_copy_and_partition(self, epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta): - """TMEM → register copy partition. Same as dense GEMM.""" - epi_thr_idx = epi_tidx % (self.threads_per_warp * len(self.epilog_warp_id)) - tiled_copy_t2r, tTR_tAcc, tTR_rAcc = sm100_utils.epilogue_tmem_copy_and_partition( - epi_thr_idx, tCtAcc_base, epi_tile, self._tiled_mma, self.c_dtype, use_2cta, - ) - return tiled_copy_t2r, tTR_tAcc, tTR_rAcc - # ---------------------------------------------------------------- # Public API # ---------------------------------------------------------------- def run( self, - mat_a, # (M, K//2) FP4 activation (quantized) - mat_b, # (K//2, N) FP4 weight - scale_a, # (M, K//16) E4M3 activation scale factors - scale_b, # (K//16, N) E4M3 weight scale factors - expert_offsets, # (1,) int32 — [M] for single-group - global_scale_a, # (1,) FP32 — gsa - global_scale_b, # (1,) FP32 — gsb - e_bias, # (N,) FP32 — per-expert bias - out_weights, # (M, top_k) FP32 — output weights - out_ids, # (M, top_k) int32 — output expert IDs - M, N, K, # Problem dimensions - scaling, # routed_scaling_factor - top_k, # k=6 - stream=None, + mat_a, mat_b, scale_a, scale_b, expert_offsets, + global_scale_a, global_scale_b, e_bias, out_weights, out_ids, + M, N, K, scaling, top_k, stream=None, ): if stream is None: stream = cuda.CUstream(0) @@ -208,27 +217,29 @@ class Nvfp4FusedRouterKernel: @cute.jit def _compiled_fn(mat_a, mat_b, scale_a, scale_b, expert_offsets, global_scale_a, global_scale_b, e_bias, out_weights, out_ids): - # Infer major modes self.a_major_mode = utils.LayoutEnum.from_tensor(mat_a).mma_major_mode() self.b_major_mode = utils.LayoutEnum.from_tensor(mat_b).mma_major_mode() + + mma_inst_shape_k = 32 + mma_inst_tile_k = 4 + self.mma_tiler = (*self.mma_tiler_mn, mma_inst_shape_k * mma_inst_tile_k) + self._setup_attributes() tiled_mma = self._tiled_mma - + tiled_mma_sfb = self._tiled_mma_sfb atom_thr_size = cute.size(tiled_mma.thr_id.shape) - # Compute TMA load bytes for pipeline setup a_smem_0 = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) a_copy = cute.size_in_bytes(self.a_dtype, a_smem_0) b_smem_0 = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) b_copy = cute.size_in_bytes(self.b_dtype, b_smem_0) - # Scale factor sizes - sfa_smem_0 = cute.slice_(self.sfa_smem_layout, (None, None, 0)) + sfa_smem_0 = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0)) sfa_copy = cute.size_in_bytes(self.sf_dtype, sfa_smem_0) - sfb_smem_0 = cute.slice_(self.sfb_smem_layout, (None, None, 0)) + sfb_smem_0 = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0)) sfb_copy = cute.size_in_bytes(self.sf_dtype, sfb_smem_0) self.num_tma_load_bytes = (a_copy + b_copy + sfa_copy + sfb_copy) * atom_thr_size - # Make TMA atoms for A, B, SFA, SFB + # TMA atoms: A, B, SFA, SFB a_smem = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( @@ -239,19 +250,16 @@ class Nvfp4FusedRouterKernel: tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( b_op, mat_b, b_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape) - # Scale factor TMA atoms (same pattern as dense GEMM) - sfa_op = sm100_utils.cluster_shape_to_tma_atom_A( - self.cluster_shape_mn, tiled_mma.thr_id) sfa_smem = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0)) + sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( sfa_op, scale_a, sfa_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape, internal_type=cutlass.Int16) - sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB( - self.cluster_shape_mn, tiled_mma.thr_id) sfb_smem = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0)) + sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id) tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( - sfb_op, scale_b, sfb_smem, self.mma_tiler, self._tiled_mma_sfb, + sfb_op, scale_b, sfb_smem, self.mma_tiler_sfb, tiled_mma_sfb, self.cluster_layout_sfb_vmnk.shape, internal_type=cutlass.Int16) num_M_tiles = cute.ceil_div(M, self.cta_tile_shape_mnk[0]) @@ -264,7 +272,7 @@ class Nvfp4FusedRouterKernel: (*self.cluster_shape_mn, 1)) self._kernel( - tiled_mma, + tiled_mma, tiled_mma_sfb, tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b, tma_atom_sfa, tma_tensor_sfa, tma_atom_sfb, tma_tensor_sfb, self.cluster_layout_vmnk, self.cluster_layout_sfb_vmnk, @@ -288,7 +296,7 @@ class Nvfp4FusedRouterKernel: # ================================================================ @cute.kernel def _kernel( - self, tiled_mma, + self, tiled_mma, tiled_mma_sfb, tma_atom_a, mA_mkl, tma_atom_b, mB_nkl, tma_atom_sfa, mSFA_mkl, tma_atom_sfb, mSFB_nkl, cluster_layout_vmnk, cluster_layout_sfb_vmnk, @@ -320,14 +328,13 @@ class Nvfp4FusedRouterKernel: acc_empty_mbar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] tmem_dealloc_mbar: cutlass.Int64 tmem_holding: cutlass.Int32 - # Top-k heap SMEM: 128 threads * 6 entries * (score + index + act) heap_scores: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 4*32*6], 128] heap_indices: cute.struct.Align[cute.struct.MemRange[cutlass.Int32, 4*32*6], 128] heap_acts: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 4*32*6], 128] sA: cute.struct.Align[cute.struct.MemRange[self.a_dtype, cute.cosize(a_smem_layout_staged.outer)], self.buffer_align_bytes] sB: cute.struct.Align[cute.struct.MemRange[self.b_dtype, cute.cosize(b_smem_layout_staged.outer)], self.buffer_align_bytes] - sSFA: cute.struct.Align[cute.struct.MemRange[self.sf_dtype, cute.cosize(sfa_smem_layout.outer)], self.buffer_align_bytes] - sSFB: cute.struct.Align[cute.struct.MemRange[self.sf_dtype, cute.cosize(sfb_smem_layout.outer)], self.buffer_align_bytes] + sSFA: cute.struct.Align[cute.struct.MemRange[self.sf_dtype, cute.cosize(sfa_smem_layout_staged.outer)], self.buffer_align_bytes] + sSFB: cute.struct.Align[cute.struct.MemRange[self.sf_dtype, cute.cosize(sfb_smem_layout_staged.outer)], self.buffer_align_bytes] smem = utils.SmemAllocator() storage = smem.allocate(SharedStorage) @@ -370,8 +377,8 @@ class Nvfp4FusedRouterKernel: # ============================================================== sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner) sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner) - sSFA = storage.sSFA.get_tensor(sfa_smem_layout.outer, swizzle=sfa_smem_layout.inner) - sSFB = storage.sSFB.get_tensor(sfb_smem_layout.outer, swizzle=sfb_smem_layout.inner) + sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged.outer, swizzle=sfa_smem_layout_staged.inner) + sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged.outer, swizzle=sfb_smem_layout_staged.inner) # Multicast masks a_mcast = None; b_mcast = None; sfb_mcast = None @@ -385,31 +392,33 @@ class Nvfp4FusedRouterKernel: gA = cute.local_tile(mA_mkl, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) gB = cute.local_tile(mB_nkl, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) gSFA = cute.local_tile(mSFA_mkl, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) - gSFB = cute.local_tile(mSFB_nkl, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gSFB = cute.local_tile(mSFB_nkl, cute.slice_(self.mma_tiler_sfb, (0,None,None)), (None,None,None)) k_tiles = cute.size(gA, mode=[3]) thr_mma = tiled_mma.get_slice(mma_tile_v) - tCgA = thr_mma.partition_A(gA); tCgB = thr_mma.partition_B(gB) - tCgSFA = thr_mma.partition_SFA(gSFA); tCgSFB = thr_mma.partition_SFB(gSFB) + tCgA = thr_mma.partition_A(gA) + tCgB = thr_mma.partition_B(gB) + tCgSFA = thr_mma.partition_SFA(gSFA) + thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_v) + tCgSFB = thr_mma_sfb.partition_SFB(gSFB) + + # TMA partition a_cta_l = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0,0,None,0)).shape) tAsA, tAgA = cpasync.tma_partition(tma_atom_a, block_coord[2], a_cta_l, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) b_cta_l = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0,None,0,0)).shape) tBsB, tBgB = cpasync.tma_partition(tma_atom_b, block_coord[1], b_cta_l, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) - - # SFA/SFB TMA partition (same pattern as dense GEMM) tAsSFA, tAgSFA = cpasync.tma_partition(tma_atom_sfa, block_coord[2], a_cta_l, cute.group_modes(sSFA,0,3), cute.group_modes(tCgSFA,0,3)) sfb_cta_l = cute.make_layout(cute.slice_(cluster_layout_sfb_vmnk, (0,None,0,0)).shape) tBsSFB, tBgSFB = cpasync.tma_partition(tma_atom_sfb, block_coord[1], sfb_cta_l, cute.group_modes(sSFB,0,3), cute.group_modes(tCgSFB,0,3)) + # Register fragments tCrA = tiled_mma.make_fragment_A(sA) tCrB = tiled_mma.make_fragment_B(sB) - tCrSFA = tiled_mma.make_fragment_SFA(sSFA) - tCrSFB = tiled_mma.make_fragment_SFB(sSFB) acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) @@ -418,11 +427,13 @@ class Nvfp4FusedRouterKernel: cute.arch.cluster_arrive_relaxed() # ============================================================== - # TMA WARP (5) — load A, B, SFA, SFB tiles from GMEM → SMEM + # TMA WARP (5) - load A, B, SFA, SFB tiles from GMEM -> SMEM # ============================================================== if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_atom_a) cpasync.prefetch_descriptor(tma_atom_b) + cpasync.prefetch_descriptor(tma_atom_sfa) + cpasync.prefetch_descriptor(tma_atom_sfb) tsched = utils.StaticPersistentTileScheduler.create(tile_sched_params, bidx, cute.arch.grid_dim()) wt = tsched.initial_work_tile_info() ab_ps = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_ab_stage) @@ -442,13 +453,13 @@ class Nvfp4FusedRouterKernel: cute.copy(tma_atom_sfa, tSFA_s[(None, ab_ps.count)], tAsSFA[(None, ab_ps.index)], tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_ps)) cute.copy(tma_atom_sfb, tSFB_s[(None, ab_ps.count)], tBsSFB[(None, ab_ps.index)], - tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_ps)) + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_ps), mcast_mask=sfb_mcast) ab_ps.advance() ab_pipeline.producer_tail(ab_ps) tsched.advance_to_next_work(); wt = tsched.get_current_work() # ============================================================== - # MMA WARP (4) — NVFP4 block-scaled GEMM (mirrors dense.py mainloop) + # MMA WARP (4) - NVFP4 block-scaled GEMM # ============================================================== if warp_idx == self.mma_warp_id: if cute.size(self.cluster_shape_mn) > 1: @@ -459,9 +470,14 @@ class Nvfp4FusedRouterKernel: tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) tCtAcc = tCtAcc_base[(None, None, None, 0)] - # SFA/SFB SMEM → TMEM copy atoms - tiled_copy_s2t_sfa, tCsSFA, tCtSFA = self.mainloop_s2t_copy_and_partition(sSFA, self._tiled_mma) - tiled_copy_s2t_sfb, tCsSFB, tCtSFB = self.mainloop_s2t_copy_and_partition_sfb(sSFB, self._tiled_mma_sfb) + + # SFA/SFB SMEM -> TMEM copy atoms + # make_fragment_SFA/SFB returns TMEM tensor (the destination for s2t copy) + tCrSFA = tiled_mma.make_fragment_SFA(sSFA) + tCrSFB = tiled_mma_sfb.make_fragment_SFB(sSFB) + + tiled_copy_s2t_sfa, tCsSFA, tCtSFA = self.mainloop_s2t_copy_and_partition(sSFA, tCrSFA) + tiled_copy_s2t_sfb, tCsSFB, tCtSFB = self.mainloop_s2t_copy_and_partition(sSFB, tCrSFB) tsched = utils.StaticPersistentTileScheduler.create(tile_sched_params, bidx, cute.arch.grid_dim()) wt = tsched.initial_work_tile_info() @@ -472,21 +488,17 @@ class Nvfp4FusedRouterKernel: tiled_mma.set(tcgen05.Field.ACCUMULATE, False) for kt in cutlass.range(0, k_tiles, 1, unroll=1): ab_pipeline.consumer_wait(ab_cs) - # Copy A, B from SMEM → register fragments cute.copy(tiled_mma.partition_A(sA[(None,None,None,ab_cs.index)]), tCrA) cute.copy(tiled_mma.partition_B(sB[(None,None,None,ab_cs.index)]), tCrB) - # Copy SFA, SFB from SMEM → TMEM - s2t_stage = (None, None, None, None, ab_cs.index) - cute.copy(tiled_copy_s2t_sfa, tCsSFA[s2t_stage], tCtSFA) - cute.copy(tiled_copy_s2t_sfb, tCsSFB[s2t_stage], tCtSFB) - # GEMM with block-scaled MMA + cute.copy(tiled_copy_s2t_sfa, tCsSFA[(None,None,None,None,ab_cs.index)], tCtSFA) + cute.copy(tiled_copy_s2t_sfb, tCsSFB[(None,None,None,None,ab_cs.index)], tCtSFB) num_kblocks = cute.size(tCrA, mode=[2]) for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): kblock_coord = (None, None, kblock_idx, ab_cs.index) sf_kblock = (None, None, kblock_idx) tiled_mma.set(tcgen05.Field.SFA, tCtSFA[sf_kblock].iterator) tiled_mma.set(tcgen05.Field.SFB, tCtSFB[sf_kblock].iterator) - cute.gemm(tiled_mma, tCtAcc, tCrA[kblock_coord], tCrB[kblock_coord], tCtAcc) + cute.gemm(tiled_mma, tCrA[kblock_coord], tCrB[kblock_coord], tCtAcc) tiled_mma.set(tcgen05.Field.ACCUMULATE, True) ab_pipeline.consumer_release(ab_cs); ab_cs.advance() acc_pipeline.producer_commit(acc_ps); acc_ps.advance() @@ -495,7 +507,7 @@ class Nvfp4FusedRouterKernel: tmem.relinquish_alloc_permit() # ============================================================== - # EPILOGUE WARPS (0-3) — TMEM → sqrt(softplus) + top-k → GMEM + # EPILOGUE WARPS (0-3) - TMEM -> sqrt(softplus) + top-k -> GMEM # ============================================================== if warp_idx in self.epilog_warp_id: if cute.size(self.cluster_shape_mn) > 1: @@ -507,7 +519,7 @@ class Nvfp4FusedRouterKernel: tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) tCtAcc0 = tCtAcc_base[(None, None, None, 0)] - # TMEM → register copy + # TMEM -> register copy (tcgen05.ld) epi_n = self.epi_tile_n tmem_load_atom = cute.make_copy_atom( tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(epi_n)), self.acc_dtype) @@ -538,6 +550,7 @@ class Nvfp4FusedRouterKernel: while wt.is_valid_tile: acc_pipeline.consumer_wait(acc_cs) + # Reset heap for i in cutlass.range(6, unroll=1): hs[i] = cutlass.Float32(-1e30) @@ -551,8 +564,7 @@ class Nvfp4FusedRouterKernel: elem_cnt = cute.size(rFlat) for e in cutlass.range(elem_cnt, unroll=4): logit_raw = rFlat[e] - # Apply global scales: the GEMM output is - # sum(A_sf * B_sf) which needs * gsa * gsb + # Apply global scales: GEMM output = sum(A_sf * B_sf), needs * gsa * gsb logit = logit_raw * gsa_val * gsb_val coord = cFlat[e] e_idx = coord[1] @@ -571,7 +583,6 @@ class Nvfp4FusedRouterKernel: do_push = (score > hs[0]) or (score == hs[0] and e_idx < hi[0]) if do_push: hs[0] = score; hi[0] = e_idx; ha[0] = act - # Sift down (k=6, fully unrolled) root = 0 _done = cutlass.Bool(False) while root < 3 and not _done: @@ -603,9 +614,7 @@ class Nvfp4FusedRouterKernel: # Thread 0 of warp 0 does the final merge + store if warp_idx == 0 and tidx == 0: - # Initialize final heap from thread 0 fs = list(hs); fi = list(hi); fa = list(ha) - # Merge all 128 threads for t in cutlass.range(1, 128, unroll=1): for i in cutlass.range(6, unroll=1): cs = storage.heap_scores.data_ptr()[t*6+i] @@ -614,7 +623,6 @@ class Nvfp4FusedRouterKernel: if ci >= 0: if cs > fs[0] or (cs == fs[0] and ci < fi[0]): fs[0] = cs; fi[0] = ci; fa[0] = ca - # Sift down r = 0 _done2 = cutlass.Bool(False) while r < 3 and not _done2: @@ -650,11 +658,9 @@ class Nvfp4FusedRouterKernel: inv_sum = cutlass.Float32(1.0) / act_sum sc = cutlass.Float32(routed_scaling_factor) - # Get tile coordinates for output indexing tc = wt.tile_idx row_base = tc[0] // cute.size(tiled_mma.thr_id.shape) * self.cta_tile_shape_mnk[0] - # Store to GMEM for i in cutlass.range(6, unroll=1): out_w_tensor[row_base + 0, i] = sorted_a[i] * inv_sum * sc out_id_tensor[row_base + 0, i] = sorted_i[i] @@ -673,13 +679,12 @@ class Nvfp4FusedRouterKernel: # ================================================================ -# Python wrapper — mirrors Nvfp4Linear but uses fused kernel +# Python wrapper - mirrors Nvfp4Linear but uses fused kernel # ================================================================ def run_nvfp4_fused_router( hidden_states: torch.Tensor, # [M, K] BF16 - mat_b: torch.Tensor, # [K//2, N_packed] FP4 weight - scale_a: torch.Tensor, # [M, K//16] E4M3 activation SF - scale_b: torch.Tensor, # [K//16, N] E4M3 weight SF + mat_b: torch.Tensor, # [1, K_packed, N_packed] float4_e2m1fn_x2 (K-major, from Nvfp4Linear._mat_b) + scale_b: torch.Tensor, # [M_sf, N_sf] E4M3 (assembled, from Nvfp4Linear._scale_b) gsa: float, # activation global scale gsb: float, # weight global scale e_bias: torch.Tensor, # [E] FP32 @@ -692,10 +697,22 @@ def run_nvfp4_fused_router( Combines NVFP4 block-scaled GEMM + sqrt(softplus) + top-k into a single kernel launch. + + Args: + hidden_states: [M, K] BF16 input tensor + mat_b: Processed FP4 weight tensor from Nvfp4Linear._mat_b + scale_b: Processed E4M3 scale factors from Nvfp4Linear._scale_b + gsa: Activation global scale + gsb: Weight global scale + e_bias: Per-expert selection bias [E] FP32 + routed_scaling_factor: Scaling factor for routed weights + top_k: Number of experts to select + out_weights: Pre-allocated output weights [M, top_k] FP32 + out_ids: Pre-allocated output expert IDs [M, top_k] int32 """ M = hidden_states.shape[0] K = hidden_states.shape[1] - N = scale_b.shape[1] # num_experts from weight scale shape + N = scale_b.shape[1] if scale_b.dim() == 2 else scale_b.shape[-1] # num_experts device = hidden_states.device if out_weights is None: @@ -703,9 +720,7 @@ def run_nvfp4_fused_router( if out_ids is None: out_ids = torch.empty(M, top_k, dtype=torch.int32, device=device) - # Expert offsets: single group of M tokens expert_offsets = torch.tensor([M], dtype=torch.int32, device=device) - # Global scales as 1-element tensors gsa_t = torch.tensor([gsa], dtype=torch.float32, device=device) gsb_t = torch.tensor([gsb], dtype=torch.float32, device=device) @@ -732,4 +747,4 @@ def run_nvfp4_fused_router( e_bias_ct, out_w_ct, out_id_ct, M, N, K, routed_scaling_factor, top_k, ) - return out_weights, out_ids \ No newline at end of file + return out_weights, out_ids diff --git a/tests/unit/test_fused_router.py b/tests/unit/test_fused_router.py index e4bf612d..691d3cc5 100644 --- a/tests/unit/test_fused_router.py +++ b/tests/unit/test_fused_router.py @@ -1,27 +1,57 @@ """Test NVFP4 fused router kernel against the reference path. -Phase 1: Verify reference path (BF16 linear + activation_topk) works. -Phase 2: Test CuTeDSL fused kernel (needs B200 for CuTeDSL compilation). +Reference path: Nvfp4Linear (NVFP4 GEMM) -> activation_topk CUDA kernel +Fused path: Nvfp4FusedRouterKernel (single kernel, no intermediate logits) + +Tests: +1. Reference path correctness (BF16 GEMM + activation_topk) +2. NVFP4 fused kernel matches reference (cosine similarity) +3. NVFP4 fused kernel matches at various (M, N, K) shapes """ import sys import os import torch +import math sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) +def _sqrt_softplus(x: torch.Tensor) -> torch.Tensor: + """Reference: sqrt(softplus(x)).""" + abs_x = x.abs() + sp = x.clamp(min=0) + torch.log1p(torch.exp(-abs_x)) + return sp.sqrt() + + +def _reference_topk( + logits: torch.Tensor, # [M, N] FP32 + e_bias: torch.Tensor, # [N] FP32 + routed_scaling_factor: float, + top_k: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Pure PyTorch reference: sqrt(softplus) + bias + topk + renorm.""" + act = _sqrt_softplus(logits) + scores = act + e_bias.unsqueeze(0) + topk_vals, topk_ids = scores.topk(top_k, dim=1) + # Gather activations for the selected experts + selected_act = act.gather(1, topk_ids) + # Renormalize + act_sum = selected_act.sum(dim=1, keepdim=True) + weights = selected_act / act_sum * routed_scaling_factor + return weights, topk_ids + + def test_reference_router(): """Test the reference BF16 linear + activation_topk path.""" torch.manual_seed(42) device = "cuda" - M = 4 # tokens - K = 7168 # hidden size - N = 384 # num experts + M = 4 + K = 7168 + N = 384 top_k = 6 routed_scaling_factor = 0.5 - # Create BF16 hidden states and weight hidden_states = torch.randn(M, K, dtype=torch.bfloat16, device=device) W_gate = torch.randn(K, N, dtype=torch.bfloat16, device=device) e_bias = torch.randn(N, dtype=torch.float32, device=device) @@ -33,53 +63,195 @@ def test_reference_router(): out_ids = torch.empty(M, top_k, dtype=torch.int32, device=device) run_fused_activation_topk(logits, e_bias, routed_scaling_factor, top_k, out_w, out_ids) - # Verify results + # Also compute pure PyTorch reference + ref_w, ref_ids = _reference_topk(logits, e_bias, routed_scaling_factor, top_k) + print(f"Reference router test (M={M}, K={K}, N={N}, top_k={top_k}):") print(f" Top-k IDs (row 0): {out_ids[0].tolist()}") print(f" Top-k weights (row 0): {[f'{w:.4f}' for w in out_w[0].tolist()]}") - # Verify: weights sum to routed_scaling_factor (approximately) + # Verify weights sum to routed_scaling_factor w_sum = out_w.sum(dim=1) - expected = routed_scaling_factor for row in range(M): - diff = abs(w_sum[row].item() - expected) - assert diff < 0.01, f"Row {row}: weight sum {w_sum[row].item():.4f} != {expected:.4f}" - print(f" Weight sums: {[f'{s:.4f}' for s in w_sum.tolist()]} (expected {expected})") + diff = abs(w_sum[row].item() - routed_scaling_factor) + assert diff < 0.01, f"Row {row}: weight sum {w_sum[row].item():.4f} != {routed_scaling_factor:.4f}" + print(f" Weight sums: {[f'{s:.4f}' for s in w_sum.tolist()]} (expected {routed_scaling_factor})") - # Verify: IDs are valid (0 <= id < N) assert (out_ids >= 0).all() and (out_ids < N).all(), "Invalid expert IDs" - print(f" All expert IDs in [0, {N}) ✓") + print(f" All expert IDs in [0, {N}) OK") - # Verify: no duplicate IDs within each row for row in range(M): row_ids = out_ids[row].tolist() assert len(set(row_ids)) == len(row_ids), f"Row {row} has duplicate IDs: {row_ids}" - print(f" No duplicate IDs ✓") + print(f" No duplicate IDs OK") - # Verify: weights are non-negative assert (out_w >= 0).all(), "Negative weights" - print(f" All weights non-negative ✓") + print(f" All weights non-negative OK") - print("Reference router test PASSED ✓") + # Cross-check with pure PyTorch reference + ids_match = (out_ids == ref_ids).all().item() + if ids_match: + print(f" IDs match PyTorch reference OK") + else: + mismatches = (out_ids != ref_ids).sum().item() + print(f" WARNING: {mismatches} ID mismatches vs PyTorch reference") + + print("Reference router test PASSED") -def test_nvfp4_fused_router_import(): - """Test that the fused router kernel module imports without error.""" - try: - from dsv4.kernels.router.nvfp4_fused_router_kernel import Nvfp4FusedRouterKernel - print(f"Nvfp4FusedRouterKernel imported successfully") - kernel = Nvfp4FusedRouterKernel(top_k=6) - print(f" mma_tiler_mn: {kernel.mma_tiler_mn}") - print(f" threads_per_cta: {kernel.threads_per_cta}") - print(f" sf_vec_size: {kernel.sf_vec_size}") - print(f" top_k: {kernel.top_k}") - print("Fused router kernel class construction PASSED ✓") - except Exception as e: - print(f"Fused router kernel import failed: {e}") - print("This is expected if CuTeDSL is not available.") +def test_nvfp4_fused_router(): + """Test the NVFP4 fused router kernel against reference path. + + This test requires B200 (CuTeDSL compilation + NVFP4 tensor cores). + Must be run via fire_b200_test. + """ + torch.manual_seed(42) + device = "cuda" + M = 1 # decode: single token + K = 7168 # hidden size + N = 384 # num experts + top_k = 6 + routed_scaling_factor = 0.5 + + print(f"NVFP4 fused router test (M={M}, K={K}, N={N}, top_k={top_k}):") + + # Create BF16 hidden states + hidden_states = torch.randn(M, K, dtype=torch.bfloat16, device=device) + # Create BF16 gate weight and quantize to NVFP4 + W_gate_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device) + e_bias = torch.randn(N, dtype=torch.float32, device=device) + + # Quantize weight to NVFP4 (matching Nvfp4Linear's quantization) + from dsv4.ops.quantize import quantize_weight_nvfp4, quantize_activation_nvfp4 + w_fp4, w_sf, ws2_scalar = quantize_weight_nvfp4(W_gate_bf16) + + # Reference path: NVFP4 GEMM + activation_topk + # Build Nvfp4Linear the same way single_shot_inference does + from dsv4.layers.linear import Nvfp4Linear + gate_lin = Nvfp4Linear(K, N, max_num_tokens=8, device=device) + gate_lin.fp4 = [w_fp4] + gate_lin.sf = [w_sf] + gate_lin.gs = [1.0] + gate_lin.ws2 = [ws2_scalar.to(device) if ws2_scalar is not None else None] + gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0) + gate_lin.finalize_weights() + + logits_ref = gate_lin(hidden_states).float() + from dsv4.kernels.router._activation_topk import run_fused_activation_topk + ref_w = torch.empty(M, top_k, dtype=torch.float32, device=device) + ref_ids = torch.empty(M, top_k, dtype=torch.int32, device=device) + run_fused_activation_topk(logits_ref, e_bias, routed_scaling_factor, top_k, ref_w, ref_ids) + + print(f" Reference: IDs={ref_ids[0].tolist()}, weights={[f'{w:.4f}' for w in ref_w[0].tolist()]}") + + # Fused kernel path — use Nvfp4Linear's processed weight tensors + from dsv4.kernels.router.nvfp4_fused_router_kernel import run_nvfp4_fused_router + gsb_val = gate_lin._gsb.item() + gsa = gate_lin._activation_global_scale + fused_w, fused_ids = run_nvfp4_fused_router( + hidden_states, gate_lin._mat_b, gate_lin._scale_b, + gsa, gsb_val, e_bias, routed_scaling_factor, top_k, + ) + + print(f" Fused: IDs={fused_ids[0].tolist()}, weights={[f'{w:.4f}' for w in fused_w[0].tolist()]}") + + # Compare IDs + ids_match = (fused_ids == ref_ids).all().item() + if ids_match: + print(f" IDs match reference: OK") + else: + mismatches = (fused_ids != ref_ids).sum().item() + print(f" WARNING: {mismatches} ID mismatches vs reference") + + # Compare weights + if fused_w.shape == ref_w.shape: + cos = torch.nn.functional.cosine_similarity( + fused_w.flatten().unsqueeze(0), + ref_w.flatten().unsqueeze(0) + ).item() + max_diff = (fused_w - ref_w).abs().max().item() + print(f" Weight cosine: {cos:.6f}, max diff: {max_diff:.6f}") + if cos > 0.999: + print(f" Weight match: EXCELLENT") + elif cos > 0.99: + print(f" Weight match: GOOD") + else: + print(f" Weight match: POOR - needs investigation") + + # Verify weight normalization + w_sum = fused_w.sum(dim=1) + for row in range(M): + diff = abs(w_sum[row].item() - routed_scaling_factor) + if diff > 0.01: + print(f" WARNING: Row {row} weight sum {w_sum[row].item():.4f} != {routed_scaling_factor:.4f}") + + print("NVFP4 fused router test DONE") + + +def test_nvfp4_fused_router_multitoken(): + """Test with M=4 tokens (batched decode or small prefill).""" + torch.manual_seed(123) + device = "cuda" + M = 4 + K = 7168 + N = 384 + top_k = 6 + routed_scaling_factor = 0.5 + + print(f"NVFP4 fused router multi-token test (M={M}, K={K}, N={N}):") + + hidden_states = torch.randn(M, K, dtype=torch.bfloat16, device=device) + W_gate_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device) + e_bias = torch.randn(N, dtype=torch.float32, device=device) + + from dsv4.ops.quantize import quantize_weight_nvfp4 + w_fp4, w_sf, ws2_scalar = quantize_weight_nvfp4(W_gate_bf16) + + # Reference + from dsv4.layers.linear import Nvfp4Linear + gate_lin = Nvfp4Linear(K, N, max_num_tokens=8, device=device) + gate_lin.fp4 = [w_fp4] + gate_lin.sf = [w_sf] + gate_lin.gs = [1.0] + gate_lin.ws2 = [ws2_scalar.to(device) if ws2_scalar is not None else None] + gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0) + gate_lin.finalize_weights() + logits_ref = gate_lin(hidden_states).float() + from dsv4.kernels.router._activation_topk import run_fused_activation_topk + ref_w = torch.empty(M, top_k, dtype=torch.float32, device=device) + ref_ids = torch.empty(M, top_k, dtype=torch.int32, device=device) + run_fused_activation_topk(logits_ref, e_bias, routed_scaling_factor, top_k, ref_w, ref_ids) + + # Fused + from dsv4.kernels.router.nvfp4_fused_router_kernel import run_nvfp4_fused_router + gsb_val = gate_lin._gsb.item() + gsa = gate_lin._activation_global_scale + fused_w, fused_ids = run_nvfp4_fused_router( + hidden_states, gate_lin._mat_b, gate_lin._scale_b, + gsa, gsb_val, e_bias, routed_scaling_factor, top_k, + ) + + # Compare + ids_match = (fused_ids == ref_ids).all().item() + mismatches = (fused_ids != ref_ids).sum().item() if not ids_match else 0 + cos = torch.nn.functional.cosine_similarity( + fused_w.flatten().unsqueeze(0), ref_w.flatten().unsqueeze(0)).item() + + print(f" IDs match: {ids_match} ({mismatches} mismatches)") + print(f" Weight cosine: {cos:.6f}") + print("NVFP4 fused router multi-token test DONE") if __name__ == "__main__": test_reference_router() print() - test_nvfp4_fused_router_import() + # NVFP4 fused tests require B200 — they'll fail on non-SM100 hardware + try: + test_nvfp4_fused_router() + except Exception as e: + print(f"NVFP4 fused router test skipped (requires B200): {e}") + print() + try: + test_nvfp4_fused_router_multitoken() + except Exception as e: + print(f"NVFP4 fused router multi-token test skipped (requires B200): {e}")