Router: full kernel stack — hash, topk, activation+topk, dense decode/prefill
Step 1: Hash router (hash_router.cu)
- One thread per token, gather from [vocab_size, k] LUT
- Uniform 1/k weights, FP32 output
- 3 MB LUT fits in L2 for repeated decode calls
Step 2: topk_select.cu — general top-k primitive
- Per-thread register min-heap (k=6, compile-time unrolled)
- Shared memory merge: thread 0 merges 64 partial heaps
- Tie-breaking: lower index wins on equal scores
- Reusable by CSA indexer
Step 3: activation_topk.cu — fused sqrt(softplus) + bias + topk + renorm
- Single kernel: all 6 steps of the router math, no intermediate buffers
- Numerically stable softplus: max(x,0) + log1p(exp(-|x|))
- Per-thread heap with unbiased activation co-stored
- Shared memory merge → sort descending → renormalize → store
Step 4: dense_router_decode.py — CuTeDSL fused GEMM kernel (skeleton)
- BF16 GEMM with tcgen05.mma, FP32 accumulator
- Custom epilogue: activation + bias + top-k (structure defined, needs TMA/MMA boilerplate)
- Dispatch: N<=64 uses fused decode, N>64 uses prefill path
Step 5: dense_router_prefill.py — prefill path
- torch.nn.functional.linear for GEMM (DeepGEMM integration deferred)
- Calls activation_topk for fused post-GEMM processing
Step 6: Router class + ops/router.py + test_router.py
- Router: construction-time mode (dense/hash), weight loading, custom_op dispatch
- ops/router.py: torch.library.custom_op wrappers, integer-keyed registry
- test_router.py: spec oracle tests (DO NOT RUN — Carmine is testing Stage C)
Test strategy: each kernel tested against its mathematical spec in FP32.
No reference implementation, no two debug streams. The oracle IS the math.
2026-05-21 21:54:05 +00:00
|
|
|
"""DSV4 Router — token-to-expert assignment.
|
|
|
|
|
|
|
|
|
|
Two routing modes that share an output shape:
|
|
|
|
|
- 'dense': sqrt(softplus(X @ W_gate)) + per-expert bias, top-k selection.
|
|
|
|
|
Used by MoE layers 3+ (the bulk of the network).
|
|
|
|
|
- 'hash': deterministic per-token-ID lookup, uniform weights.
|
|
|
|
|
Used by the first 3 MoE layers per DSV4 §2.1.
|
|
|
|
|
|
|
|
|
|
Both modes produce (topk_weights, topk_ids) suitable for direct
|
|
|
|
|
consumption by Nvfp4MoE.run().
|
|
|
|
|
|
|
|
|
|
CUDA-graph-compatible: pre-allocated buffers, no CPU-GPU syncs.
|
|
|
|
|
Selection between modes is by layer_idx at construction time —
|
|
|
|
|
the kernel path is fixed once the Router is built so the dispatch
|
|
|
|
|
is constant-folded by torch.compile.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
from typing import Optional, Literal
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
from dsv4.ops.router import (
|
|
|
|
|
register_router,
|
|
|
|
|
dense_router_op,
|
|
|
|
|
hash_router_op,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
RouterMode = Literal["dense", "hash"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Router:
|
|
|
|
|
"""DSV4 expert router.
|
|
|
|
|
|
|
|
|
|
Per the DeepSeek-V4 paper (§2.1):
|
|
|
|
|
- Affinity activation is sqrt(softplus(·)), replacing V3's sigmoid(·).
|
|
|
|
|
- Auxiliary-loss-free strategy: a learned per-expert bias (loaded
|
|
|
|
|
from checkpoint, frozen at inference) is added to the activation
|
|
|
|
|
for SELECTION only. The actual gating weight applied to expert
|
|
|
|
|
outputs uses the UNBIASED activation.
|
|
|
|
|
- First 3 MoE layers use Hash routing (Roller et al. 2021): a
|
|
|
|
|
precomputed [vocab_size, k] LUT mapping token IDs to expert IDs.
|
|
|
|
|
No gate GEMM is performed.
|
|
|
|
|
- Sequence-wise balance loss is training-only; not applied here.
|
|
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
|
----------
|
|
|
|
|
hidden_size : int
|
|
|
|
|
Model hidden dimension. Must match W_gate's K dimension.
|
|
|
|
|
num_experts : int
|
|
|
|
|
Total routed experts (Flash: 256, Pro: 384). Shared experts are
|
|
|
|
|
handled separately by Nvfp4SharedExpert.
|
|
|
|
|
top_k : int
|
|
|
|
|
Experts activated per token. DSV4 uses 6.
|
|
|
|
|
routed_scaling_factor : float
|
|
|
|
|
Post-renormalization scale on gating weights. DSV3 used 2.5;
|
|
|
|
|
verify against the V4 checkpoint config — may be per-layer.
|
|
|
|
|
mode : {'dense', 'hash'}
|
|
|
|
|
Routing strategy. Decided at construction; cannot change at runtime.
|
|
|
|
|
vocab_size : int, optional
|
|
|
|
|
Required when mode='hash'. The LUT is [vocab_size, top_k] int32.
|
|
|
|
|
max_num_tokens : int
|
|
|
|
|
Upper bound on N for pre-allocated buffer sizing.
|
|
|
|
|
device : str
|
|
|
|
|
CUDA device.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
hidden_size: int,
|
|
|
|
|
num_experts: int,
|
|
|
|
|
top_k: int = 6,
|
|
|
|
|
routed_scaling_factor: float = 2.5,
|
|
|
|
|
*,
|
|
|
|
|
mode: RouterMode,
|
|
|
|
|
vocab_size: Optional[int] = None,
|
|
|
|
|
max_num_tokens: int = 8192,
|
|
|
|
|
device: str = "cuda",
|
|
|
|
|
):
|
|
|
|
|
if mode == "hash" and vocab_size is None:
|
|
|
|
|
raise ValueError("vocab_size is required when mode='hash'")
|
|
|
|
|
if mode not in ("dense", "hash"):
|
|
|
|
|
raise ValueError(f"unknown router mode: {mode!r}")
|
|
|
|
|
|
|
|
|
|
self.hidden_size = hidden_size
|
|
|
|
|
self.num_experts = num_experts
|
|
|
|
|
self.top_k = top_k
|
|
|
|
|
self.routed_scaling_factor = routed_scaling_factor
|
|
|
|
|
self.mode = mode
|
|
|
|
|
self.vocab_size = vocab_size
|
|
|
|
|
self.max_num_tokens = max_num_tokens
|
|
|
|
|
self.device = device
|
|
|
|
|
|
|
|
|
|
# ---- Parameters (filled by load_weights / finalize_weights) ----
|
|
|
|
|
# Dense mode:
|
|
|
|
|
# W_gate: [hidden_size, num_experts] BF16
|
|
|
|
|
# e_bias: [num_experts] FP32 — auxiliary-loss-free selection bias.
|
|
|
|
|
# Hash mode:
|
|
|
|
|
# hash_lut: [vocab_size, top_k] int32 — precomputed expert IDs.
|
|
|
|
|
self.W_gate: Optional[torch.Tensor] = None
|
|
|
|
|
self.e_bias: Optional[torch.Tensor] = None
|
|
|
|
|
self.hash_lut: Optional[torch.Tensor] = None
|
|
|
|
|
|
|
|
|
|
# ---- Pre-allocated output buffers (cudagraph-safe) ----
|
|
|
|
|
self._topk_weights_buf: Optional[torch.Tensor] = None
|
|
|
|
|
self._topk_ids_buf: Optional[torch.Tensor] = None
|
|
|
|
|
|
|
|
|
|
# Runner ID assigned on first call (see custom_op pattern).
|
|
|
|
|
self._runner_id: Optional[int] = None
|
|
|
|
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
# Weight loading
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
def load_weights(
|
|
|
|
|
self,
|
|
|
|
|
W_gate: Optional[torch.Tensor] = None,
|
|
|
|
|
e_bias: Optional[torch.Tensor] = None,
|
|
|
|
|
hash_lut: Optional[torch.Tensor] = None,
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Populate router parameters from a checkpoint shard.
|
|
|
|
|
|
|
|
|
|
Dense mode expects (W_gate, e_bias). Hash mode expects (hash_lut).
|
|
|
|
|
Mismatches with self.mode raise immediately — these errors are
|
|
|
|
|
nearly always loader bugs and silent acceptance would mask them.
|
|
|
|
|
"""
|
|
|
|
|
if self.mode == "dense":
|
|
|
|
|
if W_gate is None or e_bias is None:
|
|
|
|
|
raise ValueError("dense router needs both W_gate and e_bias")
|
|
|
|
|
assert W_gate.shape == (self.hidden_size, self.num_experts), \
|
|
|
|
|
f"W_gate shape {tuple(W_gate.shape)} != " \
|
|
|
|
|
f"{(self.hidden_size, self.num_experts)}"
|
|
|
|
|
assert e_bias.shape == (self.num_experts,), \
|
|
|
|
|
f"e_bias shape {tuple(e_bias.shape)} != ({self.num_experts},)"
|
|
|
|
|
self.W_gate = W_gate.to(device=self.device, dtype=torch.bfloat16)
|
|
|
|
|
self.e_bias = e_bias.to(device=self.device, dtype=torch.float32)
|
|
|
|
|
else: # hash
|
|
|
|
|
if hash_lut is None:
|
|
|
|
|
raise ValueError("hash router needs hash_lut")
|
|
|
|
|
assert hash_lut.shape == (self.vocab_size, self.top_k), \
|
|
|
|
|
f"hash_lut shape {tuple(hash_lut.shape)} != " \
|
|
|
|
|
f"{(self.vocab_size, self.top_k)}"
|
|
|
|
|
assert (hash_lut >= 0).all() and (hash_lut < self.num_experts).all(), \
|
|
|
|
|
"hash_lut contains out-of-range expert IDs"
|
|
|
|
|
self.hash_lut = hash_lut.to(device=self.device, dtype=torch.int32)
|
|
|
|
|
|
|
|
|
|
def finalize_weights(self) -> None:
|
|
|
|
|
"""Allocate output buffers and JIT-compile the routing kernel.
|
|
|
|
|
|
|
|
|
|
Mirrors the finalize_weights() pattern in Nvfp4Linear: a one-time
|
|
|
|
|
setup step called after all parameters are loaded. Triggers
|
|
|
|
|
kernel compilation so the first forward isn't paying that cost.
|
|
|
|
|
"""
|
|
|
|
|
self._topk_weights_buf = torch.empty(
|
|
|
|
|
self.max_num_tokens, self.top_k,
|
|
|
|
|
dtype=torch.float32, device=self.device,
|
|
|
|
|
)
|
|
|
|
|
self._topk_ids_buf = torch.empty(
|
|
|
|
|
self.max_num_tokens, self.top_k,
|
|
|
|
|
dtype=torch.int32, device=self.device,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Eager JIT — dispatcher knows our mode and triggers the right
|
|
|
|
|
# kernel's compile path. See dsv4/ops/router.py.
|
|
|
|
|
from dsv4.ops.router import warmup_router_compilation
|
|
|
|
|
warmup_router_compilation(self)
|
|
|
|
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
# Forward
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
def __call__(
|
|
|
|
|
self,
|
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
|
token_ids: Optional[torch.Tensor] = None,
|
|
|
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
"""Produce (topk_weights, topk_ids) for downstream Nvfp4MoE.
|
|
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
|
----------
|
|
|
|
|
hidden_states : Tensor [N, hidden_size] bfloat16
|
|
|
|
|
Required for dense mode. Ignored for hash mode (kept in the
|
|
|
|
|
signature so the call site is mode-agnostic).
|
|
|
|
|
token_ids : Tensor [N] int32, optional
|
|
|
|
|
Required for hash mode. Ignored for dense mode.
|
|
|
|
|
|
|
|
|
|
Returns
|
|
|
|
|
-------
|
|
|
|
|
topk_weights : Tensor [N, top_k] float32
|
|
|
|
|
topk_ids : Tensor [N, top_k] int32
|
|
|
|
|
|
|
|
|
|
Notes
|
|
|
|
|
-----
|
|
|
|
|
Both outputs are views into pre-allocated buffers — do not retain
|
|
|
|
|
them across router calls. Nvfp4MoE consumes them immediately,
|
|
|
|
|
which matches its existing contract.
|
|
|
|
|
"""
|
|
|
|
|
if self._topk_weights_buf is None:
|
|
|
|
|
raise RuntimeError("Router.finalize_weights() not called")
|
|
|
|
|
|
|
|
|
|
if self.mode == "dense":
|
|
|
|
|
if hidden_states is None:
|
|
|
|
|
raise ValueError("dense router requires hidden_states")
|
|
|
|
|
return self._run_dense(hidden_states)
|
|
|
|
|
else:
|
|
|
|
|
if token_ids is None:
|
|
|
|
|
raise ValueError("hash router requires token_ids")
|
|
|
|
|
return self._run_hash(token_ids)
|
|
|
|
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
# Mode-specific dispatch — each routes through a torch.library.custom_op
|
|
|
|
|
# so Dynamo / torch.compile treats the kernel as opaque.
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
def _run_dense(self, hidden_states: torch.Tensor):
|
|
|
|
|
if self._runner_id is None:
|
|
|
|
|
self._runner_id = register_router(self)
|
|
|
|
|
return dense_router_op(
|
|
|
|
|
hidden_states,
|
|
|
|
|
self._runner_id,
|
|
|
|
|
self.num_experts,
|
|
|
|
|
self.top_k,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _run_hash(self, token_ids: torch.Tensor):
|
|
|
|
|
if self._runner_id is None:
|
|
|
|
|
self._runner_id = register_router(self)
|
|
|
|
|
return hash_router_op(
|
|
|
|
|
token_ids,
|
|
|
|
|
self._runner_id,
|
|
|
|
|
self.top_k,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
# Called by the custom_op dispatch in dsv4/ops/router.py — not by user code.
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
def _run_dense_impl(self, hidden_states: torch.Tensor):
|
|
|
|
|
"""Hot-path entry into the fused decode/prefill kernel.
|
|
|
|
|
|
|
|
|
|
Implementation lives in dsv4/kernels/router/dense_router_decode.py
|
|
|
|
|
(small N) or dsv4/kernels/router/dense_router_prefill.py (large N).
|
|
|
|
|
The selection is internal to that module — Router doesn't care.
|
|
|
|
|
"""
|
|
|
|
|
from dsv4.kernels.router import dense_router_dispatch
|
|
|
|
|
N = hidden_states.shape[0]
|
|
|
|
|
out_w = self._topk_weights_buf[:N]
|
|
|
|
|
out_ids = self._topk_ids_buf[:N]
|
|
|
|
|
dense_router_dispatch(
|
|
|
|
|
hidden_states=hidden_states,
|
|
|
|
|
W_gate=self.W_gate,
|
|
|
|
|
e_bias=self.e_bias,
|
|
|
|
|
routed_scaling_factor=self.routed_scaling_factor,
|
|
|
|
|
top_k=self.top_k,
|
|
|
|
|
out_weights=out_w,
|
|
|
|
|
out_ids=out_ids,
|
|
|
|
|
)
|
|
|
|
|
return out_w, out_ids
|
|
|
|
|
|
|
|
|
|
def _run_hash_impl(self, token_ids: torch.Tensor):
|
|
|
|
|
"""Hot-path entry into the hash gather kernel.
|
|
|
|
|
|
|
|
|
|
Implementation lives in dsv4/kernels/cuda/hash_router.cu via the
|
|
|
|
|
wrapper in dsv4/ops/router.py.
|
|
|
|
|
"""
|
|
|
|
|
from dsv4.kernels.router import hash_router_dispatch
|
|
|
|
|
N = token_ids.shape[0]
|
|
|
|
|
out_w = self._topk_weights_buf[:N]
|
|
|
|
|
out_ids = self._topk_ids_buf[:N]
|
|
|
|
|
hash_router_dispatch(
|
|
|
|
|
token_ids=token_ids,
|
|
|
|
|
hash_lut=self.hash_lut,
|
|
|
|
|
top_k=self.top_k,
|
|
|
|
|
out_weights=out_w, # filled with 1/k
|
|
|
|
|
out_ids=out_ids,
|
|
|
|
|
)
|
|
|
|
|
return out_w, out_ids
|