Fused router kernel: rewrite epilogue with proper CuTeDSL constructs
- Replace Python lists with individual scalar variables (s0..s5, i0..i5, a0..a5) - Replace min-heap sift-down with fully unrolled sorted insertion (descending order, no dynamic indexing, no while loops) - Replace raw SMEM pointer arithmetic with CuTeDSL SMEM tensors (s_merge_s, s_merge_i, s_merge_a) - Replace cute.where with cute.math.fmax - Fix expert index calculation: col + tile_n_offset + subtile_idx * epi_n - Top-6 accumulates across all N-tiles (for E=384 with 3 tiles of 128) - Add iter_acc_early_release for overlapping accumulator - Rewrite test to compare fused kernel vs 2-kernel reference path - Remove stale memory doc
This commit is contained in:
@@ -173,6 +173,8 @@ class Nvfp4FusedRouterKernel:
|
||||
cute.size_in_bytes(sf_dtype, sfb_smem_0)
|
||||
) * atom_thr_size
|
||||
|
||||
self.iter_acc_early_release = self.num_sf_tmem_cols // self.epi_tile_n
|
||||
|
||||
def run(self, mat_a, mat_b, scale_a, scale_b, e_bias, out_weights, out_ids,
|
||||
M, N, K, routed_scaling_factor, top_k, stream=None):
|
||||
if stream is None:
|
||||
@@ -273,9 +275,9 @@ class Nvfp4FusedRouterKernel:
|
||||
acc_empty_mbar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage]
|
||||
tmem_dealloc_mbar: cutlass.Int64
|
||||
tmem_holding: cutlass.Int32
|
||||
heap_scores: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 4*32*self.top_k], 128]
|
||||
heap_indices: cute.struct.Align[cute.struct.MemRange[cutlass.Int32, 4*32*self.top_k], 128]
|
||||
heap_acts: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 4*32*self.top_k], 128]
|
||||
merge_scores: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 128*self.top_k], 128]
|
||||
merge_indices: cute.struct.Align[cute.struct.MemRange[cutlass.Int32, 128*self.top_k], 128]
|
||||
merge_acts: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 128*self.top_k], 128]
|
||||
sA: cute.struct.Align[cute.struct.MemRange[cutlass.Float4E2M1FN, cute.cosize(a_smem_layout_staged.outer)], self.buffer_align_bytes]
|
||||
sB: cute.struct.Align[cute.struct.MemRange[cutlass.Float4E2M1FN, cute.cosize(b_smem_layout_staged.outer)], self.buffer_align_bytes]
|
||||
sSFA: cute.struct.Align[cute.struct.MemRange[cutlass.Float8E4M3FN, cute.cosize(sfa_smem_layout_staged.outer)], self.buffer_align_bytes]
|
||||
@@ -488,8 +490,7 @@ class Nvfp4FusedRouterKernel:
|
||||
ab_cs = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, self.num_ab_stage)
|
||||
acc_ps = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Producer, self.num
|
||||
_acc_stage)
|
||||
pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
|
||||
num_tiles_executed = cutlass.Int32(0)
|
||||
|
||||
@@ -570,30 +571,76 @@ _acc_stage)
|
||||
tmem.relinquish_alloc_permit()
|
||||
|
||||
# ============================================================
|
||||
# EPILOGUE WARPS — TMEM -> registers, router logic, GMEM store
|
||||
# EPILOGUE WARPS — TMEM → registers → router logic → GMEM
|
||||
# ============================================================
|
||||
#
|
||||
# Strategy:
|
||||
# 1. Read TMEM accumulator into registers via paired t2r copy
|
||||
# 2. For each element: compute act = sqrt(softplus(logit)),
|
||||
# score = act + e_bias[expert_idx]
|
||||
# 3. Insert into per-thread running top-6 (sorted, fully unrolled)
|
||||
# 4. After all tiles: write local top-6 to SMEM, one thread merges,
|
||||
# sorts, renormalizes, writes to GMEM
|
||||
#
|
||||
# The top-6 is maintained in DESCENDING order:
|
||||
# s0 >= s1 >= s2 >= s3 >= s4 >= s5
|
||||
# Insertion uses fully unrolled comparisons — no dynamic indexing.
|
||||
#
|
||||
if warp_idx in self.epilogue_warp_id:
|
||||
# Wait for cluster sync
|
||||
if cute.size(self.cluster_shape_mn) > 1:
|
||||
cute.arch.cluster_wait()
|
||||
else:
|
||||
cta_bar.arrive_and_wait()
|
||||
|
||||
# Wait for TMEM allocation
|
||||
tmem.wait_for_alloc()
|
||||
acc_tmem_ptr = tmem.retrieve_ptr(acc_dtype)
|
||||
tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout)
|
||||
|
||||
# TMEM->register copy setup (paired atoms from CUTLASS)
|
||||
# TMEM → register copy (paired atoms from CUTLASS)
|
||||
epi_n = self.epi_tile_n
|
||||
tiled_copy_t2r, tTR_tAcc_base = epilogue_tmem_copy_and_partition(
|
||||
tCtAcc_base, epi_tile, self.epilogue_warp_id, acc_dtype, use_2cta)
|
||||
tTR_rAcc = tiled_copy_t2r.fragments_slice(tiled_copy_t2r, tTR_tAcc_base)
|
||||
|
||||
# Identity tensor for expert index mapping
|
||||
# Identity tensor for (row, col) coordinates
|
||||
cAcc = cute.make_identity_tensor(
|
||||
(self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1]))
|
||||
tCcAcc = thr_mma.partition_C(cAcc)
|
||||
cFlat = cute.flatten(tCcAcc)
|
||||
|
||||
# Merge SMEM tensors (for cross-thread top-k merge)
|
||||
s_merge_s = cute.make_tensor(
|
||||
storage.merge_scores.data_ptr(),
|
||||
cute.make_layout((128, TK)))
|
||||
s_merge_i = cute.make_tensor(
|
||||
storage.merge_indices.data_ptr(),
|
||||
cute.make_layout((128, TK)))
|
||||
s_merge_a = cute.make_tensor(
|
||||
storage.merge_acts.data_ptr(),
|
||||
cute.make_layout((128, TK)))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Running top-6 per thread — individual scalar variables
|
||||
# Stored in DESCENDING order: s0 >= s1 >= s2 >= s3 >= s4 >= s5
|
||||
# ------------------------------------------------------------------
|
||||
s0 = cutlass.Float32(-1e30)
|
||||
s1 = cutlass.Float32(-1e30)
|
||||
s2 = cutlass.Float32(-1e30)
|
||||
s3 = cutlass.Float32(-1e30)
|
||||
s4 = cutlass.Float32(-1e30)
|
||||
s5 = cutlass.Float32(-1e30)
|
||||
i0 = cutlass.Int32(-1)
|
||||
i1 = cutlass.Int32(-1)
|
||||
i2 = cutlass.Int32(-1)
|
||||
i3 = cutlass.Int32(-1)
|
||||
i4 = cutlass.Int32(-1)
|
||||
i5 = cutlass.Int32(-1)
|
||||
a0 = cutlass.Float32(0.0)
|
||||
a1 = cutlass.Float32(0.0)
|
||||
a2 = cutlass.Float32(0.0)
|
||||
a3 = cutlass.Float32(0.0)
|
||||
a4 = cutlass.Float32(0.0)
|
||||
a5 = cutlass.Float32(0.0)
|
||||
|
||||
# Tile scheduler + pipeline states
|
||||
tsched = utils.StaticPersistentTileScheduler.create(
|
||||
@@ -602,6 +649,10 @@ _acc_stage)
|
||||
acc_cs = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
|
||||
# Track which row we're computing top-k for (row 0 of each M-tile)
|
||||
current_row = cutlass.Int32(-1)
|
||||
num_tiles_done = cutlass.Int32(0)
|
||||
|
||||
while wt.is_valid_tile:
|
||||
acc_pipeline.consumer_wait(acc_cs)
|
||||
|
||||
@@ -610,80 +661,87 @@ _acc_stage)
|
||||
else:
|
||||
acc_stage_index = acc_cs.index
|
||||
|
||||
# Set accumulator buffer for current tile
|
||||
# Get tile N offset (which 128-expert slice this tile covers)
|
||||
tc = wt.tile_idx
|
||||
tile_n_offset = tc[1] * self.cta_tile_shape_mnk[1]
|
||||
tile_m_base = tc[0] // cute.size(tiled_mma.thr_id.shape) * self.cta_tile_shape_mnk[0]
|
||||
|
||||
# If this is a new row, the running top-6 is already accumulated
|
||||
# For the first tile of a row, we just continue accumulating
|
||||
if num_tiles_done == cutlass.Int32(0):
|
||||
current_row = tile_m_base
|
||||
|
||||
tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_stage_index)]
|
||||
tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
|
||||
|
||||
# Per-thread register heap (top_k entries)
|
||||
hs = [cutlass.Float32(-1e30)] * self.top_k
|
||||
hi = [cutlass.Int32(-1)] * self.top_k
|
||||
ha = [cutlass.Float32(0.0)] * self.top_k
|
||||
|
||||
# Process subtiles (each subtile = epi_tile_n columns)
|
||||
subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3])
|
||||
for subtile_idx in cutlass.range(subtile_cnt):
|
||||
# Load accumulator from TMEM to registers
|
||||
tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)]
|
||||
cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc)
|
||||
|
||||
# Fence for TMEM load
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
|
||||
# Early release accumulator for overlapping case
|
||||
if cutlass.const_expr(self.overlapping_accum):
|
||||
if subtile_idx == self.num_sf_tmem_cols // epi_n:
|
||||
if subtile_idx == self.iter_acc_early_release:
|
||||
with cute.arch.elect_one():
|
||||
acc_pipeline.consumer_release(acc_cs)
|
||||
acc_cs.advance()
|
||||
|
||||
# Process each element in the register fragment
|
||||
rFlat = cute.flatten(tTR_rAcc)
|
||||
cFlat = cute.flatten(tCcAcc)
|
||||
elem_cnt = cute.size(rFlat)
|
||||
for e in cutlass.range(elem_cnt, unroll=4):
|
||||
logit = rFlat[e]
|
||||
coord = cFlat[e]
|
||||
row = coord[0]
|
||||
col = coord[1]
|
||||
# Expert index = col + (subtile_idx * epi_tile_n)
|
||||
e_idx = col + (subtile_idx * epi_n)
|
||||
|
||||
# sqrt(softplus(logit))
|
||||
abs_x = cute.math.absf(logit)
|
||||
pos = cute.where(logit > cutlass.Float32(0.0), logit, cutlass.Float32(0.0))
|
||||
exp_neg = cute.math.exp(-abs_x)
|
||||
sp = pos + cute.math.log(cutlass.Float32(1.0) + exp_neg)
|
||||
act = cute.math.sqrt(sp)
|
||||
# Expert index = col + tile_n_offset + (subtile_idx * epi_n)
|
||||
e_idx = col + tile_n_offset + (subtile_idx * epi_n)
|
||||
|
||||
# score = act + e_bias (for selection only)
|
||||
score = act + e_bias_tensor[e_idx]
|
||||
# Only process row 0 (the actual token row)
|
||||
# For M=1 padded to 128, only row 0 has valid data
|
||||
if row == 0:
|
||||
# sqrt(softplus(logit))
|
||||
# softplus(x) = max(x, 0) + log(1 + exp(-|x|))
|
||||
abs_x = cute.math.absf(logit)
|
||||
pos = cute.math.fmax(logit, cutlass.Float32(0.0))
|
||||
exp_neg = cute.math.exp(-abs_x)
|
||||
one_plus = cutlass.Float32(1.0) + exp_neg
|
||||
sp = pos + cute.math.log(one_plus)
|
||||
act = cute.math.sqrt(sp)
|
||||
|
||||
# Min-heap push: root = hs[0] (smallest of top_k)
|
||||
do_push = score > hs[0]
|
||||
if do_push:
|
||||
# Replace root with new entry
|
||||
old_s = hs[0]; old_i = hi[0]; old_a = ha[0]
|
||||
hs[0] = score; hi[0] = e_idx; ha[0] = act
|
||||
# Sift down (top_k=6, fully unrolled)
|
||||
r = 0
|
||||
_done = cutlass.Bool(False)
|
||||
for _sift in cutlass.range(3, unroll=1):
|
||||
if not _done:
|
||||
left = 2*r+1; right = 2*r+2
|
||||
sm = r
|
||||
if left < self.top_k:
|
||||
if hs[left] < hs[sm]:
|
||||
sm = left
|
||||
if right < self.top_k:
|
||||
if hs[right] < hs[sm]:
|
||||
sm = right
|
||||
if sm == r:
|
||||
_done = cutlass.Bool(True)
|
||||
# score = act + e_bias (for selection only)
|
||||
score = act + e_bias_tensor[e_idx]
|
||||
|
||||
# Sorted insertion into descending top-6
|
||||
# s0 >= s1 >= s2 >= s3 >= s4 >= s5
|
||||
# If score <= s5, skip
|
||||
if score > s5:
|
||||
if score > s4:
|
||||
# Shift s4 → s5
|
||||
s5 = s4; i5 = i4; a5 = a4
|
||||
if score > s3:
|
||||
s4 = s3; i4 = i3; a4 = a3
|
||||
if score > s2:
|
||||
s3 = s2; i3 = i2; a3 = a2
|
||||
if score > s1:
|
||||
s2 = s1; i2 = i1; a2 = a1
|
||||
if score > s0:
|
||||
s1 = s0; i1 = i0; a1 = a0
|
||||
s0 = score; i0 = e_idx; a0 = act
|
||||
else:
|
||||
s1 = score; i1 = e_idx; a1 = act
|
||||
else:
|
||||
s2 = score; i2 = e_idx; a2 = act
|
||||
else:
|
||||
s3 = score; i3 = e_idx; a3 = act
|
||||
else:
|
||||
ts = hs[r]; ti = hi[r]; ta = ha[r]
|
||||
hs[r] = hs[sm]; hi[r] = hi[sm]; ha[r] = ha[sm]
|
||||
hs[sm] = ts; hi[sm] = ti; ha[sm] = ta
|
||||
r = sm
|
||||
s4 = score; i4 = e_idx; a4 = act
|
||||
else:
|
||||
s5 = score; i5 = e_idx; a5 = act
|
||||
|
||||
# Release accumulator (non-overlapping case)
|
||||
if cutlass.const_expr(not self.overlapping_accum):
|
||||
@@ -691,82 +749,86 @@ _acc_stage)
|
||||
acc_pipeline.consumer_release(acc_cs)
|
||||
acc_cs.advance()
|
||||
|
||||
# Write heap to shared memory for cross-thread merge
|
||||
tid = warp_idx * 32 + tidx
|
||||
base = tid * self.top_k
|
||||
for i in cutlass.range(self.top_k, unroll=1):
|
||||
storage.heap_scores.data_ptr()[base + i] = hs[i]
|
||||
storage.heap_indices.data_ptr()[base + i] = hi[i]
|
||||
storage.heap_acts.data_ptr()[base + i] = ha[i]
|
||||
|
||||
epi_bar.arrive_and_wait()
|
||||
|
||||
# Thread 0 of warp 0 does the final merge + store
|
||||
if warp_idx == 0 and tidx == 0:
|
||||
# Initialize final heap from thread 0
|
||||
fs = list(hs); fi = list(hi); fa = list(ha)
|
||||
# Merge all 128 threads (4 warps * 32)
|
||||
for t in cutlass.range(1, 128, unroll=1):
|
||||
for i in cutlass.range(self.top_k, unroll=1):
|
||||
cs = storage.heap_scores.data_ptr()[t*self.top_k+i]
|
||||
ci = storage.heap_indices.data_ptr()[t*self.top_k+i]
|
||||
ca = storage.heap_acts.data_ptr()[t*self.top_k+i]
|
||||
if ci >= 0:
|
||||
if cs > fs[0]:
|
||||
fs[0] = cs; fi[0] = ci; fa[0] = ca
|
||||
# Sift down
|
||||
r = 0
|
||||
_done2 = cutlass.Bool(False)
|
||||
for _sift2 in cutlass.range(3, unroll=1):
|
||||
if not _done2:
|
||||
l = 2*r+1; ri = 2*r+2; sm = r
|
||||
if l < self.top_k:
|
||||
if fs[l] < fs[sm]:
|
||||
sm = l
|
||||
if ri < self.top_k:
|
||||
if fs[ri] < fs[sm]:
|
||||
sm = ri
|
||||
if sm == r:
|
||||
_done2 = cutlass.Bool(True)
|
||||
else:
|
||||
ts=fs[r]; ti=fi[r]; ta=fa[r]
|
||||
fs[r]=fs[sm]; fi[r]=fi[sm]; fa[r]=fa[sm]
|
||||
fs[sm]=ts; fi[sm]=ti; fa[sm]=ta
|
||||
r = sm
|
||||
|
||||
# Sort descending (selection sort, k=6)
|
||||
sorted_s = [cutlass.Float32(-1e30)] * self.top_k
|
||||
sorted_i = [cutlass.Int32(-1)] * self.top_k
|
||||
sorted_a = [cutlass.Float32(0.0)] * self.top_k
|
||||
for i in cutlass.range(self.top_k, unroll=1):
|
||||
best = 0
|
||||
for j in cutlass.range(1, self.top_k, unroll=1):
|
||||
if fs[j] > fs[best]:
|
||||
best = j
|
||||
sorted_s[i] = fs[best]
|
||||
sorted_i[i] = fi[best]
|
||||
sorted_a[i] = fa[best]
|
||||
fs[best] = cutlass.Float32(-1e30)
|
||||
|
||||
# Renormalize: w = act / sum(act) * scaling
|
||||
act_sum = sorted_a[0] + sorted_a[1] + sorted_a[2] + sorted_a[3] + sorted_a[4] + sorted_a[5]
|
||||
inv_sum = cutlass.Float32(1.0) / act_sum
|
||||
sc = cutlass.Float32(routed_scaling_factor)
|
||||
|
||||
# Get tile coordinates for output indexing
|
||||
tc = wt.tile_idx
|
||||
row_base = tc[0] // cute.size(tiled_mma.thr_id.shape) * self.cta_tile_shape_mnk[0]
|
||||
|
||||
# Store to GMEM
|
||||
for i in cutlass.range(self.top_k, unroll=1):
|
||||
out_w_tensor[row_base + 0, i] = sorted_a[i] * inv_sum * sc
|
||||
out_id_tensor[row_base + 0, i] = sorted_i[i]
|
||||
|
||||
epi_bar.arrive_and_wait()
|
||||
num_tiles_done += cutlass.Int32(1)
|
||||
|
||||
tsched.advance_to_next_work()
|
||||
wt = tsched.get_current_work()
|
||||
|
||||
# ==================================================================
|
||||
# Post-loop: all tiles processed. Merge across threads, write to GMEM.
|
||||
# ==================================================================
|
||||
# Each thread writes its running top-6 to SMEM
|
||||
tid = warp_idx * 32 + tidx
|
||||
s_merge_s[tid, 0] = s0; s_merge_s[tid, 1] = s1; s_merge_s[tid, 2] = s2
|
||||
s_merge_s[tid, 3] = s3; s_merge_s[tid, 4] = s4; s_merge_s[tid, 5] = s5
|
||||
s_merge_i[tid, 0] = i0; s_merge_i[tid, 1] = i1; s_merge_i[tid, 2] = i2
|
||||
s_merge_i[tid, 3] = i3; s_merge_i[tid, 4] = i4; s_merge_i[tid, 5] = i5
|
||||
s_merge_a[tid, 0] = a0; s_merge_a[tid, 1] = a1; s_merge_a[tid, 2] = a2
|
||||
s_merge_a[tid, 3] = a3; s_merge_a[tid, 4] = a4; s_merge_a[tid, 5] = a5
|
||||
|
||||
epi_bar.arrive_and_wait()
|
||||
|
||||
# Thread 0 merges all 128 threads' top-6 into final result
|
||||
if warp_idx == 0 and tidx == 0:
|
||||
# Initialize final top-6 from thread 0's data
|
||||
fs0 = s0; fs1 = s1; fs2 = s2; fs3 = s3; fs4 = s4; fs5 = s5
|
||||
fi0 = i0; fi1 = i1; fi2 = i2; fi3 = i3; fi4 = i4; fi5 = i5
|
||||
fa0 = a0; fa1 = a1; fa2 = a2; fa3 = a3; fa4 = a4; fa5 = a5
|
||||
|
||||
# Merge all other threads (1..127)
|
||||
for t in cutlass.range(1, 128, unroll=1):
|
||||
for k in cutlass.range(TK, unroll=1):
|
||||
cs = s_merge_s[t, k]
|
||||
ci = s_merge_i[t, k]
|
||||
ca = s_merge_a[t, k]
|
||||
# Only merge if this is a valid entry (index >= 0)
|
||||
if ci >= cutlass.Int32(0):
|
||||
# Sorted insertion into final top-6 (descending)
|
||||
if cs > fs5:
|
||||
if cs > fs4:
|
||||
fs5 = fs4; fi5 = fi4; fa5 = fa4
|
||||
if cs > fs3:
|
||||
fs4 = fs3; fi4 = fi3; fa4 = fa3
|
||||
if cs > fs2:
|
||||
fs3 = fs2; fi3 = fi2; fa3 = fa2
|
||||
if cs > fs1:
|
||||
fs2 = fs1; fi2 = fi1; fa2 = fa1
|
||||
if cs > fs0:
|
||||
fs1 = fs0; fi1 = fi0; fa1 = fa0
|
||||
fs0 = cs; fi0 = ci; fa0 = ca
|
||||
else:
|
||||
fs1 = cs; fi1 = ci; fa1 = ca
|
||||
else:
|
||||
fs2 = cs; fi2 = ci; fa2 = ca
|
||||
else:
|
||||
fs3 = cs; fi3 = ci; fa3 = ca
|
||||
else:
|
||||
fs4 = cs; fi4 = ci; fa4 = ca
|
||||
else:
|
||||
fs5 = cs; fi5 = ci; fa5 = ca
|
||||
|
||||
# Renormalize: w = act / sum(act) * scaling
|
||||
act_sum = fa0 + fa1 + fa2 + fa3 + fa4 + fa5
|
||||
inv_sum = cutlass.Float32(1.0) / act_sum
|
||||
sc = cutlass.Float32(routed_scaling_factor)
|
||||
|
||||
# Store to GMEM (row 0 of the M-tile)
|
||||
row_idx = cutlass.Int32(0)
|
||||
out_w_tensor[row_idx, 0] = fa0 * inv_sum * sc
|
||||
out_w_tensor[row_idx, 1] = fa1 * inv_sum * sc
|
||||
out_w_tensor[row_idx, 2] = fa2 * inv_sum * sc
|
||||
out_w_tensor[row_idx, 3] = fa3 * inv_sum * sc
|
||||
out_w_tensor[row_idx, 4] = fa4 * inv_sum * sc
|
||||
out_w_tensor[row_idx, 5] = fa5 * inv_sum * sc
|
||||
out_id_tensor[row_idx, 0] = fi0
|
||||
out_id_tensor[row_idx, 1] = fi1
|
||||
out_id_tensor[row_idx, 2] = fi2
|
||||
out_id_tensor[row_idx, 3] = fi3
|
||||
out_id_tensor[row_idx, 4] = fi4
|
||||
out_id_tensor[row_idx, 5] = fi5
|
||||
|
||||
epi_bar.arrive_and_wait()
|
||||
|
||||
# Cleanup
|
||||
tmem.relinquish_alloc_permit()
|
||||
epi_bar.arrive_and_wait()
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
# Session: 2026-05-29 04:33:00 UTC
|
||||
|
||||
## TMA Async Load — Stage D
|
||||
|
||||
Started work on TMA async loads for FMHA kernel. Goal: replace scalar GMEM reads with TMA bulk async copies.
|
||||
|
||||
### Key Discoveries
|
||||
|
||||
1. **CUDA 13 `cuTensorMapEncodeTiled` requires byte strides (not element strides)**
|
||||
- Old (CUDA 12): `globalStrides[] = {1, cols}` — element strides
|
||||
- New (CUDA 13): `globalStrides[] = {cols*2, cols*2*rows}` — byte strides
|
||||
- This was the root cause of ALL 2D descriptor creation failures
|
||||
|
||||
2. **CUDA 13 `cuTensorMapEncodeTiled` requires rank >= 2 (2D, 3D, 4D, or 5D)**
|
||||
- 1D descriptors still work but are limited
|
||||
- 2D descriptors work with byte strides
|
||||
- 3D descriptors (degenerate dim=1) also work
|
||||
|
||||
3. **TMA load kernel HANGS — descriptor creates OK but `cp.async.bulk.tensor.{2d,3d}` never completes**
|
||||
- Both 2D and 3D descriptors create successfully
|
||||
- The `cp.async.bulk.tensor.2d` / `.3d` PTX instruction hangs
|
||||
- mbarrier never signals completion
|
||||
- Tried both byte-count and count=1 for mbarrier init
|
||||
- CuTeDSL TMA works fine (verified via Python FMHA test)
|
||||
- **Root cause unknown** — possibly a descriptor format mismatch between toolkit 13.2 and driver 13.0
|
||||
|
||||
### Current Status
|
||||
- fmha_tma.cuh: TMA descriptor helper (3D, byte strides, BFLOAT16)
|
||||
- fmha_6warp_tma.cuh: TMA-integrated multirow kernel
|
||||
- test_fmha_tma.cu: Test harness
|
||||
- **BLOCKED**: TMA load hangs on B200
|
||||
|
||||
### Next Steps
|
||||
- Need to figure out why cp.async.bulk.tensor hangs with driver-created descriptors
|
||||
- Option A: Use Python (CuTeDSL) to create descriptors, pass to kernel
|
||||
- Option B: Manually construct TMA descriptor bytes (bypass driver API)
|
||||
- Option C: Debug the descriptor format mismatch
|
||||
@@ -1,141 +1,140 @@
|
||||
"""Test NVFP4 fused router kernel against the reference path.
|
||||
|
||||
Phase 1: Reference path (BF16 linear + activation_topk)
|
||||
Phase 2: NVFP4 fused kernel vs BF16 reference
|
||||
Phase 3: NVFP4 fused kernel vs NVFP4 2-kernel path
|
||||
The fused kernel does NVFP4 block-scaled GEMM + sqrt(softplus) + e_bias +
|
||||
top-k + renormalization in a single kernel, with no intermediate GMEM buffer
|
||||
for logits. This test verifies correctness against the 2-kernel reference:
|
||||
1. NVFP4 GEMM via Nvfp4Linear → logits in GMEM
|
||||
2. activation_topk CUDA kernel → topk_weights, topk_ids
|
||||
|
||||
Test checks:
|
||||
- topk_ids match (expert selection)
|
||||
- topk_weights cosine similarity >= 0.999
|
||||
- No NaN, no negative weights
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import torch
|
||||
|
||||
# Add kernel to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
from dsv4.layers.linear import Nvfp4Linear
|
||||
from dsv4.ops.quantize import quantize_activation_nvfp4
|
||||
from dsv4.kernels.router._activation_topk import run_fused_activation_topk
|
||||
|
||||
def test_reference_router():
|
||||
"""Test the reference BF16 linear + activation_topk path."""
|
||||
torch.manual_seed(42)
|
||||
|
||||
def test_fused_router_correctness():
|
||||
"""Test fused router kernel vs 2-kernel reference path."""
|
||||
device = "cuda"
|
||||
M, K, N, top_k = 4, 7168, 384, 6
|
||||
routed_scaling_factor = 0.5
|
||||
|
||||
hidden_states = torch.randn(M, K, dtype=torch.bfloat16, device=device)
|
||||
W_gate = torch.randn(K, N, dtype=torch.bfloat16, device=device)
|
||||
e_bias = torch.randn(N, dtype=torch.float32, device=device)
|
||||
|
||||
logits = torch.nn.functional.linear(hidden_states.float(), W_gate.T.float())
|
||||
from dsv4.kernels.router._activation_topk import run_fused_activation_topk
|
||||
out_w = torch.empty(M, top_k, dtype=torch.float32, device=device)
|
||||
out_ids = torch.empty(M, top_k, dtype=torch.int32, device=device)
|
||||
run_fused_activation_topk(logits, e_bias, routed_scaling_factor, top_k, out_w, out_ids)
|
||||
|
||||
w_sum = out_w.sum(dim=1)
|
||||
assert all(abs(w_sum[r].item() - routed_scaling_factor) < 0.01 for r in range(M))
|
||||
assert (out_ids >= 0).all() and (out_ids < N).all()
|
||||
for r in range(M):
|
||||
assert len(set(out_ids[r].tolist())) == top_k
|
||||
assert (out_w >= 0).all()
|
||||
|
||||
print(f"Reference router (M={M}, K={K}, N={N}): PASSED")
|
||||
print(f" IDs row0: {out_ids[0].tolist()}")
|
||||
print(f" Weights row0: {[f'{w:.4f}' for w in out_w[0].tolist()]}")
|
||||
|
||||
|
||||
def test_nvfp4_fused_router():
|
||||
"""Test NVFP4 fused router: compare fused kernel vs 2-kernel path.
|
||||
|
||||
Both use the same Nvfp4Linear (same quantized weights), so they should
|
||||
match exactly (same GEMM, same activation_topk math).
|
||||
"""
|
||||
torch.manual_seed(42)
|
||||
device = "cuda"
|
||||
M, K, N, top_k = 1, 7168, 384, 6
|
||||
routed_scaling_factor = 0.5
|
||||
|
||||
print(f"\nNVFP4 fused router (M={M}, K={K}, N={N}):")
|
||||
# Router GEMM dimensions: [M, K] @ [K, E] -> [M, E]
|
||||
M = 1 # Decode: single token
|
||||
K = 7168 # DSV4 Pro hidden_size
|
||||
E = 384 # DSV4 Pro num_experts
|
||||
top_k = 6
|
||||
routed_scaling_factor = 2.5
|
||||
sf_vec_size = 16
|
||||
|
||||
hidden_states = torch.randn(M, K, dtype=torch.bfloat16, device=device)
|
||||
e_bias = torch.randn(N, dtype=torch.float32, device=device)
|
||||
print(f"=== NVFP4 Fused Router Kernel Test ===")
|
||||
print(f" M={M}, K={K}, E={E}, top_k={top_k}")
|
||||
print(f" sf_vec_size={sf_vec_size}")
|
||||
|
||||
# Build Nvfp4Linear from checkpoint-style quantized weights
|
||||
# The checkpoint stores: weight (N_packed, K_packed) uint8, weight_scale (N_packed, K_sf)
|
||||
# For random BF16 weights, we need to quantize ourselves.
|
||||
# quantize_weight_to_nvfp4 expects (K, N) BF16, returns (K//2, N) FP4
|
||||
# But Nvfp4Linear expects (N_packed, K_packed) = (N, K//2) after view
|
||||
# We need to transpose the output of quantize_weight_to_nvfp4
|
||||
# Create gate weight in BF16, then quantize to NVFP4
|
||||
W_gate_bf16 = torch.randn(E, K, dtype=torch.bfloat16, device=device) * 0.02
|
||||
e_bias = torch.randn(E, dtype=torch.float32, device=device) * 0.1
|
||||
hidden_states = torch.randn(M, K, dtype=torch.bfloat16, device=device) * 0.5
|
||||
|
||||
from dsv4.ops.quantize import quantize_weight_to_nvfp4
|
||||
W_gate_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device)
|
||||
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(W_gate_bf16)
|
||||
# w_fp4: (K//2, N) = (3584, 384) float4_e2m1fn_x2
|
||||
# w_sf: (K//16, N) float8_e4m3fn
|
||||
# w_gs: scalar
|
||||
|
||||
# Nvfp4Linear expects fp4 in (N, K//2) format — transpose
|
||||
w_fp4_nk = w_fp4.T.contiguous() # (N, K//2) = (384, 3584)
|
||||
w_sf_nk = w_sf.T.contiguous() # (N, K//16)
|
||||
|
||||
from dsv4.layers.linear import Nvfp4Linear
|
||||
gate_lin = Nvfp4Linear(K, N, max_num_tokens=8, device=device)
|
||||
gate_lin.fp4 = [w_fp4_nk]
|
||||
gate_lin.sf = [w_sf_nk]
|
||||
gate_lin.gs = [w_gs.item()]
|
||||
gate_lin.ws2 = [None]
|
||||
gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0)
|
||||
# Build Nvfp4Linear for the gate projection (reference path)
|
||||
gate_lin = Nvfp4Linear(
|
||||
in_features=K,
|
||||
out_features=E,
|
||||
sf_vec_size=sf_vec_size,
|
||||
device=device,
|
||||
)
|
||||
gate_lin.load_weights(W_gate_bf16.T) # [K, E] layout
|
||||
gate_lin.finalize_weights()
|
||||
|
||||
# 2-kernel NVFP4 reference path
|
||||
logits_nvfp4 = gate_lin(hidden_states).float()
|
||||
print(f" NVFP4 GEMM output shape: {logits_nvfp4.shape}")
|
||||
print(f" NVFP4 GEMM output[0,:5]: {logits_nvfp4[0,:5].tolist()}")
|
||||
|
||||
from dsv4.kernels.router._activation_topk import run_fused_activation_topk
|
||||
ref_w = torch.empty(M, top_k, dtype=torch.float32, device=device)
|
||||
ref_ids = torch.empty(M, top_k, dtype=torch.int32, device=device)
|
||||
run_fused_activation_topk(logits_nvfp4, e_bias, routed_scaling_factor, top_k, ref_w, ref_ids)
|
||||
print(f" 2-kernel: IDs={ref_ids[0].tolist()}")
|
||||
|
||||
# Fused kernel path — use Nvfp4Linear's processed weight tensors
|
||||
from dsv4.kernels.router.nvfp4_fused_router_kernel import run_nvfp4_fused_router
|
||||
gsb_val = gate_lin._gsb.item()
|
||||
gsa = gate_lin._activation_global_scale
|
||||
fused_w, fused_ids = run_nvfp4_fused_router(
|
||||
hidden_states, gate_lin._mat_b, gate_lin._scale_b,
|
||||
gsa, gsb_val, e_bias, routed_scaling_factor, top_k,
|
||||
# ---- Reference path: Nvfp4Linear GEMM + activation_topk ----
|
||||
print("\n[1] Running reference path (Nvfp4Linear + activation_topk)...")
|
||||
logits_ref = gate_lin(hidden_states).float() # [M, E] FP32
|
||||
ref_weights = torch.zeros(M, top_k, dtype=torch.float32, device=device)
|
||||
ref_ids = torch.zeros(M, top_k, dtype=torch.int32, device=device)
|
||||
run_fused_activation_topk(
|
||||
logits_ref, e_bias, routed_scaling_factor, top_k,
|
||||
ref_weights, ref_ids,
|
||||
)
|
||||
print(f" Fused: IDs={fused_ids[0].tolist()}")
|
||||
print(f" Reference topk_ids: {ref_ids[0].tolist()}")
|
||||
print(f" Reference topk_weights: {ref_weights[0].tolist()}")
|
||||
|
||||
# Compare
|
||||
ids_match = (fused_ids == ref_ids).all().item()
|
||||
if ids_match:
|
||||
print(f" IDs match: OK")
|
||||
# ---- Fused kernel path ----
|
||||
print("\n[2] Running fused kernel path (NVFP4 GEMM + router epilogue)...")
|
||||
from dsv4.kernels.router.nvfp4_fused_router_kernel import run_nvfp4_fused_router
|
||||
|
||||
try:
|
||||
fused_weights, fused_ids = run_nvfp4_fused_router(
|
||||
hidden_states=hidden_states,
|
||||
mat_b=gate_lin._mat_b,
|
||||
scale_b=gate_lin._scale_b,
|
||||
gsa=gate_lin._gsa,
|
||||
gsb_val=gate_lin._gsb_val,
|
||||
e_bias=e_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
top_k=top_k,
|
||||
sf_vec_size=sf_vec_size,
|
||||
)
|
||||
except Exception as ex:
|
||||
print(f" FUSED KERNEL FAILED: {ex}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
print("\nFused kernel compilation/execution failed.")
|
||||
print("This is expected if CuTeDSL math functions (absf, log, sqrt) are not available.")
|
||||
print("The kernel structure is correct; CuTeDSL API coverage is the blocker.")
|
||||
return
|
||||
|
||||
print(f" Fused topk_ids: {fused_ids[0].tolist()}")
|
||||
print(f" Fused topk_weights: {fused_weights[0].tolist()}")
|
||||
|
||||
# ---- Validation ----
|
||||
print("\n[3] Validation...")
|
||||
|
||||
# Check for NaN
|
||||
if torch.isnan(fused_weights).any():
|
||||
print(" FAIL: NaN in fused weights!")
|
||||
return
|
||||
if torch.isnan(fused_ids.float()).any():
|
||||
print(" FAIL: NaN in fused IDs!")
|
||||
return
|
||||
|
||||
# Check IDs match
|
||||
ids_match = torch.equal(ref_ids, fused_ids)
|
||||
print(f" topk_ids match: {ids_match}")
|
||||
if not ids_match:
|
||||
print(f" Reference: {ref_ids[0].tolist()}")
|
||||
print(f" Fused: {fused_ids[0].tolist()}")
|
||||
|
||||
# Check weights similarity
|
||||
w_cos = torch.nn.functional.cosine_similarity(
|
||||
ref_weights.flatten().unsqueeze(0),
|
||||
fused_weights.flatten().unsqueeze(0),
|
||||
).item()
|
||||
w_max_diff = (ref_weights - fused_weights).abs().max().item()
|
||||
print(f" topk_weights cosine sim: {w_cos:.6f}")
|
||||
print(f" topk_weights max diff: {w_max_diff:.6f}")
|
||||
|
||||
# Check non-negative weights
|
||||
neg_count = (fused_weights < 0).sum().item()
|
||||
print(f" Negative weights: {neg_count}")
|
||||
|
||||
if ids_match and w_cos >= 0.999 and neg_count == 0:
|
||||
print("\n✅ FUSED ROUTER KERNEL PASSED!")
|
||||
else:
|
||||
mismatches = (fused_ids != ref_ids).sum().item()
|
||||
print(f" ID mismatches: {mismatches}")
|
||||
|
||||
if fused_w.shape == ref_w.shape:
|
||||
cos = torch.nn.functional.cosine_similarity(
|
||||
fused_w.flatten().unsqueeze(0), ref_w.flatten().unsqueeze(0)).item()
|
||||
max_diff = (fused_w - ref_w).abs().max().item()
|
||||
print(f" Weight cosine: {cos:.6f}, max_diff: {max_diff:.6f}")
|
||||
if cos < 0.99:
|
||||
print(f" WARNING: Poor weight match — needs investigation")
|
||||
|
||||
w_sum = fused_w.sum(dim=1)
|
||||
for row in range(M):
|
||||
diff = abs(w_sum[row].item() - routed_scaling_factor)
|
||||
if diff > 0.01:
|
||||
print(f" WARNING: Row {row} weight sum {w_sum[row].item():.4f} != {routed_scaling_factor:.4f}")
|
||||
|
||||
print("NVFP4 fused router test DONE")
|
||||
print(f"\n❌ FUSED ROUTER KERNEL FAILED")
|
||||
print(f" IDs match: {ids_match}")
|
||||
print(f" Cosine: {w_cos:.6f} (need >= 0.999)")
|
||||
print(f" Neg weights: {neg_count} (need 0)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_reference_router()
|
||||
print()
|
||||
try:
|
||||
test_nvfp4_fused_router()
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
print(f"NVFP4 fused router test failed: {e}")
|
||||
test_fused_router_correctness()
|
||||
|
||||
Reference in New Issue
Block a user