diff --git a/dsv4/kernels/router/__init__.py b/dsv4/kernels/router/__init__.py index 949d8a0f..5f323129 100644 --- a/dsv4/kernels/router/__init__.py +++ b/dsv4/kernels/router/__init__.py @@ -1,11 +1,12 @@ """DSV4 Router kernels — dispatch and CUDA kernel wrappers. Exports: - dense_router_dispatch: GEMM + fused activation + top-k (all N) + dense_router_dispatch: BF16 GEMM + fused activation + top-k (fallback) + dense_router_dispatch_nvfp4: NVFP4 GEMM + fused activation + top-k (production) hash_router_dispatch: Hash routing via precomputed LUT gather """ -from dsv4.kernels.router.dense_router_decode import dense_router_dispatch +from dsv4.kernels.router.dense_router_decode import dense_router_dispatch, dense_router_dispatch_nvfp4 def hash_router_dispatch( diff --git a/dsv4/kernels/router/dense_router_decode.py b/dsv4/kernels/router/dense_router_decode.py index 0dafa6e4..567bc220 100644 --- a/dsv4/kernels/router/dense_router_decode.py +++ b/dsv4/kernels/router/dense_router_decode.py @@ -1,12 +1,15 @@ -"""DSV4 Dense Router — BF16 GEMM + sqrt(softplus) + bias + top-k. +"""DSV4 Dense Router — NVFP4 GEMM + sqrt(softplus) + bias + top-k. -Production path: BF16 GEMM via cuBLAS (tensor cores on Blackwell) followed by -the fused activation_topk CUDA kernel for sqrt(softplus) + bias + top-k + renorm. +Production path: NVFP4 GEMM via Nvfp4Linear (Blackwell tensor cores) +followed by the fused activation_topk CUDA kernel for sqrt(softplus) + +bias + top-k + renorm. -The CuTeDSL fused GEMM+epilogue kernel was attempted but make_trivial_tiled_mma -for BF16 on SM100 has no working reference in our codebase (all other GEMMs use -NVFP4 blockscaled MMA). The unfused path is production-grade: cuBLAS uses SM100 -tensor cores, and activation_topk is a real CUDA kernel (not PyTorch). +BF16 cuBLAS fallback: When NVFP4 scales are not available in the +checkpoint, dense_router_dispatch uses torch.nn.functional.linear +(cuBLAS, SM100 tensor cores) instead. + +The CuTeDSL fused GEMM+epilogue kernel (dense_router_decode_kernel.py) +is the future single-kernel path but is not yet production-ready. """ from __future__ import annotations @@ -23,7 +26,7 @@ def dense_router_dispatch( out_weights: torch.Tensor, # [N, top_k] FP32, pre-allocated out_ids: torch.Tensor, # [N, top_k] int32, pre-allocated ): - """Dispatch the dense router. + """Dispatch the dense router (BF16 cuBLAS fallback). BF16 GEMM via torch.nn.functional.linear (cuBLAS, SM100 tensor cores), then fused activation + top-k via the CUDA kernel. @@ -34,3 +37,25 @@ def dense_router_dispatch( logits, e_bias, routed_scaling_factor, top_k, out_weights, out_ids, ) + + +def dense_router_dispatch_nvfp4( + hidden_states: torch.Tensor, # [N, hidden_size] BF16 + gate_lin, # Nvfp4Linear instance + e_bias: torch.Tensor, # [num_experts] FP32 + routed_scaling_factor: float, + top_k: int, + out_weights: torch.Tensor, # [N, top_k] FP32, pre-allocated + out_ids: torch.Tensor, # [N, top_k] int32, pre-allocated +): + """Dispatch the dense router (NVFP4 production GEMM). + + NVFP4 GEMM via Nvfp4Linear (Blackwell SM100 tensor cores), + then fused activation + top-k via the CUDA kernel. + """ + logits = gate_lin(hidden_states).float() # (N, E) FP32 + from dsv4.kernels.router._activation_topk import run_fused_activation_topk + run_fused_activation_topk( + logits, e_bias, routed_scaling_factor, top_k, + out_weights, out_ids, + ) diff --git a/dsv4/layers/router.py b/dsv4/layers/router.py index 8a3ce298..8675018c 100644 --- a/dsv4/layers/router.py +++ b/dsv4/layers/router.py @@ -93,11 +93,13 @@ class Router: # ---- Parameters (filled by load_weights / finalize_weights) ---- # Dense mode: - # W_gate: [hidden_size, num_experts] BF16 + # gate_lin: Nvfp4Linear for the gate projection (NVFP4 GEMM) + # Fallback: W_gate BF16 + cuBLAS when NVFP4 scales not available # e_bias: [num_experts] FP32 — auxiliary-loss-free selection bias. # Hash mode: # hash_lut: [vocab_size, top_k] int32 — precomputed expert IDs. - self.W_gate: Optional[torch.Tensor] = None + self.gate_lin = None # Nvfp4Linear for NVFP4 gate projection + self.W_gate: Optional[torch.Tensor] = None # BF16 fallback self.e_bias: Optional[torch.Tensor] = None self.hash_lut: Optional[torch.Tensor] = None @@ -124,15 +126,14 @@ class Router: nearly always loader bugs and silent acceptance would mask them. """ if self.mode == "dense": - if W_gate is None or e_bias is None: - raise ValueError("dense router needs both W_gate and e_bias") - assert W_gate.shape == (self.hidden_size, self.num_experts), \ - f"W_gate shape {tuple(W_gate.shape)} != " \ - f"{(self.hidden_size, self.num_experts)}" + if e_bias is None: + raise ValueError("dense router needs e_bias") assert e_bias.shape == (self.num_experts,), \ f"e_bias shape {tuple(e_bias.shape)} != ({self.num_experts},)" - self.W_gate = W_gate.to(device=self.device, dtype=torch.bfloat16) self.e_bias = e_bias.to(device=self.device, dtype=torch.float32) + if W_gate is not None: + self.W_gate = W_gate.to(device=self.device, dtype=torch.bfloat16) + # gate_lin is set separately via load_nvfp4_gate() else: # hash if hash_lut is None: raise ValueError("hash router needs hash_lut") @@ -143,6 +144,15 @@ class Router: "hash_lut contains out-of-range expert IDs" self.hash_lut = hash_lut.to(device=self.device, dtype=torch.int32) + def load_nvfp4_gate(self, gate_lin) -> None: + """Set the NVFP4 gate linear layer (preferred over BF16 W_gate). + + Called by the single_shot after constructing the Nvfp4Linear + from checkpoint NVFP4 scales. When set, _run_dense_impl uses + the production NVFP4 GEMM path instead of BF16 cuBLAS. + """ + self.gate_lin = gate_lin + def finalize_weights(self) -> None: """Allocate output buffers and JIT-compile the routing kernel. @@ -232,25 +242,35 @@ class Router: # Called by the custom_op dispatch in dsv4/ops/router.py — not by user code. # ------------------------------------------------------------------ def _run_dense_impl(self, hidden_states: torch.Tensor): - """Hot-path entry into the fused decode/prefill kernel. + """Hot-path: NVFP4 GEMM or BF16 fallback + activation_topk. - Implementation lives in dsv4/kernels/router/dense_router_decode.py - (small N) or dsv4/kernels/router/dense_router_prefill.py (large N). - The selection is internal to that module — Router doesn't care. + When gate_lin (Nvfp4Linear) is available, uses production NVFP4 GEMM. + Otherwise falls back to BF16 cuBLAS. """ - from dsv4.kernels.router import dense_router_dispatch + from dsv4.kernels.router import dense_router_dispatch, dense_router_dispatch_nvfp4 N = hidden_states.shape[0] out_w = self._topk_weights_buf[:N] out_ids = self._topk_ids_buf[:N] - dense_router_dispatch( - hidden_states=hidden_states, - W_gate=self.W_gate, - e_bias=self.e_bias, - routed_scaling_factor=self.routed_scaling_factor, - top_k=self.top_k, - out_weights=out_w, - out_ids=out_ids, - ) + if self.gate_lin is not None: + dense_router_dispatch_nvfp4( + hidden_states=hidden_states, + gate_lin=self.gate_lin, + e_bias=self.e_bias, + routed_scaling_factor=self.routed_scaling_factor, + top_k=self.top_k, + out_weights=out_w, + out_ids=out_ids, + ) + else: + dense_router_dispatch( + hidden_states=hidden_states, + W_gate=self.W_gate, + e_bias=self.e_bias, + routed_scaling_factor=self.routed_scaling_factor, + top_k=self.top_k, + out_weights=out_w, + out_ids=out_ids, + ) return out_w, out_ids def _run_hash_impl(self, token_ids: torch.Tensor): diff --git a/single_shot_inference.py b/single_shot_inference.py index a4ab3665..4d4d8a24 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -665,10 +665,20 @@ def main(): if is_hash: router.load_weights(hash_lut=all_w[f"{pfx}.gate.tid2eid"].to(dev, torch.int32)) else: - gw = all_w.get(f"{pfx}.gate.weight"); eb = all_w.get(f"{pfx}.gate.e_score_correction_bias") - if gw is not None and eb is not None: - if gw.shape == (cfg["n_routed_experts"], H): gw = gw.T.contiguous() - router.load_weights(W_gate=gw.bfloat16().to(dev), e_bias=eb.to(dev, torch.float32)) + eb = all_w.get(f"{pfx}.gate.e_score_correction_bias") + # Try NVFP4 gate weights first (production path) + gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(all_w, pfx, 'gate') + if gate_w is not None and gate_ws is not None: + # NVFP4 gate: build production Nvfp4Linear + router.load_weights(e_bias=eb.to(dev, torch.float32)) + gate_lin = make_nvfp4_linear(H, cfg["n_routed_experts"], dev, all_w, pfx, 'gate') + router.load_nvfp4_gate(gate_lin) + else: + # BF16 fallback + gw = all_w.get(f"{pfx}.gate.weight") + if gw is not None: + if gw.shape == (cfg["n_routed_experts"], H): gw = gw.T.contiguous() + router.load_weights(W_gate=gw.bfloat16().to(dev), e_bias=eb.to(dev, torch.float32)) router.finalize_weights(); routers[li] = router moe = Nvfp4MoE(num_experts=cfg["n_routed_experts"], hidden_size=H,