NVFP4 fused router kernel: full rewrite with proper block-scaled GEMM setup

Major fixes:
- Added tiled_mma_sfb creation (always CtaGroup.ONE, rounded N)
- Added mma_tiler_sfb, cta_tile_shape_mnk_sfb, cluster_layout_sfb_vmnk
- Use blockscaled_utils.make_smem_layout_sfa/sfb (with sf_vec_size)
  instead of sm100_utils (which doesn't support block-scaled SF layouts)
- Proper TMEM column accounting for SFA + SFB + accumulator
- Fixed make_blockscaled_trivial_tiled_mma argument order
  (a_dtype, b_dtype, a_major, b_major, sf_dtype, sf_vec_size, cta_group, mma_inst_shape)
- Fixed SFB TMA atom to use tiled_mma_sfb and cluster_layout_sfb_vmnk
- Fixed SFB partition_SFB to use tiled_mma_sfb.get_slice
- Fixed SFB global tile partitioning to use mma_tiler_sfb
- Fixed mainloop_s2t_copy_and_partition to use TMEM fragments
  (make_fragment_SFA/SFB) as the tSF parameter
- Updated run_nvfp4_fused_router wrapper to accept processed weight
  tensors from Nvfp4Linear._mat_b and _scale_b
- Updated test to properly build Nvfp4Linear and use processed weights

The old code was a rough sketch that never worked — it was missing
the entire tiled_mma_sfb infrastructure, used wrong SMEM layout
functions, and had broken TMA atom setup for scale factors.
This commit is contained in:
2026-06-01 07:08:12 +00:00
parent 8658c8eca5
commit 940f37fb6c
2 changed files with 352 additions and 165 deletions

View File

@@ -5,27 +5,29 @@ epilogue into a single kernel launch. Avoids materializing the intermediate
(N, E) FP32 logits tensor to global memory.
Architecture (6-warp specialization):
Warp 5 (TMA): Load A [M,K] and B [K,N] tiles GMEM SMEM, plus scale factors
Warp 4 (MMA): NVFP4 block-scaled GEMM, FP32 accumulator TMEM
Warps 0-3 (EPI): TMEM registers sqrt(softplus) + bias + top-k heap GMEM
Warp 5 (TMA): Load A [M,K] and B [K,N] tiles GMEM -> SMEM, plus SFA/SFB
Warp 4 (MMA): NVFP4 block-scaled GEMM (SFA/SFB in TMEM), FP32 accumulator -> TMEM
Warps 0-3 (EPI): TMEM -> registers -> sqrt(softplus) + bias + top-k heap -> GMEM
The epilogue accumulates a per-thread min-heap across all subtiles.
After all subtiles for a row are processed, thread 0 of warp 0 merges
all heaps in SMEM, sorts, renormalizes, and writes the final (k=6)
weights and expert IDs to global memory.
Math (DSV4 §2.1):
Math (DSV4 S2.1):
logit = X @ W_gate (NVFP4 block-scaled GEMM, FP32 accumulator)
act = sqrt(softplus(logit)) softplus(x) = max(x,0) + log(1+exp(-|x|))
score = act + e_bias[e]
ids = argtopk(score, k=6) min-heap, lower index wins ties
w = (act[ids] / sum(act[ids])) * scaling
NVFP4 GEMM details:
NVFP4 GEMM details (mirrors Sm100BlockScaledPersistentDenseGemmKernel):
- A operand: FP4 (quantized from BF16 activation), SFA in TMEM
- B operand: FP4 (from checkpoint), SFB in TMEM
- Accumulator: FP32 in TMEM
- Global scales: gsa (activation) and gsb (weight) applied in epilogue
- sf_vec_size=16 for NVFP4 (not 32 for MXF4)
- Separate tiled_mma_sfb for SFB TMEM copy (always CtaGroup.ONE)
"""
from __future__ import annotations
@@ -41,31 +43,33 @@ from cutlass.cute.nvgpu import cpasync, tcgen05
import cutlass.utils as utils
import cutlass.pipeline as pipeline
import cutlass.utils.blackwell_helpers as sm100_utils
import cutlass.utils.blockscaled_layout as blockscaled_utils
import cutlass.torch as cutlass_torch
LOG2_E = 1.44269504088896340736
class Nvfp4FusedRouterKernel:
"""NVFP4 block-scaled GEMM + fused sqrt(softplus)/top-k router epilogue.
Single-kernel replacement for the two-kernel path:
Nvfp4Linear (NVFP4 GEMM) activation_topk CUDA kernel
Nvfp4Linear (NVFP4 GEMM) -> activation_topk CUDA kernel
The fusion eliminates the intermediate FP32 logits write to GMEM
and the subsequent read-back. For decode (1 token, 384 experts),
the savings are small (1.5KB), but for large-batch prefill the
bandwidth savings and reduced kernel launch overhead are significant.
and the subsequent read-back.
"""
def __init__(self, mma_tiler_mn=(128, 128), cluster_shape_mn=(1, 1), top_k=6, sf_vec_size=16):
# Data types
self.a_dtype = cutlass.Float4E2M1FN # FP4 activation (quantized from BF16)
self.b_dtype = cutlass.Float4E2M1FN # FP4 weight
self.sf_dtype = cutlass.Float8E4M3FN # Scale factors (E4M3)
def __init__(
self,
mma_tiler_mn: Tuple[int, int] = (128, 128),
cluster_shape_mn: Tuple[int, int] = (1, 1),
top_k: int = 6,
sf_vec_size: int = 16,
):
# Data types - NVFP4 (FP4 weights/activations, E4M3 scale factors)
self.a_dtype = cutlass.Float4E2M1FN
self.b_dtype = cutlass.Float4E2M1FN
self.sf_dtype = cutlass.Float8E4M3FN
self.acc_dtype = cutlass.Float32
self.c_dtype = cutlass.Float32 # Accumulator for topk
self.c_dtype = cutlass.Float32
self.mma_tiler_mn = mma_tiler_mn
self.cluster_shape_mn = cluster_shape_mn
@@ -74,7 +78,6 @@ class Nvfp4FusedRouterKernel:
self.use_2cta_instrs = mma_tiler_mn[0] == 256
self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
# No mma_kind needed — make_blockscaled_trivial_tiled_mma infers from dtypes
# Warp layout (6 warps: 4 epi + 1 MMA + 1 TMA)
self.epilog_warp_id = (0, 1, 2, 3)
@@ -92,33 +95,64 @@ class Nvfp4FusedRouterKernel:
self.buffer_align_bytes = 1024
# ----------------------------------------------------------------
# MMA setup mirrors Sm100BlockScaledPersistentDenseGemmKernel
# MMA setup - mirrors Sm100BlockScaledPersistentDenseGemmKernel
# ----------------------------------------------------------------
def _create_tiled_mma(self):
"""Create the tiled MMA for NVFP4 block-scaled GEMM."""
def _create_tiled_mma(self) -> cute.TiledMma:
return sm100_utils.make_blockscaled_trivial_tiled_mma(
self.a_dtype, self.a_major_mode, self.b_major_mode,
self.sf_dtype, self.cta_group, self.mma_tiler_mn,
self.sf_vec_size,
self.a_dtype, self.b_dtype,
self.a_major_mode, self.b_major_mode,
self.sf_dtype, self.sf_vec_size,
self.cta_group, self.mma_inst_shape_mn,
)
def _create_tiled_mma_sfb(self) -> cute.TiledMma:
return sm100_utils.make_blockscaled_trivial_tiled_mma(
self.a_dtype, self.b_dtype,
self.a_major_mode, self.b_major_mode,
self.sf_dtype, self.sf_vec_size,
tcgen05.CtaGroup.ONE, self.mma_inst_shape_mn_sfb,
)
def _setup_attributes(self):
self._tiled_mma = self._create_tiled_mma()
mma_inst_shape_k = cute.size(self._tiled_mma.shape_mnk, mode=[2])
self.mma_inst_shape_mn = (self.mma_tiler[0], self.mma_tiler[1])
self.mma_inst_shape_mn_sfb = (
self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1),
cute.round_up(self.mma_inst_shape_mn[1], 128),
)
tiled_mma = self._create_tiled_mma()
tiled_mma_sfb = self._create_tiled_mma_sfb()
self._tiled_mma = tiled_mma
self._tiled_mma_sfb = tiled_mma_sfb
mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2])
mma_inst_tile_k = 4
self.mma_tiler = (*self.mma_tiler_mn, mma_inst_shape_k * mma_inst_tile_k)
self.mma_tiler = (
self.mma_inst_shape_mn[0], self.mma_inst_shape_mn[1],
mma_inst_shape_k * mma_inst_tile_k,
)
self.mma_tiler_sfb = (
self.mma_inst_shape_mn_sfb[0], self.mma_inst_shape_mn_sfb[1],
mma_inst_shape_k * mma_inst_tile_k,
)
self.cta_tile_shape_mnk = (
self.mma_tiler[0] // cute.size(self._tiled_mma.thr_id.shape),
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
self.mma_tiler[1], self.mma_tiler[2],
)
self.cta_tile_shape_mnk_sfb = (
self.mma_tiler_sfb[0] // cute.size(tiled_mma.thr_id.shape),
self.mma_tiler_sfb[1], self.mma_tiler_sfb[2],
)
self.cluster_layout_vmnk = cute.tiled_divide(
cute.make_layout((*self.cluster_shape_mn, 1)),
(self._tiled_mma.thr_id.shape,),
(tiled_mma.thr_id.shape,),
)
self.cluster_layout_sfb_vmnk = cute.tiled_divide(
cute.make_layout((*self.cluster_shape_mn, 1)),
(self._tiled_mma_sfb.thr_id.shape,),
(tiled_mma_sfb.thr_id.shape,),
)
self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2])
self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1])
self.num_mcast_ctas_sfb = cute.size(self.cluster_layout_sfb_vmnk.shape[1])
@@ -138,24 +172,25 @@ class Nvfp4FusedRouterKernel:
self.overlapping_accum = False
self.a_smem_layout_staged = sm100_utils.make_smem_layout_a(
self._tiled_mma, self.mma_tiler, self.a_dtype, self.num_ab_stage)
tiled_mma, self.mma_tiler, self.a_dtype, self.num_ab_stage)
self.b_smem_layout_staged = sm100_utils.make_smem_layout_b(
self._tiled_mma, self.mma_tiler, self.b_dtype, self.num_ab_stage)
tiled_mma, self.mma_tiler, self.b_dtype, self.num_ab_stage)
self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa(
tiled_mma, self.mma_tiler, self.sf_vec_size, self.num_ab_stage)
self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb(
tiled_mma, self.mma_tiler, self.sf_vec_size, self.num_ab_stage)
# Scale factor SMEM layouts
self.sfa_smem_layout = sm100_utils.make_smem_layout_sfa(
self._tiled_mma, self.mma_tiler, self.sf_dtype)
self.sfb_smem_layout = sm100_utils.make_smem_layout_sfb(
self._tiled_mma, self.mma_tiler, self.sf_dtype)
sf_atom_mn = 32
self.num_sfa_tmem_cols = (self.cta_tile_shape_mnk[0] // sf_atom_mn) * mma_inst_tile_k
self.num_sfb_tmem_cols = (self.cta_tile_shape_mnk_sfb[1] // sf_atom_mn) * mma_inst_tile_k
self.num_sf_tmem_cols = self.num_sfa_tmem_cols + self.num_sfb_tmem_cols
self.num_accumulator_tmem_cols = self.cta_tile_shape_mnk[1] * self.num_acc_stage
acc_shape = self._tiled_mma.partition_shape_C(self.mma_tiler[:2])
tCtAcc_fake = self._tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake)
def mainloop_s2t_copy_and_partition(
self, sSF: cute.Tensor, tSF: cute.Tensor,
) -> tuple:
"""SMEM → TMEM copy partition for scale factors (mirrors dense.py)."""
def mainloop_s2t_copy_and_partition(self, sSF, tSF):
tCsSF_compact = cute.filter_zeros(sSF)
tCtSF_compact = cute.filter_zeros(tSF)
copy_atom_s2t = cute.make_copy_atom(
@@ -167,40 +202,14 @@ class Nvfp4FusedRouterKernel:
tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact)
return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t
def mainloop_s2t_copy_and_partition_sfb(
self, sSF: cute.Tensor, tSF: cute.Tensor,
) -> tuple:
"""SMEM → TMEM copy partition for SFB (uses tiled_mma_sfb)."""
return self.mainloop_s2t_copy_and_partition(sSF, tSF)
# ----------------------------------------------------------------
def epilog_tmem_copy_and_partition(self, epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta):
"""TMEM → register copy partition. Same as dense GEMM."""
epi_thr_idx = epi_tidx % (self.threads_per_warp * len(self.epilog_warp_id))
tiled_copy_t2r, tTR_tAcc, tTR_rAcc = sm100_utils.epilogue_tmem_copy_and_partition(
epi_thr_idx, tCtAcc_base, epi_tile, self._tiled_mma, self.c_dtype, use_2cta,
)
return tiled_copy_t2r, tTR_tAcc, tTR_rAcc
# ----------------------------------------------------------------
# Public API
# ----------------------------------------------------------------
def run(
self,
mat_a, # (M, K//2) FP4 activation (quantized)
mat_b, # (K//2, N) FP4 weight
scale_a, # (M, K//16) E4M3 activation scale factors
scale_b, # (K//16, N) E4M3 weight scale factors
expert_offsets, # (1,) int32 — [M] for single-group
global_scale_a, # (1,) FP32 — gsa
global_scale_b, # (1,) FP32 — gsb
e_bias, # (N,) FP32 — per-expert bias
out_weights, # (M, top_k) FP32 — output weights
out_ids, # (M, top_k) int32 — output expert IDs
M, N, K, # Problem dimensions
scaling, # routed_scaling_factor
top_k, # k=6
stream=None,
mat_a, mat_b, scale_a, scale_b, expert_offsets,
global_scale_a, global_scale_b, e_bias, out_weights, out_ids,
M, N, K, scaling, top_k, stream=None,
):
if stream is None:
stream = cuda.CUstream(0)
@@ -208,27 +217,29 @@ class Nvfp4FusedRouterKernel:
@cute.jit
def _compiled_fn(mat_a, mat_b, scale_a, scale_b, expert_offsets,
global_scale_a, global_scale_b, e_bias, out_weights, out_ids):
# Infer major modes
self.a_major_mode = utils.LayoutEnum.from_tensor(mat_a).mma_major_mode()
self.b_major_mode = utils.LayoutEnum.from_tensor(mat_b).mma_major_mode()
mma_inst_shape_k = 32
mma_inst_tile_k = 4
self.mma_tiler = (*self.mma_tiler_mn, mma_inst_shape_k * mma_inst_tile_k)
self._setup_attributes()
tiled_mma = self._tiled_mma
tiled_mma_sfb = self._tiled_mma_sfb
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
# Compute TMA load bytes for pipeline setup
a_smem_0 = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
a_copy = cute.size_in_bytes(self.a_dtype, a_smem_0)
b_smem_0 = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
b_copy = cute.size_in_bytes(self.b_dtype, b_smem_0)
# Scale factor sizes
sfa_smem_0 = cute.slice_(self.sfa_smem_layout, (None, None, 0))
sfa_smem_0 = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0))
sfa_copy = cute.size_in_bytes(self.sf_dtype, sfa_smem_0)
sfb_smem_0 = cute.slice_(self.sfb_smem_layout, (None, None, 0))
sfb_smem_0 = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0))
sfb_copy = cute.size_in_bytes(self.sf_dtype, sfb_smem_0)
self.num_tma_load_bytes = (a_copy + b_copy + sfa_copy + sfb_copy) * atom_thr_size
# Make TMA atoms for A, B, SFA, SFB
# TMA atoms: A, B, SFA, SFB
a_smem = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
@@ -239,19 +250,16 @@ class Nvfp4FusedRouterKernel:
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
b_op, mat_b, b_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
# Scale factor TMA atoms (same pattern as dense GEMM)
sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(
self.cluster_shape_mn, tiled_mma.thr_id)
sfa_smem = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0))
sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A(
sfa_op, scale_a, sfa_smem, self.mma_tiler, tiled_mma,
self.cluster_layout_vmnk.shape, internal_type=cutlass.Int16)
sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(
self.cluster_shape_mn, tiled_mma.thr_id)
sfb_smem = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0))
sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id)
tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B(
sfb_op, scale_b, sfb_smem, self.mma_tiler, self._tiled_mma_sfb,
sfb_op, scale_b, sfb_smem, self.mma_tiler_sfb, tiled_mma_sfb,
self.cluster_layout_sfb_vmnk.shape, internal_type=cutlass.Int16)
num_M_tiles = cute.ceil_div(M, self.cta_tile_shape_mnk[0])
@@ -264,7 +272,7 @@ class Nvfp4FusedRouterKernel:
(*self.cluster_shape_mn, 1))
self._kernel(
tiled_mma,
tiled_mma, tiled_mma_sfb,
tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b,
tma_atom_sfa, tma_tensor_sfa, tma_atom_sfb, tma_tensor_sfb,
self.cluster_layout_vmnk, self.cluster_layout_sfb_vmnk,
@@ -288,7 +296,7 @@ class Nvfp4FusedRouterKernel:
# ================================================================
@cute.kernel
def _kernel(
self, tiled_mma,
self, tiled_mma, tiled_mma_sfb,
tma_atom_a, mA_mkl, tma_atom_b, mB_nkl,
tma_atom_sfa, mSFA_mkl, tma_atom_sfb, mSFB_nkl,
cluster_layout_vmnk, cluster_layout_sfb_vmnk,
@@ -320,14 +328,13 @@ class Nvfp4FusedRouterKernel:
acc_empty_mbar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
tmem_dealloc_mbar: cutlass.Int64
tmem_holding: cutlass.Int32
# Top-k heap SMEM: 128 threads * 6 entries * (score + index + act)
heap_scores: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 4*32*6], 128]
heap_indices: cute.struct.Align[cute.struct.MemRange[cutlass.Int32, 4*32*6], 128]
heap_acts: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 4*32*6], 128]
sA: cute.struct.Align[cute.struct.MemRange[self.a_dtype, cute.cosize(a_smem_layout_staged.outer)], self.buffer_align_bytes]
sB: cute.struct.Align[cute.struct.MemRange[self.b_dtype, cute.cosize(b_smem_layout_staged.outer)], self.buffer_align_bytes]
sSFA: cute.struct.Align[cute.struct.MemRange[self.sf_dtype, cute.cosize(sfa_smem_layout.outer)], self.buffer_align_bytes]
sSFB: cute.struct.Align[cute.struct.MemRange[self.sf_dtype, cute.cosize(sfb_smem_layout.outer)], self.buffer_align_bytes]
sSFA: cute.struct.Align[cute.struct.MemRange[self.sf_dtype, cute.cosize(sfa_smem_layout_staged.outer)], self.buffer_align_bytes]
sSFB: cute.struct.Align[cute.struct.MemRange[self.sf_dtype, cute.cosize(sfb_smem_layout_staged.outer)], self.buffer_align_bytes]
smem = utils.SmemAllocator()
storage = smem.allocate(SharedStorage)
@@ -370,8 +377,8 @@ class Nvfp4FusedRouterKernel:
# ==============================================================
sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner)
sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner)
sSFA = storage.sSFA.get_tensor(sfa_smem_layout.outer, swizzle=sfa_smem_layout.inner)
sSFB = storage.sSFB.get_tensor(sfb_smem_layout.outer, swizzle=sfb_smem_layout.inner)
sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged.outer, swizzle=sfa_smem_layout_staged.inner)
sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged.outer, swizzle=sfb_smem_layout_staged.inner)
# Multicast masks
a_mcast = None; b_mcast = None; sfb_mcast = None
@@ -385,31 +392,33 @@ class Nvfp4FusedRouterKernel:
gA = cute.local_tile(mA_mkl, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB_nkl, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gSFA = cute.local_tile(mSFA_mkl, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gSFB = cute.local_tile(mSFB_nkl, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gSFB = cute.local_tile(mSFB_nkl, cute.slice_(self.mma_tiler_sfb, (0,None,None)), (None,None,None))
k_tiles = cute.size(gA, mode=[3])
thr_mma = tiled_mma.get_slice(mma_tile_v)
tCgA = thr_mma.partition_A(gA); tCgB = thr_mma.partition_B(gB)
tCgSFA = thr_mma.partition_SFA(gSFA); tCgSFB = thr_mma.partition_SFB(gSFB)
tCgA = thr_mma.partition_A(gA)
tCgB = thr_mma.partition_B(gB)
tCgSFA = thr_mma.partition_SFA(gSFA)
thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_v)
tCgSFB = thr_mma_sfb.partition_SFB(gSFB)
# TMA partition
a_cta_l = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_atom_a, block_coord[2], a_cta_l,
cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_cta_l = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_atom_b, block_coord[1], b_cta_l,
cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
# SFA/SFB TMA partition (same pattern as dense GEMM)
tAsSFA, tAgSFA = cpasync.tma_partition(tma_atom_sfa, block_coord[2], a_cta_l,
cute.group_modes(sSFA,0,3), cute.group_modes(tCgSFA,0,3))
sfb_cta_l = cute.make_layout(cute.slice_(cluster_layout_sfb_vmnk, (0,None,0,0)).shape)
tBsSFB, tBgSFB = cpasync.tma_partition(tma_atom_sfb, block_coord[1], sfb_cta_l,
cute.group_modes(sSFB,0,3), cute.group_modes(tCgSFB,0,3))
# Register fragments
tCrA = tiled_mma.make_fragment_A(sA)
tCrB = tiled_mma.make_fragment_B(sB)
tCrSFA = tiled_mma.make_fragment_SFA(sSFA)
tCrSFB = tiled_mma.make_fragment_SFB(sSFB)
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
@@ -418,11 +427,13 @@ class Nvfp4FusedRouterKernel:
cute.arch.cluster_arrive_relaxed()
# ==============================================================
# TMA WARP (5) load A, B, SFA, SFB tiles from GMEM SMEM
# TMA WARP (5) - load A, B, SFA, SFB tiles from GMEM -> SMEM
# ==============================================================
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_atom_a)
cpasync.prefetch_descriptor(tma_atom_b)
cpasync.prefetch_descriptor(tma_atom_sfa)
cpasync.prefetch_descriptor(tma_atom_sfb)
tsched = utils.StaticPersistentTileScheduler.create(tile_sched_params, bidx, cute.arch.grid_dim())
wt = tsched.initial_work_tile_info()
ab_ps = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_ab_stage)
@@ -442,13 +453,13 @@ class Nvfp4FusedRouterKernel:
cute.copy(tma_atom_sfa, tSFA_s[(None, ab_ps.count)], tAsSFA[(None, ab_ps.index)],
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_ps))
cute.copy(tma_atom_sfb, tSFB_s[(None, ab_ps.count)], tBsSFB[(None, ab_ps.index)],
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_ps))
tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_ps), mcast_mask=sfb_mcast)
ab_ps.advance()
ab_pipeline.producer_tail(ab_ps)
tsched.advance_to_next_work(); wt = tsched.get_current_work()
# ==============================================================
# MMA WARP (4) NVFP4 block-scaled GEMM (mirrors dense.py mainloop)
# MMA WARP (4) - NVFP4 block-scaled GEMM
# ==============================================================
if warp_idx == self.mma_warp_id:
if cute.size(self.cluster_shape_mn) > 1:
@@ -459,9 +470,14 @@ class Nvfp4FusedRouterKernel:
tmem_ptr = tmem.retrieve_ptr(self.acc_dtype)
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
tCtAcc = tCtAcc_base[(None, None, None, 0)]
# SFA/SFB SMEM → TMEM copy atoms
tiled_copy_s2t_sfa, tCsSFA, tCtSFA = self.mainloop_s2t_copy_and_partition(sSFA, self._tiled_mma)
tiled_copy_s2t_sfb, tCsSFB, tCtSFB = self.mainloop_s2t_copy_and_partition_sfb(sSFB, self._tiled_mma_sfb)
# SFA/SFB SMEM -> TMEM copy atoms
# make_fragment_SFA/SFB returns TMEM tensor (the destination for s2t copy)
tCrSFA = tiled_mma.make_fragment_SFA(sSFA)
tCrSFB = tiled_mma_sfb.make_fragment_SFB(sSFB)
tiled_copy_s2t_sfa, tCsSFA, tCtSFA = self.mainloop_s2t_copy_and_partition(sSFA, tCrSFA)
tiled_copy_s2t_sfb, tCsSFB, tCtSFB = self.mainloop_s2t_copy_and_partition(sSFB, tCrSFB)
tsched = utils.StaticPersistentTileScheduler.create(tile_sched_params, bidx, cute.arch.grid_dim())
wt = tsched.initial_work_tile_info()
@@ -472,21 +488,17 @@ class Nvfp4FusedRouterKernel:
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in cutlass.range(0, k_tiles, 1, unroll=1):
ab_pipeline.consumer_wait(ab_cs)
# Copy A, B from SMEM → register fragments
cute.copy(tiled_mma.partition_A(sA[(None,None,None,ab_cs.index)]), tCrA)
cute.copy(tiled_mma.partition_B(sB[(None,None,None,ab_cs.index)]), tCrB)
# Copy SFA, SFB from SMEM → TMEM
s2t_stage = (None, None, None, None, ab_cs.index)
cute.copy(tiled_copy_s2t_sfa, tCsSFA[s2t_stage], tCtSFA)
cute.copy(tiled_copy_s2t_sfb, tCsSFB[s2t_stage], tCtSFB)
# GEMM with block-scaled MMA
cute.copy(tiled_copy_s2t_sfa, tCsSFA[(None,None,None,None,ab_cs.index)], tCtSFA)
cute.copy(tiled_copy_s2t_sfb, tCsSFB[(None,None,None,None,ab_cs.index)], tCtSFB)
num_kblocks = cute.size(tCrA, mode=[2])
for kblock_idx in cutlass.range(num_kblocks, unroll_full=True):
kblock_coord = (None, None, kblock_idx, ab_cs.index)
sf_kblock = (None, None, kblock_idx)
tiled_mma.set(tcgen05.Field.SFA, tCtSFA[sf_kblock].iterator)
tiled_mma.set(tcgen05.Field.SFB, tCtSFB[sf_kblock].iterator)
cute.gemm(tiled_mma, tCtAcc, tCrA[kblock_coord], tCrB[kblock_coord], tCtAcc)
cute.gemm(tiled_mma, tCrA[kblock_coord], tCrB[kblock_coord], tCtAcc)
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
ab_pipeline.consumer_release(ab_cs); ab_cs.advance()
acc_pipeline.producer_commit(acc_ps); acc_ps.advance()
@@ -495,7 +507,7 @@ class Nvfp4FusedRouterKernel:
tmem.relinquish_alloc_permit()
# ==============================================================
# EPILOGUE WARPS (0-3) TMEM sqrt(softplus) + top-k GMEM
# EPILOGUE WARPS (0-3) - TMEM -> sqrt(softplus) + top-k -> GMEM
# ==============================================================
if warp_idx in self.epilog_warp_id:
if cute.size(self.cluster_shape_mn) > 1:
@@ -507,7 +519,7 @@ class Nvfp4FusedRouterKernel:
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
tCtAcc0 = tCtAcc_base[(None, None, None, 0)]
# TMEM register copy
# TMEM -> register copy (tcgen05.ld)
epi_n = self.epi_tile_n
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(epi_n)), self.acc_dtype)
@@ -538,6 +550,7 @@ class Nvfp4FusedRouterKernel:
while wt.is_valid_tile:
acc_pipeline.consumer_wait(acc_cs)
# Reset heap
for i in cutlass.range(6, unroll=1):
hs[i] = cutlass.Float32(-1e30)
@@ -551,8 +564,7 @@ class Nvfp4FusedRouterKernel:
elem_cnt = cute.size(rFlat)
for e in cutlass.range(elem_cnt, unroll=4):
logit_raw = rFlat[e]
# Apply global scales: the GEMM output is
# sum(A_sf * B_sf) which needs * gsa * gsb
# Apply global scales: GEMM output = sum(A_sf * B_sf), needs * gsa * gsb
logit = logit_raw * gsa_val * gsb_val
coord = cFlat[e]
e_idx = coord[1]
@@ -571,7 +583,6 @@ class Nvfp4FusedRouterKernel:
do_push = (score > hs[0]) or (score == hs[0] and e_idx < hi[0])
if do_push:
hs[0] = score; hi[0] = e_idx; ha[0] = act
# Sift down (k=6, fully unrolled)
root = 0
_done = cutlass.Bool(False)
while root < 3 and not _done:
@@ -603,9 +614,7 @@ class Nvfp4FusedRouterKernel:
# Thread 0 of warp 0 does the final merge + store
if warp_idx == 0 and tidx == 0:
# Initialize final heap from thread 0
fs = list(hs); fi = list(hi); fa = list(ha)
# Merge all 128 threads
for t in cutlass.range(1, 128, unroll=1):
for i in cutlass.range(6, unroll=1):
cs = storage.heap_scores.data_ptr()[t*6+i]
@@ -614,7 +623,6 @@ class Nvfp4FusedRouterKernel:
if ci >= 0:
if cs > fs[0] or (cs == fs[0] and ci < fi[0]):
fs[0] = cs; fi[0] = ci; fa[0] = ca
# Sift down
r = 0
_done2 = cutlass.Bool(False)
while r < 3 and not _done2:
@@ -650,11 +658,9 @@ class Nvfp4FusedRouterKernel:
inv_sum = cutlass.Float32(1.0) / act_sum
sc = cutlass.Float32(routed_scaling_factor)
# Get tile coordinates for output indexing
tc = wt.tile_idx
row_base = tc[0] // cute.size(tiled_mma.thr_id.shape) * self.cta_tile_shape_mnk[0]
# Store to GMEM
for i in cutlass.range(6, unroll=1):
out_w_tensor[row_base + 0, i] = sorted_a[i] * inv_sum * sc
out_id_tensor[row_base + 0, i] = sorted_i[i]
@@ -673,13 +679,12 @@ class Nvfp4FusedRouterKernel:
# ================================================================
# Python wrapper mirrors Nvfp4Linear but uses fused kernel
# Python wrapper - mirrors Nvfp4Linear but uses fused kernel
# ================================================================
def run_nvfp4_fused_router(
hidden_states: torch.Tensor, # [M, K] BF16
mat_b: torch.Tensor, # [K//2, N_packed] FP4 weight
scale_a: torch.Tensor, # [M, K//16] E4M3 activation SF
scale_b: torch.Tensor, # [K//16, N] E4M3 weight SF
mat_b: torch.Tensor, # [1, K_packed, N_packed] float4_e2m1fn_x2 (K-major, from Nvfp4Linear._mat_b)
scale_b: torch.Tensor, # [M_sf, N_sf] E4M3 (assembled, from Nvfp4Linear._scale_b)
gsa: float, # activation global scale
gsb: float, # weight global scale
e_bias: torch.Tensor, # [E] FP32
@@ -692,10 +697,22 @@ def run_nvfp4_fused_router(
Combines NVFP4 block-scaled GEMM + sqrt(softplus) + top-k
into a single kernel launch.
Args:
hidden_states: [M, K] BF16 input tensor
mat_b: Processed FP4 weight tensor from Nvfp4Linear._mat_b
scale_b: Processed E4M3 scale factors from Nvfp4Linear._scale_b
gsa: Activation global scale
gsb: Weight global scale
e_bias: Per-expert selection bias [E] FP32
routed_scaling_factor: Scaling factor for routed weights
top_k: Number of experts to select
out_weights: Pre-allocated output weights [M, top_k] FP32
out_ids: Pre-allocated output expert IDs [M, top_k] int32
"""
M = hidden_states.shape[0]
K = hidden_states.shape[1]
N = scale_b.shape[1] # num_experts from weight scale shape
N = scale_b.shape[1] if scale_b.dim() == 2 else scale_b.shape[-1] # num_experts
device = hidden_states.device
if out_weights is None:
@@ -703,9 +720,7 @@ def run_nvfp4_fused_router(
if out_ids is None:
out_ids = torch.empty(M, top_k, dtype=torch.int32, device=device)
# Expert offsets: single group of M tokens
expert_offsets = torch.tensor([M], dtype=torch.int32, device=device)
# Global scales as 1-element tensors
gsa_t = torch.tensor([gsa], dtype=torch.float32, device=device)
gsb_t = torch.tensor([gsb], dtype=torch.float32, device=device)
@@ -732,4 +747,4 @@ def run_nvfp4_fused_router(
e_bias_ct, out_w_ct, out_id_ct,
M, N, K, routed_scaling_factor, top_k,
)
return out_weights, out_ids
return out_weights, out_ids

View File

@@ -1,27 +1,57 @@
"""Test NVFP4 fused router kernel against the reference path.
Phase 1: Verify reference path (BF16 linear + activation_topk) works.
Phase 2: Test CuTeDSL fused kernel (needs B200 for CuTeDSL compilation).
Reference path: Nvfp4Linear (NVFP4 GEMM) -> activation_topk CUDA kernel
Fused path: Nvfp4FusedRouterKernel (single kernel, no intermediate logits)
Tests:
1. Reference path correctness (BF16 GEMM + activation_topk)
2. NVFP4 fused kernel matches reference (cosine similarity)
3. NVFP4 fused kernel matches at various (M, N, K) shapes
"""
import sys
import os
import torch
import math
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
def _sqrt_softplus(x: torch.Tensor) -> torch.Tensor:
"""Reference: sqrt(softplus(x))."""
abs_x = x.abs()
sp = x.clamp(min=0) + torch.log1p(torch.exp(-abs_x))
return sp.sqrt()
def _reference_topk(
logits: torch.Tensor, # [M, N] FP32
e_bias: torch.Tensor, # [N] FP32
routed_scaling_factor: float,
top_k: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Pure PyTorch reference: sqrt(softplus) + bias + topk + renorm."""
act = _sqrt_softplus(logits)
scores = act + e_bias.unsqueeze(0)
topk_vals, topk_ids = scores.topk(top_k, dim=1)
# Gather activations for the selected experts
selected_act = act.gather(1, topk_ids)
# Renormalize
act_sum = selected_act.sum(dim=1, keepdim=True)
weights = selected_act / act_sum * routed_scaling_factor
return weights, topk_ids
def test_reference_router():
"""Test the reference BF16 linear + activation_topk path."""
torch.manual_seed(42)
device = "cuda"
M = 4 # tokens
K = 7168 # hidden size
N = 384 # num experts
M = 4
K = 7168
N = 384
top_k = 6
routed_scaling_factor = 0.5
# Create BF16 hidden states and weight
hidden_states = torch.randn(M, K, dtype=torch.bfloat16, device=device)
W_gate = torch.randn(K, N, dtype=torch.bfloat16, device=device)
e_bias = torch.randn(N, dtype=torch.float32, device=device)
@@ -33,53 +63,195 @@ def test_reference_router():
out_ids = torch.empty(M, top_k, dtype=torch.int32, device=device)
run_fused_activation_topk(logits, e_bias, routed_scaling_factor, top_k, out_w, out_ids)
# Verify results
# Also compute pure PyTorch reference
ref_w, ref_ids = _reference_topk(logits, e_bias, routed_scaling_factor, top_k)
print(f"Reference router test (M={M}, K={K}, N={N}, top_k={top_k}):")
print(f" Top-k IDs (row 0): {out_ids[0].tolist()}")
print(f" Top-k weights (row 0): {[f'{w:.4f}' for w in out_w[0].tolist()]}")
# Verify: weights sum to routed_scaling_factor (approximately)
# Verify weights sum to routed_scaling_factor
w_sum = out_w.sum(dim=1)
expected = routed_scaling_factor
for row in range(M):
diff = abs(w_sum[row].item() - expected)
assert diff < 0.01, f"Row {row}: weight sum {w_sum[row].item():.4f} != {expected:.4f}"
print(f" Weight sums: {[f'{s:.4f}' for s in w_sum.tolist()]} (expected {expected})")
diff = abs(w_sum[row].item() - routed_scaling_factor)
assert diff < 0.01, f"Row {row}: weight sum {w_sum[row].item():.4f} != {routed_scaling_factor:.4f}"
print(f" Weight sums: {[f'{s:.4f}' for s in w_sum.tolist()]} (expected {routed_scaling_factor})")
# Verify: IDs are valid (0 <= id < N)
assert (out_ids >= 0).all() and (out_ids < N).all(), "Invalid expert IDs"
print(f" All expert IDs in [0, {N}) ")
print(f" All expert IDs in [0, {N}) OK")
# Verify: no duplicate IDs within each row
for row in range(M):
row_ids = out_ids[row].tolist()
assert len(set(row_ids)) == len(row_ids), f"Row {row} has duplicate IDs: {row_ids}"
print(f" No duplicate IDs ")
print(f" No duplicate IDs OK")
# Verify: weights are non-negative
assert (out_w >= 0).all(), "Negative weights"
print(f" All weights non-negative ")
print(f" All weights non-negative OK")
print("Reference router test PASSED ✓")
# Cross-check with pure PyTorch reference
ids_match = (out_ids == ref_ids).all().item()
if ids_match:
print(f" IDs match PyTorch reference OK")
else:
mismatches = (out_ids != ref_ids).sum().item()
print(f" WARNING: {mismatches} ID mismatches vs PyTorch reference")
print("Reference router test PASSED")
def test_nvfp4_fused_router_import():
"""Test that the fused router kernel module imports without error."""
try:
from dsv4.kernels.router.nvfp4_fused_router_kernel import Nvfp4FusedRouterKernel
print(f"Nvfp4FusedRouterKernel imported successfully")
kernel = Nvfp4FusedRouterKernel(top_k=6)
print(f" mma_tiler_mn: {kernel.mma_tiler_mn}")
print(f" threads_per_cta: {kernel.threads_per_cta}")
print(f" sf_vec_size: {kernel.sf_vec_size}")
print(f" top_k: {kernel.top_k}")
print("Fused router kernel class construction PASSED ✓")
except Exception as e:
print(f"Fused router kernel import failed: {e}")
print("This is expected if CuTeDSL is not available.")
def test_nvfp4_fused_router():
"""Test the NVFP4 fused router kernel against reference path.
This test requires B200 (CuTeDSL compilation + NVFP4 tensor cores).
Must be run via fire_b200_test.
"""
torch.manual_seed(42)
device = "cuda"
M = 1 # decode: single token
K = 7168 # hidden size
N = 384 # num experts
top_k = 6
routed_scaling_factor = 0.5
print(f"NVFP4 fused router test (M={M}, K={K}, N={N}, top_k={top_k}):")
# Create BF16 hidden states
hidden_states = torch.randn(M, K, dtype=torch.bfloat16, device=device)
# Create BF16 gate weight and quantize to NVFP4
W_gate_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device)
e_bias = torch.randn(N, dtype=torch.float32, device=device)
# Quantize weight to NVFP4 (matching Nvfp4Linear's quantization)
from dsv4.ops.quantize import quantize_weight_nvfp4, quantize_activation_nvfp4
w_fp4, w_sf, ws2_scalar = quantize_weight_nvfp4(W_gate_bf16)
# Reference path: NVFP4 GEMM + activation_topk
# Build Nvfp4Linear the same way single_shot_inference does
from dsv4.layers.linear import Nvfp4Linear
gate_lin = Nvfp4Linear(K, N, max_num_tokens=8, device=device)
gate_lin.fp4 = [w_fp4]
gate_lin.sf = [w_sf]
gate_lin.gs = [1.0]
gate_lin.ws2 = [ws2_scalar.to(device) if ws2_scalar is not None else None]
gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0)
gate_lin.finalize_weights()
logits_ref = gate_lin(hidden_states).float()
from dsv4.kernels.router._activation_topk import run_fused_activation_topk
ref_w = torch.empty(M, top_k, dtype=torch.float32, device=device)
ref_ids = torch.empty(M, top_k, dtype=torch.int32, device=device)
run_fused_activation_topk(logits_ref, e_bias, routed_scaling_factor, top_k, ref_w, ref_ids)
print(f" Reference: IDs={ref_ids[0].tolist()}, weights={[f'{w:.4f}' for w in ref_w[0].tolist()]}")
# Fused kernel path — use Nvfp4Linear's processed weight tensors
from dsv4.kernels.router.nvfp4_fused_router_kernel import run_nvfp4_fused_router
gsb_val = gate_lin._gsb.item()
gsa = gate_lin._activation_global_scale
fused_w, fused_ids = run_nvfp4_fused_router(
hidden_states, gate_lin._mat_b, gate_lin._scale_b,
gsa, gsb_val, e_bias, routed_scaling_factor, top_k,
)
print(f" Fused: IDs={fused_ids[0].tolist()}, weights={[f'{w:.4f}' for w in fused_w[0].tolist()]}")
# Compare IDs
ids_match = (fused_ids == ref_ids).all().item()
if ids_match:
print(f" IDs match reference: OK")
else:
mismatches = (fused_ids != ref_ids).sum().item()
print(f" WARNING: {mismatches} ID mismatches vs reference")
# Compare weights
if fused_w.shape == ref_w.shape:
cos = torch.nn.functional.cosine_similarity(
fused_w.flatten().unsqueeze(0),
ref_w.flatten().unsqueeze(0)
).item()
max_diff = (fused_w - ref_w).abs().max().item()
print(f" Weight cosine: {cos:.6f}, max diff: {max_diff:.6f}")
if cos > 0.999:
print(f" Weight match: EXCELLENT")
elif cos > 0.99:
print(f" Weight match: GOOD")
else:
print(f" Weight match: POOR - needs investigation")
# Verify weight normalization
w_sum = fused_w.sum(dim=1)
for row in range(M):
diff = abs(w_sum[row].item() - routed_scaling_factor)
if diff > 0.01:
print(f" WARNING: Row {row} weight sum {w_sum[row].item():.4f} != {routed_scaling_factor:.4f}")
print("NVFP4 fused router test DONE")
def test_nvfp4_fused_router_multitoken():
"""Test with M=4 tokens (batched decode or small prefill)."""
torch.manual_seed(123)
device = "cuda"
M = 4
K = 7168
N = 384
top_k = 6
routed_scaling_factor = 0.5
print(f"NVFP4 fused router multi-token test (M={M}, K={K}, N={N}):")
hidden_states = torch.randn(M, K, dtype=torch.bfloat16, device=device)
W_gate_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device)
e_bias = torch.randn(N, dtype=torch.float32, device=device)
from dsv4.ops.quantize import quantize_weight_nvfp4
w_fp4, w_sf, ws2_scalar = quantize_weight_nvfp4(W_gate_bf16)
# Reference
from dsv4.layers.linear import Nvfp4Linear
gate_lin = Nvfp4Linear(K, N, max_num_tokens=8, device=device)
gate_lin.fp4 = [w_fp4]
gate_lin.sf = [w_sf]
gate_lin.gs = [1.0]
gate_lin.ws2 = [ws2_scalar.to(device) if ws2_scalar is not None else None]
gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0)
gate_lin.finalize_weights()
logits_ref = gate_lin(hidden_states).float()
from dsv4.kernels.router._activation_topk import run_fused_activation_topk
ref_w = torch.empty(M, top_k, dtype=torch.float32, device=device)
ref_ids = torch.empty(M, top_k, dtype=torch.int32, device=device)
run_fused_activation_topk(logits_ref, e_bias, routed_scaling_factor, top_k, ref_w, ref_ids)
# Fused
from dsv4.kernels.router.nvfp4_fused_router_kernel import run_nvfp4_fused_router
gsb_val = gate_lin._gsb.item()
gsa = gate_lin._activation_global_scale
fused_w, fused_ids = run_nvfp4_fused_router(
hidden_states, gate_lin._mat_b, gate_lin._scale_b,
gsa, gsb_val, e_bias, routed_scaling_factor, top_k,
)
# Compare
ids_match = (fused_ids == ref_ids).all().item()
mismatches = (fused_ids != ref_ids).sum().item() if not ids_match else 0
cos = torch.nn.functional.cosine_similarity(
fused_w.flatten().unsqueeze(0), ref_w.flatten().unsqueeze(0)).item()
print(f" IDs match: {ids_match} ({mismatches} mismatches)")
print(f" Weight cosine: {cos:.6f}")
print("NVFP4 fused router multi-token test DONE")
if __name__ == "__main__":
test_reference_router()
print()
test_nvfp4_fused_router_import()
# NVFP4 fused tests require B200 — they'll fail on non-SM100 hardware
try:
test_nvfp4_fused_router()
except Exception as e:
print(f"NVFP4 fused router test skipped (requires B200): {e}")
print()
try:
test_nvfp4_fused_router_multitoken()
except Exception as e:
print(f"NVFP4 fused router multi-token test skipped (requires B200): {e}")