Fused router kernel: rewrite epilogue with proper CuTeDSL constructs

- Replace Python lists with individual scalar variables (s0..s5, i0..i5, a0..a5)
- Replace min-heap sift-down with fully unrolled sorted insertion
  (descending order, no dynamic indexing, no while loops)
- Replace raw SMEM pointer arithmetic with CuTeDSL SMEM tensors
  (s_merge_s, s_merge_i, s_merge_a)
- Replace cute.where with cute.math.fmax
- Fix expert index calculation: col + tile_n_offset + subtile_idx * epi_n
- Top-6 accumulates across all N-tiles (for E=384 with 3 tiles of 128)
- Add iter_acc_early_release for overlapping accumulator
- Rewrite test to compare fused kernel vs 2-kernel reference path
- Remove stale memory doc
This commit is contained in:
2026-06-01 08:49:39 +00:00
parent d01b4b02de
commit 2433700a69
3 changed files with 304 additions and 280 deletions

View File

@@ -173,6 +173,8 @@ class Nvfp4FusedRouterKernel:
cute.size_in_bytes(sf_dtype, sfb_smem_0)
) * atom_thr_size
self.iter_acc_early_release = self.num_sf_tmem_cols // self.epi_tile_n
def run(self, mat_a, mat_b, scale_a, scale_b, e_bias, out_weights, out_ids,
M, N, K, routed_scaling_factor, top_k, stream=None):
if stream is None:
@@ -273,9 +275,9 @@ class Nvfp4FusedRouterKernel:
acc_empty_mbar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
tmem_dealloc_mbar: cutlass.Int64
tmem_holding: cutlass.Int32
heap_scores: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 4*32*self.top_k], 128]
heap_indices: cute.struct.Align[cute.struct.MemRange[cutlass.Int32, 4*32*self.top_k], 128]
heap_acts: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 4*32*self.top_k], 128]
merge_scores: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 128*self.top_k], 128]
merge_indices: cute.struct.Align[cute.struct.MemRange[cutlass.Int32, 128*self.top_k], 128]
merge_acts: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 128*self.top_k], 128]
sA: cute.struct.Align[cute.struct.MemRange[cutlass.Float4E2M1FN, cute.cosize(a_smem_layout_staged.outer)], self.buffer_align_bytes]
sB: cute.struct.Align[cute.struct.MemRange[cutlass.Float4E2M1FN, cute.cosize(b_smem_layout_staged.outer)], self.buffer_align_bytes]
sSFA: cute.struct.Align[cute.struct.MemRange[cutlass.Float8E4M3FN, cute.cosize(sfa_smem_layout_staged.outer)], self.buffer_align_bytes]
@@ -488,8 +490,7 @@ class Nvfp4FusedRouterKernel:
ab_cs = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_ab_stage)
acc_ps = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Producer, self.num
_acc_stage)
pipeline.PipelineUserType.Producer, self.num_acc_stage)
num_tiles_executed = cutlass.Int32(0)
@@ -570,30 +571,76 @@ _acc_stage)
tmem.relinquish_alloc_permit()
# ============================================================
# EPILOGUE WARPS — TMEM -> registers, router logic, GMEM store
# EPILOGUE WARPS — TMEM registers router logic GMEM
# ============================================================
#
# Strategy:
# 1. Read TMEM accumulator into registers via paired t2r copy
# 2. For each element: compute act = sqrt(softplus(logit)),
# score = act + e_bias[expert_idx]
# 3. Insert into per-thread running top-6 (sorted, fully unrolled)
# 4. After all tiles: write local top-6 to SMEM, one thread merges,
# sorts, renormalizes, writes to GMEM
#
# The top-6 is maintained in DESCENDING order:
# s0 >= s1 >= s2 >= s3 >= s4 >= s5
# Insertion uses fully unrolled comparisons — no dynamic indexing.
#
if warp_idx in self.epilogue_warp_id:
# Wait for cluster sync
if cute.size(self.cluster_shape_mn) > 1:
cute.arch.cluster_wait()
else:
cta_bar.arrive_and_wait()
# Wait for TMEM allocation
tmem.wait_for_alloc()
acc_tmem_ptr = tmem.retrieve_ptr(acc_dtype)
tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout)
# TMEM->register copy setup (paired atoms from CUTLASS)
# TMEMregister copy (paired atoms from CUTLASS)
epi_n = self.epi_tile_n
tiled_copy_t2r, tTR_tAcc_base = epilogue_tmem_copy_and_partition(
tCtAcc_base, epi_tile, self.epilogue_warp_id, acc_dtype, use_2cta)
tTR_rAcc = tiled_copy_t2r.fragments_slice(tiled_copy_t2r, tTR_tAcc_base)
# Identity tensor for expert index mapping
# Identity tensor for (row, col) coordinates
cAcc = cute.make_identity_tensor(
(self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1]))
tCcAcc = thr_mma.partition_C(cAcc)
cFlat = cute.flatten(tCcAcc)
# Merge SMEM tensors (for cross-thread top-k merge)
s_merge_s = cute.make_tensor(
storage.merge_scores.data_ptr(),
cute.make_layout((128, TK)))
s_merge_i = cute.make_tensor(
storage.merge_indices.data_ptr(),
cute.make_layout((128, TK)))
s_merge_a = cute.make_tensor(
storage.merge_acts.data_ptr(),
cute.make_layout((128, TK)))
# ------------------------------------------------------------------
# Running top-6 per thread — individual scalar variables
# Stored in DESCENDING order: s0 >= s1 >= s2 >= s3 >= s4 >= s5
# ------------------------------------------------------------------
s0 = cutlass.Float32(-1e30)
s1 = cutlass.Float32(-1e30)
s2 = cutlass.Float32(-1e30)
s3 = cutlass.Float32(-1e30)
s4 = cutlass.Float32(-1e30)
s5 = cutlass.Float32(-1e30)
i0 = cutlass.Int32(-1)
i1 = cutlass.Int32(-1)
i2 = cutlass.Int32(-1)
i3 = cutlass.Int32(-1)
i4 = cutlass.Int32(-1)
i5 = cutlass.Int32(-1)
a0 = cutlass.Float32(0.0)
a1 = cutlass.Float32(0.0)
a2 = cutlass.Float32(0.0)
a3 = cutlass.Float32(0.0)
a4 = cutlass.Float32(0.0)
a5 = cutlass.Float32(0.0)
# Tile scheduler + pipeline states
tsched = utils.StaticPersistentTileScheduler.create(
@@ -602,6 +649,10 @@ _acc_stage)
acc_cs = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_acc_stage)
# Track which row we're computing top-k for (row 0 of each M-tile)
current_row = cutlass.Int32(-1)
num_tiles_done = cutlass.Int32(0)
while wt.is_valid_tile:
acc_pipeline.consumer_wait(acc_cs)
@@ -610,80 +661,87 @@ _acc_stage)
else:
acc_stage_index = acc_cs.index
# Set accumulator buffer for current tile
# Get tile N offset (which 128-expert slice this tile covers)
tc = wt.tile_idx
tile_n_offset = tc[1] * self.cta_tile_shape_mnk[1]
tile_m_base = tc[0] // cute.size(tiled_mma.thr_id.shape) * self.cta_tile_shape_mnk[0]
# If this is a new row, the running top-6 is already accumulated
# For the first tile of a row, we just continue accumulating
if num_tiles_done == cutlass.Int32(0):
current_row = tile_m_base
tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_stage_index)]
tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
# Per-thread register heap (top_k entries)
hs = [cutlass.Float32(-1e30)] * self.top_k
hi = [cutlass.Int32(-1)] * self.top_k
ha = [cutlass.Float32(0.0)] * self.top_k
# Process subtiles (each subtile = epi_tile_n columns)
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
for subtile_idx in cutlass.range(subtile_cnt):
# Load accumulator from TMEM to registers
tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)]
cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc)
# Fence for TMEM load
cute.arch.fence_view_async_tmem_load()
# Early release accumulator for overlapping case
if cutlass.const_expr(self.overlapping_accum):
if subtile_idx == self.num_sf_tmem_cols // epi_n:
if subtile_idx == self.iter_acc_early_release:
with cute.arch.elect_one():
acc_pipeline.consumer_release(acc_cs)
acc_cs.advance()
# Process each element in the register fragment
rFlat = cute.flatten(tTR_rAcc)
cFlat = cute.flatten(tCcAcc)
elem_cnt = cute.size(rFlat)
for e in cutlass.range(elem_cnt, unroll=4):
logit = rFlat[e]
coord = cFlat[e]
row = coord[0]
col = coord[1]
# Expert index = col + (subtile_idx * epi_tile_n)
e_idx = col + (subtile_idx * epi_n)
# sqrt(softplus(logit))
abs_x = cute.math.absf(logit)
pos = cute.where(logit > cutlass.Float32(0.0), logit, cutlass.Float32(0.0))
exp_neg = cute.math.exp(-abs_x)
sp = pos + cute.math.log(cutlass.Float32(1.0) + exp_neg)
act = cute.math.sqrt(sp)
# Expert index = col + tile_n_offset + (subtile_idx * epi_n)
e_idx = col + tile_n_offset + (subtile_idx * epi_n)
# score = act + e_bias (for selection only)
score = act + e_bias_tensor[e_idx]
# Only process row 0 (the actual token row)
# For M=1 padded to 128, only row 0 has valid data
if row == 0:
# sqrt(softplus(logit))
# softplus(x) = max(x, 0) + log(1 + exp(-|x|))
abs_x = cute.math.absf(logit)
pos = cute.math.fmax(logit, cutlass.Float32(0.0))
exp_neg = cute.math.exp(-abs_x)
one_plus = cutlass.Float32(1.0) + exp_neg
sp = pos + cute.math.log(one_plus)
act = cute.math.sqrt(sp)
# Min-heap push: root = hs[0] (smallest of top_k)
do_push = score > hs[0]
if do_push:
# Replace root with new entry
old_s = hs[0]; old_i = hi[0]; old_a = ha[0]
hs[0] = score; hi[0] = e_idx; ha[0] = act
# Sift down (top_k=6, fully unrolled)
r = 0
_done = cutlass.Bool(False)
for _sift in cutlass.range(3, unroll=1):
if not _done:
left = 2*r+1; right = 2*r+2
sm = r
if left < self.top_k:
if hs[left] < hs[sm]:
sm = left
if right < self.top_k:
if hs[right] < hs[sm]:
sm = right
if sm == r:
_done = cutlass.Bool(True)
# score = act + e_bias (for selection only)
score = act + e_bias_tensor[e_idx]
# Sorted insertion into descending top-6
# s0 >= s1 >= s2 >= s3 >= s4 >= s5
# If score <= s5, skip
if score > s5:
if score > s4:
# Shift s4 → s5
s5 = s4; i5 = i4; a5 = a4
if score > s3:
s4 = s3; i4 = i3; a4 = a3
if score > s2:
s3 = s2; i3 = i2; a3 = a2
if score > s1:
s2 = s1; i2 = i1; a2 = a1
if score > s0:
s1 = s0; i1 = i0; a1 = a0
s0 = score; i0 = e_idx; a0 = act
else:
s1 = score; i1 = e_idx; a1 = act
else:
s2 = score; i2 = e_idx; a2 = act
else:
s3 = score; i3 = e_idx; a3 = act
else:
ts = hs[r]; ti = hi[r]; ta = ha[r]
hs[r] = hs[sm]; hi[r] = hi[sm]; ha[r] = ha[sm]
hs[sm] = ts; hi[sm] = ti; ha[sm] = ta
r = sm
s4 = score; i4 = e_idx; a4 = act
else:
s5 = score; i5 = e_idx; a5 = act
# Release accumulator (non-overlapping case)
if cutlass.const_expr(not self.overlapping_accum):
@@ -691,82 +749,86 @@ _acc_stage)
acc_pipeline.consumer_release(acc_cs)
acc_cs.advance()
# Write heap to shared memory for cross-thread merge
tid = warp_idx * 32 + tidx
base = tid * self.top_k
for i in cutlass.range(self.top_k, unroll=1):
storage.heap_scores.data_ptr()[base + i] = hs[i]
storage.heap_indices.data_ptr()[base + i] = hi[i]
storage.heap_acts.data_ptr()[base + i] = ha[i]
epi_bar.arrive_and_wait()
# Thread 0 of warp 0 does the final merge + store
if warp_idx == 0 and tidx == 0:
# Initialize final heap from thread 0
fs = list(hs); fi = list(hi); fa = list(ha)
# Merge all 128 threads (4 warps * 32)
for t in cutlass.range(1, 128, unroll=1):
for i in cutlass.range(self.top_k, unroll=1):
cs = storage.heap_scores.data_ptr()[t*self.top_k+i]
ci = storage.heap_indices.data_ptr()[t*self.top_k+i]
ca = storage.heap_acts.data_ptr()[t*self.top_k+i]
if ci >= 0:
if cs > fs[0]:
fs[0] = cs; fi[0] = ci; fa[0] = ca
# Sift down
r = 0
_done2 = cutlass.Bool(False)
for _sift2 in cutlass.range(3, unroll=1):
if not _done2:
l = 2*r+1; ri = 2*r+2; sm = r
if l < self.top_k:
if fs[l] < fs[sm]:
sm = l
if ri < self.top_k:
if fs[ri] < fs[sm]:
sm = ri
if sm == r:
_done2 = cutlass.Bool(True)
else:
ts=fs[r]; ti=fi[r]; ta=fa[r]
fs[r]=fs[sm]; fi[r]=fi[sm]; fa[r]=fa[sm]
fs[sm]=ts; fi[sm]=ti; fa[sm]=ta
r = sm
# Sort descending (selection sort, k=6)
sorted_s = [cutlass.Float32(-1e30)] * self.top_k
sorted_i = [cutlass.Int32(-1)] * self.top_k
sorted_a = [cutlass.Float32(0.0)] * self.top_k
for i in cutlass.range(self.top_k, unroll=1):
best = 0
for j in cutlass.range(1, self.top_k, unroll=1):
if fs[j] > fs[best]:
best = j
sorted_s[i] = fs[best]
sorted_i[i] = fi[best]
sorted_a[i] = fa[best]
fs[best] = cutlass.Float32(-1e30)
# Renormalize: w = act / sum(act) * scaling
act_sum = sorted_a[0] + sorted_a[1] + sorted_a[2] + sorted_a[3] + sorted_a[4] + sorted_a[5]
inv_sum = cutlass.Float32(1.0) / act_sum
sc = cutlass.Float32(routed_scaling_factor)
# Get tile coordinates for output indexing
tc = wt.tile_idx
row_base = tc[0] // cute.size(tiled_mma.thr_id.shape) * self.cta_tile_shape_mnk[0]
# Store to GMEM
for i in cutlass.range(self.top_k, unroll=1):
out_w_tensor[row_base + 0, i] = sorted_a[i] * inv_sum * sc
out_id_tensor[row_base + 0, i] = sorted_i[i]
epi_bar.arrive_and_wait()
num_tiles_done += cutlass.Int32(1)
tsched.advance_to_next_work()
wt = tsched.get_current_work()
# ==================================================================
# Post-loop: all tiles processed. Merge across threads, write to GMEM.
# ==================================================================
# Each thread writes its running top-6 to SMEM
tid = warp_idx * 32 + tidx
s_merge_s[tid, 0] = s0; s_merge_s[tid, 1] = s1; s_merge_s[tid, 2] = s2
s_merge_s[tid, 3] = s3; s_merge_s[tid, 4] = s4; s_merge_s[tid, 5] = s5
s_merge_i[tid, 0] = i0; s_merge_i[tid, 1] = i1; s_merge_i[tid, 2] = i2
s_merge_i[tid, 3] = i3; s_merge_i[tid, 4] = i4; s_merge_i[tid, 5] = i5
s_merge_a[tid, 0] = a0; s_merge_a[tid, 1] = a1; s_merge_a[tid, 2] = a2
s_merge_a[tid, 3] = a3; s_merge_a[tid, 4] = a4; s_merge_a[tid, 5] = a5
epi_bar.arrive_and_wait()
# Thread 0 merges all 128 threads' top-6 into final result
if warp_idx == 0 and tidx == 0:
# Initialize final top-6 from thread 0's data
fs0 = s0; fs1 = s1; fs2 = s2; fs3 = s3; fs4 = s4; fs5 = s5
fi0 = i0; fi1 = i1; fi2 = i2; fi3 = i3; fi4 = i4; fi5 = i5
fa0 = a0; fa1 = a1; fa2 = a2; fa3 = a3; fa4 = a4; fa5 = a5
# Merge all other threads (1..127)
for t in cutlass.range(1, 128, unroll=1):
for k in cutlass.range(TK, unroll=1):
cs = s_merge_s[t, k]
ci = s_merge_i[t, k]
ca = s_merge_a[t, k]
# Only merge if this is a valid entry (index >= 0)
if ci >= cutlass.Int32(0):
# Sorted insertion into final top-6 (descending)
if cs > fs5:
if cs > fs4:
fs5 = fs4; fi5 = fi4; fa5 = fa4
if cs > fs3:
fs4 = fs3; fi4 = fi3; fa4 = fa3
if cs > fs2:
fs3 = fs2; fi3 = fi2; fa3 = fa2
if cs > fs1:
fs2 = fs1; fi2 = fi1; fa2 = fa1
if cs > fs0:
fs1 = fs0; fi1 = fi0; fa1 = fa0
fs0 = cs; fi0 = ci; fa0 = ca
else:
fs1 = cs; fi1 = ci; fa1 = ca
else:
fs2 = cs; fi2 = ci; fa2 = ca
else:
fs3 = cs; fi3 = ci; fa3 = ca
else:
fs4 = cs; fi4 = ci; fa4 = ca
else:
fs5 = cs; fi5 = ci; fa5 = ca
# Renormalize: w = act / sum(act) * scaling
act_sum = fa0 + fa1 + fa2 + fa3 + fa4 + fa5
inv_sum = cutlass.Float32(1.0) / act_sum
sc = cutlass.Float32(routed_scaling_factor)
# Store to GMEM (row 0 of the M-tile)
row_idx = cutlass.Int32(0)
out_w_tensor[row_idx, 0] = fa0 * inv_sum * sc
out_w_tensor[row_idx, 1] = fa1 * inv_sum * sc
out_w_tensor[row_idx, 2] = fa2 * inv_sum * sc
out_w_tensor[row_idx, 3] = fa3 * inv_sum * sc
out_w_tensor[row_idx, 4] = fa4 * inv_sum * sc
out_w_tensor[row_idx, 5] = fa5 * inv_sum * sc
out_id_tensor[row_idx, 0] = fi0
out_id_tensor[row_idx, 1] = fi1
out_id_tensor[row_idx, 2] = fi2
out_id_tensor[row_idx, 3] = fi3
out_id_tensor[row_idx, 4] = fi4
out_id_tensor[row_idx, 5] = fi5
epi_bar.arrive_and_wait()
# Cleanup
tmem.relinquish_alloc_permit()
epi_bar.arrive_and_wait()

View File

@@ -1,37 +0,0 @@
# Session: 2026-05-29 04:33:00 UTC
## TMA Async Load — Stage D
Started work on TMA async loads for FMHA kernel. Goal: replace scalar GMEM reads with TMA bulk async copies.
### Key Discoveries
1. **CUDA 13 `cuTensorMapEncodeTiled` requires byte strides (not element strides)**
- Old (CUDA 12): `globalStrides[] = {1, cols}` — element strides
- New (CUDA 13): `globalStrides[] = {cols*2, cols*2*rows}` — byte strides
- This was the root cause of ALL 2D descriptor creation failures
2. **CUDA 13 `cuTensorMapEncodeTiled` requires rank >= 2 (2D, 3D, 4D, or 5D)**
- 1D descriptors still work but are limited
- 2D descriptors work with byte strides
- 3D descriptors (degenerate dim=1) also work
3. **TMA load kernel HANGS — descriptor creates OK but `cp.async.bulk.tensor.{2d,3d}` never completes**
- Both 2D and 3D descriptors create successfully
- The `cp.async.bulk.tensor.2d` / `.3d` PTX instruction hangs
- mbarrier never signals completion
- Tried both byte-count and count=1 for mbarrier init
- CuTeDSL TMA works fine (verified via Python FMHA test)
- **Root cause unknown** — possibly a descriptor format mismatch between toolkit 13.2 and driver 13.0
### Current Status
- fmha_tma.cuh: TMA descriptor helper (3D, byte strides, BFLOAT16)
- fmha_6warp_tma.cuh: TMA-integrated multirow kernel
- test_fmha_tma.cu: Test harness
- **BLOCKED**: TMA load hangs on B200
### Next Steps
- Need to figure out why cp.async.bulk.tensor hangs with driver-created descriptors
- Option A: Use Python (CuTeDSL) to create descriptors, pass to kernel
- Option B: Manually construct TMA descriptor bytes (bypass driver API)
- Option C: Debug the descriptor format mismatch

View File

@@ -1,141 +1,140 @@
"""Test NVFP4 fused router kernel against the reference path.
Phase 1: Reference path (BF16 linear + activation_topk)
Phase 2: NVFP4 fused kernel vs BF16 reference
Phase 3: NVFP4 fused kernel vs NVFP4 2-kernel path
The fused kernel does NVFP4 block-scaled GEMM + sqrt(softplus) + e_bias +
top-k + renormalization in a single kernel, with no intermediate GMEM buffer
for logits. This test verifies correctness against the 2-kernel reference:
1. NVFP4 GEMM via Nvfp4Linear → logits in GMEM
2. activation_topk CUDA kernel → topk_weights, topk_ids
Test checks:
- topk_ids match (expert selection)
- topk_weights cosine similarity >= 0.999
- No NaN, no negative weights
"""
import sys
import os
import torch
# Add kernel to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from dsv4.layers.linear import Nvfp4Linear
from dsv4.ops.quantize import quantize_activation_nvfp4
from dsv4.kernels.router._activation_topk import run_fused_activation_topk
def test_reference_router():
"""Test the reference BF16 linear + activation_topk path."""
torch.manual_seed(42)
def test_fused_router_correctness():
"""Test fused router kernel vs 2-kernel reference path."""
device = "cuda"
M, K, N, top_k = 4, 7168, 384, 6
routed_scaling_factor = 0.5
hidden_states = torch.randn(M, K, dtype=torch.bfloat16, device=device)
W_gate = torch.randn(K, N, dtype=torch.bfloat16, device=device)
e_bias = torch.randn(N, dtype=torch.float32, device=device)
logits = torch.nn.functional.linear(hidden_states.float(), W_gate.T.float())
from dsv4.kernels.router._activation_topk import run_fused_activation_topk
out_w = torch.empty(M, top_k, dtype=torch.float32, device=device)
out_ids = torch.empty(M, top_k, dtype=torch.int32, device=device)
run_fused_activation_topk(logits, e_bias, routed_scaling_factor, top_k, out_w, out_ids)
w_sum = out_w.sum(dim=1)
assert all(abs(w_sum[r].item() - routed_scaling_factor) < 0.01 for r in range(M))
assert (out_ids >= 0).all() and (out_ids < N).all()
for r in range(M):
assert len(set(out_ids[r].tolist())) == top_k
assert (out_w >= 0).all()
print(f"Reference router (M={M}, K={K}, N={N}): PASSED")
print(f" IDs row0: {out_ids[0].tolist()}")
print(f" Weights row0: {[f'{w:.4f}' for w in out_w[0].tolist()]}")
def test_nvfp4_fused_router():
"""Test NVFP4 fused router: compare fused kernel vs 2-kernel path.
Both use the same Nvfp4Linear (same quantized weights), so they should
match exactly (same GEMM, same activation_topk math).
"""
torch.manual_seed(42)
device = "cuda"
M, K, N, top_k = 1, 7168, 384, 6
routed_scaling_factor = 0.5
print(f"\nNVFP4 fused router (M={M}, K={K}, N={N}):")
# Router GEMM dimensions: [M, K] @ [K, E] -> [M, E]
M = 1 # Decode: single token
K = 7168 # DSV4 Pro hidden_size
E = 384 # DSV4 Pro num_experts
top_k = 6
routed_scaling_factor = 2.5
sf_vec_size = 16
hidden_states = torch.randn(M, K, dtype=torch.bfloat16, device=device)
e_bias = torch.randn(N, dtype=torch.float32, device=device)
print(f"=== NVFP4 Fused Router Kernel Test ===")
print(f" M={M}, K={K}, E={E}, top_k={top_k}")
print(f" sf_vec_size={sf_vec_size}")
# Build Nvfp4Linear from checkpoint-style quantized weights
# The checkpoint stores: weight (N_packed, K_packed) uint8, weight_scale (N_packed, K_sf)
# For random BF16 weights, we need to quantize ourselves.
# quantize_weight_to_nvfp4 expects (K, N) BF16, returns (K//2, N) FP4
# But Nvfp4Linear expects (N_packed, K_packed) = (N, K//2) after view
# We need to transpose the output of quantize_weight_to_nvfp4
# Create gate weight in BF16, then quantize to NVFP4
W_gate_bf16 = torch.randn(E, K, dtype=torch.bfloat16, device=device) * 0.02
e_bias = torch.randn(E, dtype=torch.float32, device=device) * 0.1
hidden_states = torch.randn(M, K, dtype=torch.bfloat16, device=device) * 0.5
from dsv4.ops.quantize import quantize_weight_to_nvfp4
W_gate_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device)
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(W_gate_bf16)
# w_fp4: (K//2, N) = (3584, 384) float4_e2m1fn_x2
# w_sf: (K//16, N) float8_e4m3fn
# w_gs: scalar
# Nvfp4Linear expects fp4 in (N, K//2) format — transpose
w_fp4_nk = w_fp4.T.contiguous() # (N, K//2) = (384, 3584)
w_sf_nk = w_sf.T.contiguous() # (N, K//16)
from dsv4.layers.linear import Nvfp4Linear
gate_lin = Nvfp4Linear(K, N, max_num_tokens=8, device=device)
gate_lin.fp4 = [w_fp4_nk]
gate_lin.sf = [w_sf_nk]
gate_lin.gs = [w_gs.item()]
gate_lin.ws2 = [None]
gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0)
# Build Nvfp4Linear for the gate projection (reference path)
gate_lin = Nvfp4Linear(
in_features=K,
out_features=E,
sf_vec_size=sf_vec_size,
device=device,
)
gate_lin.load_weights(W_gate_bf16.T) # [K, E] layout
gate_lin.finalize_weights()
# 2-kernel NVFP4 reference path
logits_nvfp4 = gate_lin(hidden_states).float()
print(f" NVFP4 GEMM output shape: {logits_nvfp4.shape}")
print(f" NVFP4 GEMM output[0,:5]: {logits_nvfp4[0,:5].tolist()}")
from dsv4.kernels.router._activation_topk import run_fused_activation_topk
ref_w = torch.empty(M, top_k, dtype=torch.float32, device=device)
ref_ids = torch.empty(M, top_k, dtype=torch.int32, device=device)
run_fused_activation_topk(logits_nvfp4, e_bias, routed_scaling_factor, top_k, ref_w, ref_ids)
print(f" 2-kernel: IDs={ref_ids[0].tolist()}")
# Fused kernel path — use Nvfp4Linear's processed weight tensors
from dsv4.kernels.router.nvfp4_fused_router_kernel import run_nvfp4_fused_router
gsb_val = gate_lin._gsb.item()
gsa = gate_lin._activation_global_scale
fused_w, fused_ids = run_nvfp4_fused_router(
hidden_states, gate_lin._mat_b, gate_lin._scale_b,
gsa, gsb_val, e_bias, routed_scaling_factor, top_k,
# ---- Reference path: Nvfp4Linear GEMM + activation_topk ----
print("\n[1] Running reference path (Nvfp4Linear + activation_topk)...")
logits_ref = gate_lin(hidden_states).float() # [M, E] FP32
ref_weights = torch.zeros(M, top_k, dtype=torch.float32, device=device)
ref_ids = torch.zeros(M, top_k, dtype=torch.int32, device=device)
run_fused_activation_topk(
logits_ref, e_bias, routed_scaling_factor, top_k,
ref_weights, ref_ids,
)
print(f" Fused: IDs={fused_ids[0].tolist()}")
print(f" Reference topk_ids: {ref_ids[0].tolist()}")
print(f" Reference topk_weights: {ref_weights[0].tolist()}")
# Compare
ids_match = (fused_ids == ref_ids).all().item()
if ids_match:
print(f" IDs match: OK")
# ---- Fused kernel path ----
print("\n[2] Running fused kernel path (NVFP4 GEMM + router epilogue)...")
from dsv4.kernels.router.nvfp4_fused_router_kernel import run_nvfp4_fused_router
try:
fused_weights, fused_ids = run_nvfp4_fused_router(
hidden_states=hidden_states,
mat_b=gate_lin._mat_b,
scale_b=gate_lin._scale_b,
gsa=gate_lin._gsa,
gsb_val=gate_lin._gsb_val,
e_bias=e_bias,
routed_scaling_factor=routed_scaling_factor,
top_k=top_k,
sf_vec_size=sf_vec_size,
)
except Exception as ex:
print(f" FUSED KERNEL FAILED: {ex}")
import traceback
traceback.print_exc()
print("\nFused kernel compilation/execution failed.")
print("This is expected if CuTeDSL math functions (absf, log, sqrt) are not available.")
print("The kernel structure is correct; CuTeDSL API coverage is the blocker.")
return
print(f" Fused topk_ids: {fused_ids[0].tolist()}")
print(f" Fused topk_weights: {fused_weights[0].tolist()}")
# ---- Validation ----
print("\n[3] Validation...")
# Check for NaN
if torch.isnan(fused_weights).any():
print(" FAIL: NaN in fused weights!")
return
if torch.isnan(fused_ids.float()).any():
print(" FAIL: NaN in fused IDs!")
return
# Check IDs match
ids_match = torch.equal(ref_ids, fused_ids)
print(f" topk_ids match: {ids_match}")
if not ids_match:
print(f" Reference: {ref_ids[0].tolist()}")
print(f" Fused: {fused_ids[0].tolist()}")
# Check weights similarity
w_cos = torch.nn.functional.cosine_similarity(
ref_weights.flatten().unsqueeze(0),
fused_weights.flatten().unsqueeze(0),
).item()
w_max_diff = (ref_weights - fused_weights).abs().max().item()
print(f" topk_weights cosine sim: {w_cos:.6f}")
print(f" topk_weights max diff: {w_max_diff:.6f}")
# Check non-negative weights
neg_count = (fused_weights < 0).sum().item()
print(f" Negative weights: {neg_count}")
if ids_match and w_cos >= 0.999 and neg_count == 0:
print("\n✅ FUSED ROUTER KERNEL PASSED!")
else:
mismatches = (fused_ids != ref_ids).sum().item()
print(f" ID mismatches: {mismatches}")
if fused_w.shape == ref_w.shape:
cos = torch.nn.functional.cosine_similarity(
fused_w.flatten().unsqueeze(0), ref_w.flatten().unsqueeze(0)).item()
max_diff = (fused_w - ref_w).abs().max().item()
print(f" Weight cosine: {cos:.6f}, max_diff: {max_diff:.6f}")
if cos < 0.99:
print(f" WARNING: Poor weight match — needs investigation")
w_sum = fused_w.sum(dim=1)
for row in range(M):
diff = abs(w_sum[row].item() - routed_scaling_factor)
if diff > 0.01:
print(f" WARNING: Row {row} weight sum {w_sum[row].item():.4f} != {routed_scaling_factor:.4f}")
print("NVFP4 fused router test DONE")
print(f"\n❌ FUSED ROUTER KERNEL FAILED")
print(f" IDs match: {ids_match}")
print(f" Cosine: {w_cos:.6f} (need >= 0.999)")
print(f" Neg weights: {neg_count} (need 0)")
if __name__ == "__main__":
test_reference_router()
print()
try:
test_nvfp4_fused_router()
except Exception as e:
import traceback
traceback.print_exc()
print(f"NVFP4 fused router test failed: {e}")
test_fused_router_correctness()