Rewrite NVFP4 fused router kernel: MoE-style epilogue replaces broken SMEM merge
CRITICAL REWRITE of nvfp4_fused_router_kernel.py: - REMOVED: Raw pointer SMEM merge (storage.merge_scores.data_ptr()[idx] = val) This crashed the CuTeDSL MLIR optimizer. Never use raw pointer indexing inside CuTeDSL kernels. - REMOVED: Per-thread top-k accumulation + 128-thread SMEM merge. Too complex for MLIR, caused SIGABRT during compilation. - ADDED: MoE-style epilogue (TMEM→regs→activation→SMEM→TMA store→GMEM) using paired copy atoms from CUTLASS (epilogue_tmem_copy_and_partition + epilogue_smem_copy_and_partition). Structurally identical to the proven FusedSwiGLUScaledGroupedGemmKernel epilogue. This SHOULD compile. - Activation: sqrt(softplus(logit)) in registers (replaces SwiGLU) - Output: FP32 activated scores written to GMEM via TMA store - Top-k handled by activation_topk CUDA kernel in Python wrapper Other changes: - _activation_topk.py: Added run_fused_activation_topk_pre_activated() for top-k + renorm on pre-activated scores (PyTorch reference, not CUDA kernel) - dense_router_dispatch_nvfp4_fused: Updated to match new kernel API - Kernel now uses standard _compute_stages() for SMEM budget calculation - Kernel now uses compute_epilogue_tile_shape() for epi_tile (not hardcoded) - C pipeline (PipelineTmaStore) added for SMEM→GMEM overlap
This commit is contained in:
@@ -51,3 +51,44 @@ def run_fused_activation_topk(
|
||||
top_k,
|
||||
out_weights, out_ids,
|
||||
)
|
||||
|
||||
|
||||
def run_fused_activation_topk_pre_activated(
|
||||
activated_scores: torch.Tensor, # [N, E] FP32, already sqrt(softplus(logits))
|
||||
e_bias: torch.Tensor, # [E] FP32
|
||||
routed_scaling_factor: float,
|
||||
top_k: int,
|
||||
out_weights: torch.Tensor, # [N, top_k] FP32, pre-allocated
|
||||
out_ids: torch.Tensor, # [N, top_k] int32, pre-allocated
|
||||
):
|
||||
"""Run top-k + renormalization on pre-activated scores.
|
||||
|
||||
The CUDA kernel is called with logits=activated_scores.
|
||||
Since the kernel computes sqrt(softplus(logits)) + e_bias,
|
||||
we pass e_bias=0 and add e_bias ourselves in a pre-step,
|
||||
then call the kernel with the scores (which are already activated).
|
||||
|
||||
Actually, simpler approach: just add e_bias to activated_scores,
|
||||
then call the standard kernel with e_bias=0. The kernel will
|
||||
compute sqrt(softplus(score + 0)) = sqrt(softplus(score)).
|
||||
But that double-applies softplus!
|
||||
|
||||
Correct approach: Add a dedicated kernel entry point that
|
||||
skips activation and just does top-k + renorm.
|
||||
For now, use the existing kernel with a workaround:
|
||||
pre-add e_bias to get selection scores, do top-k on those,
|
||||
then gather the unbiased activations for weights.
|
||||
"""
|
||||
# Step 1: selection scores = activated + e_bias
|
||||
sel_scores = activated_scores + e_bias.unsqueeze(0) # [N, E]
|
||||
|
||||
# Step 2: top-k on selection scores
|
||||
topk_vals, topk_indices = sel_scores.topk(top_k, dim=-1) # [N, k]
|
||||
|
||||
# Step 3: gather unbiased activations (without e_bias)
|
||||
raw_w = activated_scores.gather(1, topk_indices) # [N, k]
|
||||
|
||||
# Step 4: renormalize
|
||||
row_sum = raw_w.sum(dim=-1, keepdim=True).clamp(min=1e-9)
|
||||
out_weights.copy_(raw_w / row_sum * routed_scaling_factor)
|
||||
out_ids.copy_(topk_indices.to(torch.int32))
|
||||
|
||||
@@ -74,20 +74,18 @@ def dense_router_dispatch_nvfp4_fused(
|
||||
):
|
||||
"""Dispatch the dense router (NVFP4 fused single-kernel path).
|
||||
|
||||
Single kernel: NVFP4 blockscaled GEMM + fused router epilogue.
|
||||
Activation is quantized to NVFP4 inside the kernel.
|
||||
No intermediate GMEM buffer. Pure NVFP4 + Blackwell tensor cores.
|
||||
Phase 1: CuTeDSL NVFP4 blockscaled GEMM + sqrt(softplus) epilogue.
|
||||
Activation is quantized to NVFP4, GEMM runs on Blackwell tensor cores,
|
||||
sqrt(softplus) is fused in the epilogue (TMEM→regs→activation→SMEM→GMEM).
|
||||
Writes FP32 activated scores to GMEM. No intermediate BF16 logits.
|
||||
|
||||
Phase 2: top-k + renorm on activated scores.
|
||||
"""
|
||||
from dsv4.kernels.router.nvfp4_fused_router_kernel import run_nvfp4_fused_router
|
||||
|
||||
# Global scales:
|
||||
# gsa (activation global scale) = input_scale from checkpoint
|
||||
# gsb (weight global scale) = weight_scale_2 (NOT input_scale * ws2)
|
||||
gsa = gate_input_scale.float().item() if gate_input_scale.numel() == 1 else gate_input_scale.float().mean().item()
|
||||
gsb_val = gate_ws2.float().item() if gate_ws2.numel() == 1 else gate_ws2.float().mean().item()
|
||||
|
||||
# The fused kernel handles activation quantization internally
|
||||
# and writes directly to out_weights / out_ids
|
||||
result_w, result_ids = run_nvfp4_fused_router(
|
||||
hidden_states=hidden_states,
|
||||
mat_b=gate_weight,
|
||||
@@ -98,7 +96,6 @@ def dense_router_dispatch_nvfp4_fused(
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
top_k=top_k,
|
||||
)
|
||||
# Copy results into pre-allocated buffers
|
||||
N = hidden_states.shape[0]
|
||||
out_weights[:N].copy_(result_w[:N])
|
||||
out_ids[:N].copy_(result_ids[:N])
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user