Files
nvfp4-megamoe-kernel/dsv4/layers/router.py
biondizzle abfe4485f7 Router: full kernel stack — hash, topk, activation+topk, dense decode/prefill
Step 1: Hash router (hash_router.cu)
- One thread per token, gather from [vocab_size, k] LUT
- Uniform 1/k weights, FP32 output
- 3 MB LUT fits in L2 for repeated decode calls

Step 2: topk_select.cu — general top-k primitive
- Per-thread register min-heap (k=6, compile-time unrolled)
- Shared memory merge: thread 0 merges 64 partial heaps
- Tie-breaking: lower index wins on equal scores
- Reusable by CSA indexer

Step 3: activation_topk.cu — fused sqrt(softplus) + bias + topk + renorm
- Single kernel: all 6 steps of the router math, no intermediate buffers
- Numerically stable softplus: max(x,0) + log1p(exp(-|x|))
- Per-thread heap with unbiased activation co-stored
- Shared memory merge → sort descending → renormalize → store

Step 4: dense_router_decode.py — CuTeDSL fused GEMM kernel (skeleton)
- BF16 GEMM with tcgen05.mma, FP32 accumulator
- Custom epilogue: activation + bias + top-k (structure defined, needs TMA/MMA boilerplate)
- Dispatch: N<=64 uses fused decode, N>64 uses prefill path

Step 5: dense_router_prefill.py — prefill path
- torch.nn.functional.linear for GEMM (DeepGEMM integration deferred)
- Calls activation_topk for fused post-GEMM processing

Step 6: Router class + ops/router.py + test_router.py
- Router: construction-time mode (dense/hash), weight loading, custom_op dispatch
- ops/router.py: torch.library.custom_op wrappers, integer-keyed registry
- test_router.py: spec oracle tests (DO NOT RUN — Carmine is testing Stage C)

Test strategy: each kernel tested against its mathematical spec in FP32.
No reference implementation, no two debug streams. The oracle IS the math.
2026-05-21 21:54:05 +00:00

274 lines
11 KiB
Python

"""DSV4 Router — token-to-expert assignment.
Two routing modes that share an output shape:
- 'dense': sqrt(softplus(X @ W_gate)) + per-expert bias, top-k selection.
Used by MoE layers 3+ (the bulk of the network).
- 'hash': deterministic per-token-ID lookup, uniform weights.
Used by the first 3 MoE layers per DSV4 §2.1.
Both modes produce (topk_weights, topk_ids) suitable for direct
consumption by Nvfp4MoE.run().
CUDA-graph-compatible: pre-allocated buffers, no CPU-GPU syncs.
Selection between modes is by layer_idx at construction time —
the kernel path is fixed once the Router is built so the dispatch
is constant-folded by torch.compile.
"""
from __future__ import annotations
from typing import Optional, Literal
import torch
from dsv4.ops.router import (
register_router,
dense_router_op,
hash_router_op,
)
RouterMode = Literal["dense", "hash"]
class Router:
"""DSV4 expert router.
Per the DeepSeek-V4 paper (§2.1):
- Affinity activation is sqrt(softplus(·)), replacing V3's sigmoid(·).
- Auxiliary-loss-free strategy: a learned per-expert bias (loaded
from checkpoint, frozen at inference) is added to the activation
for SELECTION only. The actual gating weight applied to expert
outputs uses the UNBIASED activation.
- First 3 MoE layers use Hash routing (Roller et al. 2021): a
precomputed [vocab_size, k] LUT mapping token IDs to expert IDs.
No gate GEMM is performed.
- Sequence-wise balance loss is training-only; not applied here.
Parameters
----------
hidden_size : int
Model hidden dimension. Must match W_gate's K dimension.
num_experts : int
Total routed experts (Flash: 256, Pro: 384). Shared experts are
handled separately by Nvfp4SharedExpert.
top_k : int
Experts activated per token. DSV4 uses 6.
routed_scaling_factor : float
Post-renormalization scale on gating weights. DSV3 used 2.5;
verify against the V4 checkpoint config — may be per-layer.
mode : {'dense', 'hash'}
Routing strategy. Decided at construction; cannot change at runtime.
vocab_size : int, optional
Required when mode='hash'. The LUT is [vocab_size, top_k] int32.
max_num_tokens : int
Upper bound on N for pre-allocated buffer sizing.
device : str
CUDA device.
"""
def __init__(
self,
hidden_size: int,
num_experts: int,
top_k: int = 6,
routed_scaling_factor: float = 2.5,
*,
mode: RouterMode,
vocab_size: Optional[int] = None,
max_num_tokens: int = 8192,
device: str = "cuda",
):
if mode == "hash" and vocab_size is None:
raise ValueError("vocab_size is required when mode='hash'")
if mode not in ("dense", "hash"):
raise ValueError(f"unknown router mode: {mode!r}")
self.hidden_size = hidden_size
self.num_experts = num_experts
self.top_k = top_k
self.routed_scaling_factor = routed_scaling_factor
self.mode = mode
self.vocab_size = vocab_size
self.max_num_tokens = max_num_tokens
self.device = device
# ---- Parameters (filled by load_weights / finalize_weights) ----
# Dense mode:
# W_gate: [hidden_size, num_experts] BF16
# e_bias: [num_experts] FP32 — auxiliary-loss-free selection bias.
# Hash mode:
# hash_lut: [vocab_size, top_k] int32 — precomputed expert IDs.
self.W_gate: Optional[torch.Tensor] = None
self.e_bias: Optional[torch.Tensor] = None
self.hash_lut: Optional[torch.Tensor] = None
# ---- Pre-allocated output buffers (cudagraph-safe) ----
self._topk_weights_buf: Optional[torch.Tensor] = None
self._topk_ids_buf: Optional[torch.Tensor] = None
# Runner ID assigned on first call (see custom_op pattern).
self._runner_id: Optional[int] = None
# ------------------------------------------------------------------
# Weight loading
# ------------------------------------------------------------------
def load_weights(
self,
W_gate: Optional[torch.Tensor] = None,
e_bias: Optional[torch.Tensor] = None,
hash_lut: Optional[torch.Tensor] = None,
) -> None:
"""Populate router parameters from a checkpoint shard.
Dense mode expects (W_gate, e_bias). Hash mode expects (hash_lut).
Mismatches with self.mode raise immediately — these errors are
nearly always loader bugs and silent acceptance would mask them.
"""
if self.mode == "dense":
if W_gate is None or e_bias is None:
raise ValueError("dense router needs both W_gate and e_bias")
assert W_gate.shape == (self.hidden_size, self.num_experts), \
f"W_gate shape {tuple(W_gate.shape)} != " \
f"{(self.hidden_size, self.num_experts)}"
assert e_bias.shape == (self.num_experts,), \
f"e_bias shape {tuple(e_bias.shape)} != ({self.num_experts},)"
self.W_gate = W_gate.to(device=self.device, dtype=torch.bfloat16)
self.e_bias = e_bias.to(device=self.device, dtype=torch.float32)
else: # hash
if hash_lut is None:
raise ValueError("hash router needs hash_lut")
assert hash_lut.shape == (self.vocab_size, self.top_k), \
f"hash_lut shape {tuple(hash_lut.shape)} != " \
f"{(self.vocab_size, self.top_k)}"
assert (hash_lut >= 0).all() and (hash_lut < self.num_experts).all(), \
"hash_lut contains out-of-range expert IDs"
self.hash_lut = hash_lut.to(device=self.device, dtype=torch.int32)
def finalize_weights(self) -> None:
"""Allocate output buffers and JIT-compile the routing kernel.
Mirrors the finalize_weights() pattern in Nvfp4Linear: a one-time
setup step called after all parameters are loaded. Triggers
kernel compilation so the first forward isn't paying that cost.
"""
self._topk_weights_buf = torch.empty(
self.max_num_tokens, self.top_k,
dtype=torch.float32, device=self.device,
)
self._topk_ids_buf = torch.empty(
self.max_num_tokens, self.top_k,
dtype=torch.int32, device=self.device,
)
# Eager JIT — dispatcher knows our mode and triggers the right
# kernel's compile path. See dsv4/ops/router.py.
from dsv4.ops.router import warmup_router_compilation
warmup_router_compilation(self)
# ------------------------------------------------------------------
# Forward
# ------------------------------------------------------------------
def __call__(
self,
hidden_states: torch.Tensor,
token_ids: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Produce (topk_weights, topk_ids) for downstream Nvfp4MoE.
Parameters
----------
hidden_states : Tensor [N, hidden_size] bfloat16
Required for dense mode. Ignored for hash mode (kept in the
signature so the call site is mode-agnostic).
token_ids : Tensor [N] int32, optional
Required for hash mode. Ignored for dense mode.
Returns
-------
topk_weights : Tensor [N, top_k] float32
topk_ids : Tensor [N, top_k] int32
Notes
-----
Both outputs are views into pre-allocated buffers — do not retain
them across router calls. Nvfp4MoE consumes them immediately,
which matches its existing contract.
"""
if self._topk_weights_buf is None:
raise RuntimeError("Router.finalize_weights() not called")
if self.mode == "dense":
if hidden_states is None:
raise ValueError("dense router requires hidden_states")
return self._run_dense(hidden_states)
else:
if token_ids is None:
raise ValueError("hash router requires token_ids")
return self._run_hash(token_ids)
# ------------------------------------------------------------------
# Mode-specific dispatch — each routes through a torch.library.custom_op
# so Dynamo / torch.compile treats the kernel as opaque.
# ------------------------------------------------------------------
def _run_dense(self, hidden_states: torch.Tensor):
if self._runner_id is None:
self._runner_id = register_router(self)
return dense_router_op(
hidden_states,
self._runner_id,
self.num_experts,
self.top_k,
)
def _run_hash(self, token_ids: torch.Tensor):
if self._runner_id is None:
self._runner_id = register_router(self)
return hash_router_op(
token_ids,
self._runner_id,
self.top_k,
)
# ------------------------------------------------------------------
# Called by the custom_op dispatch in dsv4/ops/router.py — not by user code.
# ------------------------------------------------------------------
def _run_dense_impl(self, hidden_states: torch.Tensor):
"""Hot-path entry into the fused decode/prefill kernel.
Implementation lives in dsv4/kernels/router/dense_router_decode.py
(small N) or dsv4/kernels/router/dense_router_prefill.py (large N).
The selection is internal to that module — Router doesn't care.
"""
from dsv4.kernels.router import dense_router_dispatch
N = hidden_states.shape[0]
out_w = self._topk_weights_buf[:N]
out_ids = self._topk_ids_buf[:N]
dense_router_dispatch(
hidden_states=hidden_states,
W_gate=self.W_gate,
e_bias=self.e_bias,
routed_scaling_factor=self.routed_scaling_factor,
top_k=self.top_k,
out_weights=out_w,
out_ids=out_ids,
)
return out_w, out_ids
def _run_hash_impl(self, token_ids: torch.Tensor):
"""Hot-path entry into the hash gather kernel.
Implementation lives in dsv4/kernels/cuda/hash_router.cu via the
wrapper in dsv4/ops/router.py.
"""
from dsv4.kernels.router import hash_router_dispatch
N = token_ids.shape[0]
out_w = self._topk_weights_buf[:N]
out_ids = self._topk_ids_buf[:N]
hash_router_dispatch(
token_ids=token_ids,
hash_lut=self.hash_lut,
top_k=self.top_k,
out_weights=out_w, # filled with 1/k
out_ids=out_ids,
)
return out_w, out_ids