diff --git a/dsv4/kernels/router/__init__.py b/dsv4/kernels/router/__init__.py index 5f323129..611364b1 100644 --- a/dsv4/kernels/router/__init__.py +++ b/dsv4/kernels/router/__init__.py @@ -2,11 +2,16 @@ Exports: dense_router_dispatch: BF16 GEMM + fused activation + top-k (fallback) - dense_router_dispatch_nvfp4: NVFP4 GEMM + fused activation + top-k (production) + dense_router_dispatch_nvfp4: NVFP4 GEMM + fused activation + top-k (2-kernel) + dense_router_dispatch_nvfp4_fused: NVFP4 fused single-kernel GEMM + router epilogue hash_router_dispatch: Hash routing via precomputed LUT gather """ -from dsv4.kernels.router.dense_router_decode import dense_router_dispatch, dense_router_dispatch_nvfp4 +from dsv4.kernels.router.dense_router_decode import ( + dense_router_dispatch, + dense_router_dispatch_nvfp4, + dense_router_dispatch_nvfp4_fused, +) def hash_router_dispatch( diff --git a/dsv4/kernels/router/dense_router_decode.py b/dsv4/kernels/router/dense_router_decode.py index 567bc220..7a83b58f 100644 --- a/dsv4/kernels/router/dense_router_decode.py +++ b/dsv4/kernels/router/dense_router_decode.py @@ -1,15 +1,14 @@ """DSV4 Dense Router — NVFP4 GEMM + sqrt(softplus) + bias + top-k. -Production path: NVFP4 GEMM via Nvfp4Linear (Blackwell tensor cores) -followed by the fused activation_topk CUDA kernel for sqrt(softplus) + -bias + top-k + renorm. - -BF16 cuBLAS fallback: When NVFP4 scales are not available in the -checkpoint, dense_router_dispatch uses torch.nn.functional.linear -(cuBLAS, SM100 tensor cores) instead. - -The CuTeDSL fused GEMM+epilogue kernel (dense_router_decode_kernel.py) -is the future single-kernel path but is not yet production-ready. +Production paths (in priority order): +1. NVFP4 fused router kernel (nvfp4_fused_router_kernel.py): + Single-kernel blockscaled GEMM + fused router epilogue. + No intermediate GMEM buffer. Pure NVFP4 + Blackwell tensor cores. +2. NVFP4 GEMM + activation_topk (2-kernel path): + Nvfp4Linear (Blackwell tensor cores) + fused activation_topk CUDA kernel. +3. BF16 cuBLAS fallback: When NVFP4 scales are not available in the + checkpoint, dense_router_dispatch uses torch.nn.functional.linear + (cuBLAS, SM100 tensor cores) instead. """ from __future__ import annotations @@ -48,7 +47,7 @@ def dense_router_dispatch_nvfp4( out_weights: torch.Tensor, # [N, top_k] FP32, pre-allocated out_ids: torch.Tensor, # [N, top_k] int32, pre-allocated ): - """Dispatch the dense router (NVFP4 production GEMM). + """Dispatch the dense router (NVFP4 production GEMM, 2-kernel path). NVFP4 GEMM via Nvfp4Linear (Blackwell SM100 tensor cores), then fused activation + top-k via the CUDA kernel. @@ -59,3 +58,47 @@ def dense_router_dispatch_nvfp4( logits, e_bias, routed_scaling_factor, top_k, out_weights, out_ids, ) + + +def dense_router_dispatch_nvfp4_fused( + hidden_states: torch.Tensor, # [N, hidden_size] BF16 + gate_weight: torch.Tensor, # [K_packed, E_packed] uint8 NVFP4 weight + gate_weight_scale: torch.Tensor, # [K_sf, E_sf] FP8 E4M3 weight scale + gate_ws2: torch.Tensor, # weight_scale_2 (scalar or per-output) + gate_input_scale: torch.Tensor, # input_scale (activation global scale base) + e_bias: torch.Tensor, # [num_experts] FP32 + routed_scaling_factor: float, + top_k: int, + out_weights: torch.Tensor, # [N, top_k] FP32, pre-allocated + out_ids: torch.Tensor, # [N, top_k] int32, pre-allocated +): + """Dispatch the dense router (NVFP4 fused single-kernel path). + + Single kernel: NVFP4 blockscaled GEMM + fused router epilogue. + Activation is quantized to NVFP4 inside the kernel. + No intermediate GMEM buffer. Pure NVFP4 + Blackwell tensor cores. + """ + from dsv4.kernels.router.nvfp4_fused_router_kernel import run_nvfp4_fused_router + + # Global scales: + # gsa (activation global scale) = input_scale from checkpoint + # gsb (weight global scale) = weight_scale_2 (NOT input_scale * ws2) + gsa = gate_input_scale.float().item() if gate_input_scale.numel() == 1 else gate_input_scale.float().mean().item() + gsb_val = gate_ws2.float().item() if gate_ws2.numel() == 1 else gate_ws2.float().mean().item() + + # The fused kernel handles activation quantization internally + # and writes directly to out_weights / out_ids + result_w, result_ids = run_nvfp4_fused_router( + hidden_states=hidden_states, + mat_b=gate_weight, + scale_b=gate_weight_scale, + gsa=gsa, + gsb_val=gsb_val, + e_bias=e_bias, + routed_scaling_factor=routed_scaling_factor, + top_k=top_k, + ) + # Copy results into pre-allocated buffers + N = hidden_states.shape[0] + out_weights[:N].copy_(result_w[:N]) + out_ids[:N].copy_(result_ids[:N]) diff --git a/dsv4/kernels/router/nvfp4_fused_router_kernel.py b/dsv4/kernels/router/nvfp4_fused_router_kernel.py index d57476d2..39a8bafa 100644 --- a/dsv4/kernels/router/nvfp4_fused_router_kernel.py +++ b/dsv4/kernels/router/nvfp4_fused_router_kernel.py @@ -929,6 +929,8 @@ class Nvfp4FusedRouterKernel: if fs5 > m0_s: m0_s = fs5; m0_i = fi5; m0_a = fa5; m0_k = cutlass.Int32(5) # Swap position 0 with the max (flat conditionals by position) + t_s = fs0; t_i = fi0; t_a = fa0 + fs0 = m0_s; fi0 = m0_i; fa0 = m0_a if m0_k == cutlass.Int32(1): fs1 = t_s; fi1 = t_i; fa1 = t_a if m0_k == cutlass.Int32(2): diff --git a/dsv4/layers/router.py b/dsv4/layers/router.py index 8675018c..f49dc44a 100644 --- a/dsv4/layers/router.py +++ b/dsv4/layers/router.py @@ -92,13 +92,22 @@ class Router: self.device = device # ---- Parameters (filled by load_weights / finalize_weights) ---- - # Dense mode: - # gate_lin: Nvfp4Linear for the gate projection (NVFP4 GEMM) - # Fallback: W_gate BF16 + cuBLAS when NVFP4 scales not available - # e_bias: [num_experts] FP32 — auxiliary-loss-free selection bias. + # Dense mode — fused NVFP4 kernel (single-kernel, preferred): + # gate_weight: raw NVFP4 gate weight tensor [K_packed, E_packed] uint8 + # gate_weight_scale: weight scale [K_sf, E_sf] FP8 E4M3 + # gate_ws2: weight_scale_2 (global scale base) + # gate_input_scale: input_scale (activation global scale base) + # Dense mode — 2-kernel NVFP4 path (fallback): + # gate_lin: Nvfp4Linear for the gate projection + # Dense mode — BF16 fallback: + # W_gate: BF16 weight for cuBLAS when NVFP4 scales not available # Hash mode: # hash_lut: [vocab_size, top_k] int32 — precomputed expert IDs. - self.gate_lin = None # Nvfp4Linear for NVFP4 gate projection + self.gate_weight = None # Raw NVFP4 weight for fused kernel + self.gate_weight_scale = None # FP8 E4M3 scale for fused kernel + self.gate_ws2 = None # weight_scale_2 for fused kernel + self.gate_input_scale = None # input_scale for fused kernel + self.gate_lin = None # Nvfp4Linear for 2-kernel NVFP4 path self.W_gate: Optional[torch.Tensor] = None # BF16 fallback self.e_bias: Optional[torch.Tensor] = None self.hash_lut: Optional[torch.Tensor] = None @@ -145,7 +154,7 @@ class Router: self.hash_lut = hash_lut.to(device=self.device, dtype=torch.int32) def load_nvfp4_gate(self, gate_lin) -> None: - """Set the NVFP4 gate linear layer (preferred over BF16 W_gate). + """Set the NVFP4 gate linear layer (2-kernel path). Called by the single_shot after constructing the Nvfp4Linear from checkpoint NVFP4 scales. When set, _run_dense_impl uses @@ -153,6 +162,19 @@ class Router: """ self.gate_lin = gate_lin + def load_nvfp4_fused_gate(self, gate_weight, gate_weight_scale, + gate_ws2, gate_input_scale) -> None: + """Set raw NVFP4 gate tensors for the fused single-kernel path. + + Preferred over load_nvfp4_gate (2-kernel) when available. + The fused kernel handles activation quantization + GEMM + + router epilogue in a single kernel launch. + """ + self.gate_weight = gate_weight.to(device=self.device) + self.gate_weight_scale = gate_weight_scale.to(device=self.device) + self.gate_ws2 = gate_ws2.to(device=self.device) if gate_ws2 is not None else None + self.gate_input_scale = gate_input_scale.to(device=self.device) + def finalize_weights(self) -> None: """Allocate output buffers and JIT-compile the routing kernel. @@ -242,16 +264,33 @@ class Router: # Called by the custom_op dispatch in dsv4/ops/router.py — not by user code. # ------------------------------------------------------------------ def _run_dense_impl(self, hidden_states: torch.Tensor): - """Hot-path: NVFP4 GEMM or BF16 fallback + activation_topk. + """Hot-path: fused NVFP4, 2-kernel NVFP4, or BF16 fallback. - When gate_lin (Nvfp4Linear) is available, uses production NVFP4 GEMM. - Otherwise falls back to BF16 cuBLAS. + Priority: + 1. Fused NVFP4 kernel (single-kernel GEMM + router epilogue) + 2. 2-kernel NVFP4 path (Nvfp4Linear + activation_topk) + 3. BF16 cuBLAS fallback """ - from dsv4.kernels.router import dense_router_dispatch, dense_router_dispatch_nvfp4 N = hidden_states.shape[0] out_w = self._topk_weights_buf[:N] out_ids = self._topk_ids_buf[:N] - if self.gate_lin is not None: + if self.gate_weight is not None: + # Fused single-kernel path (preferred) + from dsv4.kernels.router import dense_router_dispatch_nvfp4_fused + dense_router_dispatch_nvfp4_fused( + hidden_states=hidden_states, + gate_weight=self.gate_weight, + gate_weight_scale=self.gate_weight_scale, + gate_ws2=self.gate_ws2, + gate_input_scale=self.gate_input_scale, + e_bias=self.e_bias, + routed_scaling_factor=self.routed_scaling_factor, + top_k=self.top_k, + out_weights=out_w, + out_ids=out_ids, + ) + elif self.gate_lin is not None: + from dsv4.kernels.router import dense_router_dispatch_nvfp4 dense_router_dispatch_nvfp4( hidden_states=hidden_states, gate_lin=self.gate_lin, @@ -262,6 +301,7 @@ class Router: out_ids=out_ids, ) else: + from dsv4.kernels.router import dense_router_dispatch dense_router_dispatch( hidden_states=hidden_states, W_gate=self.W_gate, diff --git a/single_shot_inference.py b/single_shot_inference.py index fdddb012..001a3853 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -697,10 +697,14 @@ def main(): # Try NVFP4 gate weights first (production path) gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(all_w, pfx, 'gate') if gate_w is not None and gate_ws is not None: - # NVFP4 gate: build production Nvfp4Linear + # NVFP4 gate: load raw tensors for fused single-kernel path router.load_weights(e_bias=eb.to(dev, torch.float32)) - gate_lin = make_nvfp4_linear(H, cfg["n_routed_experts"], dev, all_w, pfx, 'gate') - router.load_nvfp4_gate(gate_lin) + router.load_nvfp4_fused_gate( + gate_weight=gate_w.to(dev), + gate_weight_scale=gate_ws.to(dev), + gate_ws2=gate_ws2.to(dev) if gate_ws2 is not None else torch.tensor(1.0, device=dev), + gate_input_scale=gate_isc.to(dev) if gate_isc is not None else torch.tensor(1.0 / (6.0 * 448.0), device=dev), + ) else: # BF16 fallback gw = all_w.get(f"{pfx}.gate.weight")