"""DSV4 Router — token-to-expert assignment. Two routing modes that share an output shape: - 'dense': sqrt(softplus(X @ W_gate)) + per-expert bias, top-k selection. Used by MoE layers 3+ (the bulk of the network). - 'hash': deterministic per-token-ID lookup, uniform weights. Used by the first 3 MoE layers per DSV4 §2.1. Both modes produce (topk_weights, topk_ids) suitable for direct consumption by Nvfp4MoE.run(). CUDA-graph-compatible: pre-allocated buffers, no CPU-GPU syncs. Selection between modes is by layer_idx at construction time — the kernel path is fixed once the Router is built so the dispatch is constant-folded by torch.compile. """ from __future__ import annotations from typing import Optional, Literal import torch from dsv4.ops.router import ( register_router, dense_router_op, hash_router_op, ) RouterMode = Literal["dense", "hash"] class Router: """DSV4 expert router. Per the DeepSeek-V4 paper (§2.1): - Affinity activation is sqrt(softplus(·)), replacing V3's sigmoid(·). - Auxiliary-loss-free strategy: a learned per-expert bias (loaded from checkpoint, frozen at inference) is added to the activation for SELECTION only. The actual gating weight applied to expert outputs uses the UNBIASED activation. - First 3 MoE layers use Hash routing (Roller et al. 2021): a precomputed [vocab_size, k] LUT mapping token IDs to expert IDs. No gate GEMM is performed. - Sequence-wise balance loss is training-only; not applied here. Parameters ---------- hidden_size : int Model hidden dimension. Must match W_gate's K dimension. num_experts : int Total routed experts (Flash: 256, Pro: 384). Shared experts are handled separately by Nvfp4SharedExpert. top_k : int Experts activated per token. DSV4 uses 6. routed_scaling_factor : float Post-renormalization scale on gating weights. DSV3 used 2.5; verify against the V4 checkpoint config — may be per-layer. mode : {'dense', 'hash'} Routing strategy. Decided at construction; cannot change at runtime. vocab_size : int, optional Required when mode='hash'. The LUT is [vocab_size, top_k] int32. max_num_tokens : int Upper bound on N for pre-allocated buffer sizing. device : str CUDA device. """ def __init__( self, hidden_size: int, num_experts: int, top_k: int = 6, routed_scaling_factor: float = 2.5, *, mode: RouterMode, vocab_size: Optional[int] = None, max_num_tokens: int = 8192, device: str = "cuda", ): if mode == "hash" and vocab_size is None: raise ValueError("vocab_size is required when mode='hash'") if mode not in ("dense", "hash"): raise ValueError(f"unknown router mode: {mode!r}") self.hidden_size = hidden_size self.num_experts = num_experts self.top_k = top_k self.routed_scaling_factor = routed_scaling_factor self.mode = mode self.vocab_size = vocab_size self.max_num_tokens = max_num_tokens self.device = device # ---- Parameters (filled by load_weights / finalize_weights) ---- # Dense mode: # W_gate: [hidden_size, num_experts] BF16 # e_bias: [num_experts] FP32 — auxiliary-loss-free selection bias. # Hash mode: # hash_lut: [vocab_size, top_k] int32 — precomputed expert IDs. self.W_gate: Optional[torch.Tensor] = None self.e_bias: Optional[torch.Tensor] = None self.hash_lut: Optional[torch.Tensor] = None # ---- Pre-allocated output buffers (cudagraph-safe) ---- self._topk_weights_buf: Optional[torch.Tensor] = None self._topk_ids_buf: Optional[torch.Tensor] = None # Runner ID assigned on first call (see custom_op pattern). self._runner_id: Optional[int] = None # ------------------------------------------------------------------ # Weight loading # ------------------------------------------------------------------ def load_weights( self, W_gate: Optional[torch.Tensor] = None, e_bias: Optional[torch.Tensor] = None, hash_lut: Optional[torch.Tensor] = None, ) -> None: """Populate router parameters from a checkpoint shard. Dense mode expects (W_gate, e_bias). Hash mode expects (hash_lut). Mismatches with self.mode raise immediately — these errors are nearly always loader bugs and silent acceptance would mask them. """ if self.mode == "dense": if W_gate is None or e_bias is None: raise ValueError("dense router needs both W_gate and e_bias") assert W_gate.shape == (self.hidden_size, self.num_experts), \ f"W_gate shape {tuple(W_gate.shape)} != " \ f"{(self.hidden_size, self.num_experts)}" assert e_bias.shape == (self.num_experts,), \ f"e_bias shape {tuple(e_bias.shape)} != ({self.num_experts},)" self.W_gate = W_gate.to(device=self.device, dtype=torch.bfloat16) self.e_bias = e_bias.to(device=self.device, dtype=torch.float32) else: # hash if hash_lut is None: raise ValueError("hash router needs hash_lut") assert hash_lut.shape == (self.vocab_size, self.top_k), \ f"hash_lut shape {tuple(hash_lut.shape)} != " \ f"{(self.vocab_size, self.top_k)}" assert (hash_lut >= 0).all() and (hash_lut < self.num_experts).all(), \ "hash_lut contains out-of-range expert IDs" self.hash_lut = hash_lut.to(device=self.device, dtype=torch.int32) def finalize_weights(self) -> None: """Allocate output buffers and JIT-compile the routing kernel. Mirrors the finalize_weights() pattern in Nvfp4Linear: a one-time setup step called after all parameters are loaded. Triggers kernel compilation so the first forward isn't paying that cost. """ self._topk_weights_buf = torch.empty( self.max_num_tokens, self.top_k, dtype=torch.float32, device=self.device, ) self._topk_ids_buf = torch.empty( self.max_num_tokens, self.top_k, dtype=torch.int32, device=self.device, ) # Eager JIT — dispatcher knows our mode and triggers the right # kernel's compile path. See dsv4/ops/router.py. from dsv4.ops.router import warmup_router_compilation warmup_router_compilation(self) # ------------------------------------------------------------------ # Forward # ------------------------------------------------------------------ def __call__( self, hidden_states: torch.Tensor, token_ids: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Produce (topk_weights, topk_ids) for downstream Nvfp4MoE. Parameters ---------- hidden_states : Tensor [N, hidden_size] bfloat16 Required for dense mode. Ignored for hash mode (kept in the signature so the call site is mode-agnostic). token_ids : Tensor [N] int32, optional Required for hash mode. Ignored for dense mode. Returns ------- topk_weights : Tensor [N, top_k] float32 topk_ids : Tensor [N, top_k] int32 Notes ----- Both outputs are views into pre-allocated buffers — do not retain them across router calls. Nvfp4MoE consumes them immediately, which matches its existing contract. """ if self._topk_weights_buf is None: raise RuntimeError("Router.finalize_weights() not called") if self.mode == "dense": if hidden_states is None: raise ValueError("dense router requires hidden_states") return self._run_dense(hidden_states) else: if token_ids is None: raise ValueError("hash router requires token_ids") return self._run_hash(token_ids) # ------------------------------------------------------------------ # Mode-specific dispatch — each routes through a torch.library.custom_op # so Dynamo / torch.compile treats the kernel as opaque. # ------------------------------------------------------------------ def _run_dense(self, hidden_states: torch.Tensor): if self._runner_id is None: self._runner_id = register_router(self) return dense_router_op( hidden_states, self._runner_id, self.num_experts, self.top_k, ) def _run_hash(self, token_ids: torch.Tensor): if self._runner_id is None: self._runner_id = register_router(self) return hash_router_op( token_ids, self._runner_id, self.top_k, ) # ------------------------------------------------------------------ # Called by the custom_op dispatch in dsv4/ops/router.py — not by user code. # ------------------------------------------------------------------ def _run_dense_impl(self, hidden_states: torch.Tensor): """Hot-path entry into the fused decode/prefill kernel. Implementation lives in dsv4/kernels/router/dense_router_decode.py (small N) or dsv4/kernels/router/dense_router_prefill.py (large N). The selection is internal to that module — Router doesn't care. """ from dsv4.kernels.router import dense_router_dispatch N = hidden_states.shape[0] out_w = self._topk_weights_buf[:N] out_ids = self._topk_ids_buf[:N] dense_router_dispatch( hidden_states=hidden_states, W_gate=self.W_gate, e_bias=self.e_bias, routed_scaling_factor=self.routed_scaling_factor, top_k=self.top_k, out_weights=out_w, out_ids=out_ids, ) return out_w, out_ids def _run_hash_impl(self, token_ids: torch.Tensor): """Hot-path entry into the hash gather kernel. Implementation lives in dsv4/kernels/cuda/hash_router.cu via the wrapper in dsv4/ops/router.py. """ from dsv4.kernels.router import hash_router_dispatch N = token_ids.shape[0] out_w = self._topk_weights_buf[:N] out_ids = self._topk_ids_buf[:N] hash_router_dispatch( token_ids=token_ids, hash_lut=self.hash_lut, top_k=self.top_k, out_weights=out_w, # filled with 1/k out_ids=out_ids, ) return out_w, out_ids