Step 1: Hash router (hash_router.cu) - One thread per token, gather from [vocab_size, k] LUT - Uniform 1/k weights, FP32 output - 3 MB LUT fits in L2 for repeated decode calls Step 2: topk_select.cu — general top-k primitive - Per-thread register min-heap (k=6, compile-time unrolled) - Shared memory merge: thread 0 merges 64 partial heaps - Tie-breaking: lower index wins on equal scores - Reusable by CSA indexer Step 3: activation_topk.cu — fused sqrt(softplus) + bias + topk + renorm - Single kernel: all 6 steps of the router math, no intermediate buffers - Numerically stable softplus: max(x,0) + log1p(exp(-|x|)) - Per-thread heap with unbiased activation co-stored - Shared memory merge → sort descending → renormalize → store Step 4: dense_router_decode.py — CuTeDSL fused GEMM kernel (skeleton) - BF16 GEMM with tcgen05.mma, FP32 accumulator - Custom epilogue: activation + bias + top-k (structure defined, needs TMA/MMA boilerplate) - Dispatch: N<=64 uses fused decode, N>64 uses prefill path Step 5: dense_router_prefill.py — prefill path - torch.nn.functional.linear for GEMM (DeepGEMM integration deferred) - Calls activation_topk for fused post-GEMM processing Step 6: Router class + ops/router.py + test_router.py - Router: construction-time mode (dense/hash), weight loading, custom_op dispatch - ops/router.py: torch.library.custom_op wrappers, integer-keyed registry - test_router.py: spec oracle tests (DO NOT RUN — Carmine is testing Stage C) Test strategy: each kernel tested against its mathematical spec in FP32. No reference implementation, no two debug streams. The oracle IS the math.
274 lines
11 KiB
Python
274 lines
11 KiB
Python
"""DSV4 Router — token-to-expert assignment.
|
|
|
|
Two routing modes that share an output shape:
|
|
- 'dense': sqrt(softplus(X @ W_gate)) + per-expert bias, top-k selection.
|
|
Used by MoE layers 3+ (the bulk of the network).
|
|
- 'hash': deterministic per-token-ID lookup, uniform weights.
|
|
Used by the first 3 MoE layers per DSV4 §2.1.
|
|
|
|
Both modes produce (topk_weights, topk_ids) suitable for direct
|
|
consumption by Nvfp4MoE.run().
|
|
|
|
CUDA-graph-compatible: pre-allocated buffers, no CPU-GPU syncs.
|
|
Selection between modes is by layer_idx at construction time —
|
|
the kernel path is fixed once the Router is built so the dispatch
|
|
is constant-folded by torch.compile.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
from typing import Optional, Literal
|
|
import torch
|
|
|
|
from dsv4.ops.router import (
|
|
register_router,
|
|
dense_router_op,
|
|
hash_router_op,
|
|
)
|
|
|
|
|
|
RouterMode = Literal["dense", "hash"]
|
|
|
|
|
|
class Router:
|
|
"""DSV4 expert router.
|
|
|
|
Per the DeepSeek-V4 paper (§2.1):
|
|
- Affinity activation is sqrt(softplus(·)), replacing V3's sigmoid(·).
|
|
- Auxiliary-loss-free strategy: a learned per-expert bias (loaded
|
|
from checkpoint, frozen at inference) is added to the activation
|
|
for SELECTION only. The actual gating weight applied to expert
|
|
outputs uses the UNBIASED activation.
|
|
- First 3 MoE layers use Hash routing (Roller et al. 2021): a
|
|
precomputed [vocab_size, k] LUT mapping token IDs to expert IDs.
|
|
No gate GEMM is performed.
|
|
- Sequence-wise balance loss is training-only; not applied here.
|
|
|
|
Parameters
|
|
----------
|
|
hidden_size : int
|
|
Model hidden dimension. Must match W_gate's K dimension.
|
|
num_experts : int
|
|
Total routed experts (Flash: 256, Pro: 384). Shared experts are
|
|
handled separately by Nvfp4SharedExpert.
|
|
top_k : int
|
|
Experts activated per token. DSV4 uses 6.
|
|
routed_scaling_factor : float
|
|
Post-renormalization scale on gating weights. DSV3 used 2.5;
|
|
verify against the V4 checkpoint config — may be per-layer.
|
|
mode : {'dense', 'hash'}
|
|
Routing strategy. Decided at construction; cannot change at runtime.
|
|
vocab_size : int, optional
|
|
Required when mode='hash'. The LUT is [vocab_size, top_k] int32.
|
|
max_num_tokens : int
|
|
Upper bound on N for pre-allocated buffer sizing.
|
|
device : str
|
|
CUDA device.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
num_experts: int,
|
|
top_k: int = 6,
|
|
routed_scaling_factor: float = 2.5,
|
|
*,
|
|
mode: RouterMode,
|
|
vocab_size: Optional[int] = None,
|
|
max_num_tokens: int = 8192,
|
|
device: str = "cuda",
|
|
):
|
|
if mode == "hash" and vocab_size is None:
|
|
raise ValueError("vocab_size is required when mode='hash'")
|
|
if mode not in ("dense", "hash"):
|
|
raise ValueError(f"unknown router mode: {mode!r}")
|
|
|
|
self.hidden_size = hidden_size
|
|
self.num_experts = num_experts
|
|
self.top_k = top_k
|
|
self.routed_scaling_factor = routed_scaling_factor
|
|
self.mode = mode
|
|
self.vocab_size = vocab_size
|
|
self.max_num_tokens = max_num_tokens
|
|
self.device = device
|
|
|
|
# ---- Parameters (filled by load_weights / finalize_weights) ----
|
|
# Dense mode:
|
|
# W_gate: [hidden_size, num_experts] BF16
|
|
# e_bias: [num_experts] FP32 — auxiliary-loss-free selection bias.
|
|
# Hash mode:
|
|
# hash_lut: [vocab_size, top_k] int32 — precomputed expert IDs.
|
|
self.W_gate: Optional[torch.Tensor] = None
|
|
self.e_bias: Optional[torch.Tensor] = None
|
|
self.hash_lut: Optional[torch.Tensor] = None
|
|
|
|
# ---- Pre-allocated output buffers (cudagraph-safe) ----
|
|
self._topk_weights_buf: Optional[torch.Tensor] = None
|
|
self._topk_ids_buf: Optional[torch.Tensor] = None
|
|
|
|
# Runner ID assigned on first call (see custom_op pattern).
|
|
self._runner_id: Optional[int] = None
|
|
|
|
# ------------------------------------------------------------------
|
|
# Weight loading
|
|
# ------------------------------------------------------------------
|
|
def load_weights(
|
|
self,
|
|
W_gate: Optional[torch.Tensor] = None,
|
|
e_bias: Optional[torch.Tensor] = None,
|
|
hash_lut: Optional[torch.Tensor] = None,
|
|
) -> None:
|
|
"""Populate router parameters from a checkpoint shard.
|
|
|
|
Dense mode expects (W_gate, e_bias). Hash mode expects (hash_lut).
|
|
Mismatches with self.mode raise immediately — these errors are
|
|
nearly always loader bugs and silent acceptance would mask them.
|
|
"""
|
|
if self.mode == "dense":
|
|
if W_gate is None or e_bias is None:
|
|
raise ValueError("dense router needs both W_gate and e_bias")
|
|
assert W_gate.shape == (self.hidden_size, self.num_experts), \
|
|
f"W_gate shape {tuple(W_gate.shape)} != " \
|
|
f"{(self.hidden_size, self.num_experts)}"
|
|
assert e_bias.shape == (self.num_experts,), \
|
|
f"e_bias shape {tuple(e_bias.shape)} != ({self.num_experts},)"
|
|
self.W_gate = W_gate.to(device=self.device, dtype=torch.bfloat16)
|
|
self.e_bias = e_bias.to(device=self.device, dtype=torch.float32)
|
|
else: # hash
|
|
if hash_lut is None:
|
|
raise ValueError("hash router needs hash_lut")
|
|
assert hash_lut.shape == (self.vocab_size, self.top_k), \
|
|
f"hash_lut shape {tuple(hash_lut.shape)} != " \
|
|
f"{(self.vocab_size, self.top_k)}"
|
|
assert (hash_lut >= 0).all() and (hash_lut < self.num_experts).all(), \
|
|
"hash_lut contains out-of-range expert IDs"
|
|
self.hash_lut = hash_lut.to(device=self.device, dtype=torch.int32)
|
|
|
|
def finalize_weights(self) -> None:
|
|
"""Allocate output buffers and JIT-compile the routing kernel.
|
|
|
|
Mirrors the finalize_weights() pattern in Nvfp4Linear: a one-time
|
|
setup step called after all parameters are loaded. Triggers
|
|
kernel compilation so the first forward isn't paying that cost.
|
|
"""
|
|
self._topk_weights_buf = torch.empty(
|
|
self.max_num_tokens, self.top_k,
|
|
dtype=torch.float32, device=self.device,
|
|
)
|
|
self._topk_ids_buf = torch.empty(
|
|
self.max_num_tokens, self.top_k,
|
|
dtype=torch.int32, device=self.device,
|
|
)
|
|
|
|
# Eager JIT — dispatcher knows our mode and triggers the right
|
|
# kernel's compile path. See dsv4/ops/router.py.
|
|
from dsv4.ops.router import warmup_router_compilation
|
|
warmup_router_compilation(self)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Forward
|
|
# ------------------------------------------------------------------
|
|
def __call__(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
token_ids: Optional[torch.Tensor] = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""Produce (topk_weights, topk_ids) for downstream Nvfp4MoE.
|
|
|
|
Parameters
|
|
----------
|
|
hidden_states : Tensor [N, hidden_size] bfloat16
|
|
Required for dense mode. Ignored for hash mode (kept in the
|
|
signature so the call site is mode-agnostic).
|
|
token_ids : Tensor [N] int32, optional
|
|
Required for hash mode. Ignored for dense mode.
|
|
|
|
Returns
|
|
-------
|
|
topk_weights : Tensor [N, top_k] float32
|
|
topk_ids : Tensor [N, top_k] int32
|
|
|
|
Notes
|
|
-----
|
|
Both outputs are views into pre-allocated buffers — do not retain
|
|
them across router calls. Nvfp4MoE consumes them immediately,
|
|
which matches its existing contract.
|
|
"""
|
|
if self._topk_weights_buf is None:
|
|
raise RuntimeError("Router.finalize_weights() not called")
|
|
|
|
if self.mode == "dense":
|
|
if hidden_states is None:
|
|
raise ValueError("dense router requires hidden_states")
|
|
return self._run_dense(hidden_states)
|
|
else:
|
|
if token_ids is None:
|
|
raise ValueError("hash router requires token_ids")
|
|
return self._run_hash(token_ids)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Mode-specific dispatch — each routes through a torch.library.custom_op
|
|
# so Dynamo / torch.compile treats the kernel as opaque.
|
|
# ------------------------------------------------------------------
|
|
def _run_dense(self, hidden_states: torch.Tensor):
|
|
if self._runner_id is None:
|
|
self._runner_id = register_router(self)
|
|
return dense_router_op(
|
|
hidden_states,
|
|
self._runner_id,
|
|
self.num_experts,
|
|
self.top_k,
|
|
)
|
|
|
|
def _run_hash(self, token_ids: torch.Tensor):
|
|
if self._runner_id is None:
|
|
self._runner_id = register_router(self)
|
|
return hash_router_op(
|
|
token_ids,
|
|
self._runner_id,
|
|
self.top_k,
|
|
)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Called by the custom_op dispatch in dsv4/ops/router.py — not by user code.
|
|
# ------------------------------------------------------------------
|
|
def _run_dense_impl(self, hidden_states: torch.Tensor):
|
|
"""Hot-path entry into the fused decode/prefill kernel.
|
|
|
|
Implementation lives in dsv4/kernels/router/dense_router_decode.py
|
|
(small N) or dsv4/kernels/router/dense_router_prefill.py (large N).
|
|
The selection is internal to that module — Router doesn't care.
|
|
"""
|
|
from dsv4.kernels.router import dense_router_dispatch
|
|
N = hidden_states.shape[0]
|
|
out_w = self._topk_weights_buf[:N]
|
|
out_ids = self._topk_ids_buf[:N]
|
|
dense_router_dispatch(
|
|
hidden_states=hidden_states,
|
|
W_gate=self.W_gate,
|
|
e_bias=self.e_bias,
|
|
routed_scaling_factor=self.routed_scaling_factor,
|
|
top_k=self.top_k,
|
|
out_weights=out_w,
|
|
out_ids=out_ids,
|
|
)
|
|
return out_w, out_ids
|
|
|
|
def _run_hash_impl(self, token_ids: torch.Tensor):
|
|
"""Hot-path entry into the hash gather kernel.
|
|
|
|
Implementation lives in dsv4/kernels/cuda/hash_router.cu via the
|
|
wrapper in dsv4/ops/router.py.
|
|
"""
|
|
from dsv4.kernels.router import hash_router_dispatch
|
|
N = token_ids.shape[0]
|
|
out_w = self._topk_weights_buf[:N]
|
|
out_ids = self._topk_ids_buf[:N]
|
|
hash_router_dispatch(
|
|
token_ids=token_ids,
|
|
hash_lut=self.hash_lut,
|
|
top_k=self.top_k,
|
|
out_weights=out_w, # filled with 1/k
|
|
out_ids=out_ids,
|
|
)
|
|
return out_w, out_ids
|