Files
nvfp4-megamoe-kernel/dsv4/kernels/router/dense_router_prefill.py
biondizzle a813d2824b Router: clean up dense_router_decode.py — realistic architecture, no fake code
The first draft had a fake CuTeDSL kernel body with pass statements and
Python lists as register heaps. That is not the right way. This commit
replaces it with honest documentation of what the kernel does and what
needs to happen.

Current working path:
- All N routes through torch.nn.functional.linear + activation_topk.cu
- activation_topk is a single-pass fused CUDA kernel (all 6 steps)
- This is correct and performant for all N

CuTeDSL fused decode kernel (DenseRouterDecodeKernel):
- Class structure and warp specialization defined
- Full documentation of the TMA/MMA/epilogue pipeline
- The novel part is the row-level top-k epilogue (cross-subtile heap)
- EFC framework does not apply — our epilogue is not per-element
- Implementation deferred until profiling shows the GMEM round-trip
  on logits matters for decode latency

No fake code. No pass statements. No Python lists as GPU registers.
The working path is the activation_topk kernel. The CuTeDSL kernel
will be built on top of it when the optimization is needed.
2026-05-21 21:58:31 +00:00

39 lines
1.5 KiB
Python

"""DSV4 Dense Router — prefill path.
For prefill with N >= ~256, the gate GEMM has enough work to make DeepGEMM
(or the standard BF16 persistent GEMM) the better choice for the matmul,
with a separate fused activation+top-k kernel on the output.
Currently both decode and prefill go through this path (GEMM + activation_topk).
The CuTeDSL fused decode kernel will replace the small-N path when complete.
"""
from __future__ import annotations
import torch
def dense_router_prefill(
hidden_states: torch.Tensor, # [N, hidden_size] BF16
W_gate: torch.Tensor, # [hidden_size, num_experts] BF16
e_bias: torch.Tensor, # [num_experts] FP32
routed_scaling_factor: float,
top_k: int,
out_weights: torch.Tensor, # [N, top_k] FP32
out_ids: torch.Tensor, # [N, top_k] int32
):
"""Prefill path: BF16 GEMM → FP32 logits → fused activation + top-k.
Step 1: logits = hidden_states @ W_gate (BF16 GEMM, FP32 output)
Step 2: fused kernel: act=sqrt(softplus(logits)), score=act+bias,
top-k, renorm → (out_weights, out_ids)
"""
# FP32 GEMM for numerical accuracy in the activation.
# BF16 accumulator would lose too much precision for softplus.
logits = torch.nn.functional.linear(hidden_states.float(), W_gate.float())
from dsv4.kernels.router._activation_topk import run_fused_activation_topk
run_fused_activation_topk(
logits, e_bias, routed_scaling_factor, top_k,
out_weights, out_ids,
)