Wire NVFP4 fused router kernel into e2e single-shot pipeline

- Add dense_router_dispatch_nvfp4_fused() in dense_router_decode.py:
  single-kernel NVFP4 blockscaled GEMM + fused router epilogue
- Router.load_nvfp4_fused_gate(): stores raw NVFP4 tensors for fused path
- Router._run_dense_impl() dispatch priority: fused > 2-kernel > BF16
- single_shot_inference.py: loads raw NVFP4 gate weights for fused kernel
  instead of building Nvfp4Linear (which was the 2-kernel path)
- Fix selection sort bug in nvfp4_fused_router_kernel.py: pass 0 was
  missing t_s/t_i/t_a temp save before swap, causing undefined vars
- Export dense_router_dispatch_nvfp4_fused from __init__.py
This commit is contained in:
2026-06-01 09:47:48 +00:00
parent d9d3ca42b0
commit 31ebe4f2db
5 changed files with 121 additions and 27 deletions

View File

@@ -2,11 +2,16 @@
Exports:
dense_router_dispatch: BF16 GEMM + fused activation + top-k (fallback)
dense_router_dispatch_nvfp4: NVFP4 GEMM + fused activation + top-k (production)
dense_router_dispatch_nvfp4: NVFP4 GEMM + fused activation + top-k (2-kernel)
dense_router_dispatch_nvfp4_fused: NVFP4 fused single-kernel GEMM + router epilogue
hash_router_dispatch: Hash routing via precomputed LUT gather
"""
from dsv4.kernels.router.dense_router_decode import dense_router_dispatch, dense_router_dispatch_nvfp4
from dsv4.kernels.router.dense_router_decode import (
dense_router_dispatch,
dense_router_dispatch_nvfp4,
dense_router_dispatch_nvfp4_fused,
)
def hash_router_dispatch(

View File

@@ -1,15 +1,14 @@
"""DSV4 Dense Router — NVFP4 GEMM + sqrt(softplus) + bias + top-k.
Production path: NVFP4 GEMM via Nvfp4Linear (Blackwell tensor cores)
followed by the fused activation_topk CUDA kernel for sqrt(softplus) +
bias + top-k + renorm.
BF16 cuBLAS fallback: When NVFP4 scales are not available in the
checkpoint, dense_router_dispatch uses torch.nn.functional.linear
(cuBLAS, SM100 tensor cores) instead.
The CuTeDSL fused GEMM+epilogue kernel (dense_router_decode_kernel.py)
is the future single-kernel path but is not yet production-ready.
Production paths (in priority order):
1. NVFP4 fused router kernel (nvfp4_fused_router_kernel.py):
Single-kernel blockscaled GEMM + fused router epilogue.
No intermediate GMEM buffer. Pure NVFP4 + Blackwell tensor cores.
2. NVFP4 GEMM + activation_topk (2-kernel path):
Nvfp4Linear (Blackwell tensor cores) + fused activation_topk CUDA kernel.
3. BF16 cuBLAS fallback: When NVFP4 scales are not available in the
checkpoint, dense_router_dispatch uses torch.nn.functional.linear
(cuBLAS, SM100 tensor cores) instead.
"""
from __future__ import annotations
@@ -48,7 +47,7 @@ def dense_router_dispatch_nvfp4(
out_weights: torch.Tensor, # [N, top_k] FP32, pre-allocated
out_ids: torch.Tensor, # [N, top_k] int32, pre-allocated
):
"""Dispatch the dense router (NVFP4 production GEMM).
"""Dispatch the dense router (NVFP4 production GEMM, 2-kernel path).
NVFP4 GEMM via Nvfp4Linear (Blackwell SM100 tensor cores),
then fused activation + top-k via the CUDA kernel.
@@ -59,3 +58,47 @@ def dense_router_dispatch_nvfp4(
logits, e_bias, routed_scaling_factor, top_k,
out_weights, out_ids,
)
def dense_router_dispatch_nvfp4_fused(
hidden_states: torch.Tensor, # [N, hidden_size] BF16
gate_weight: torch.Tensor, # [K_packed, E_packed] uint8 NVFP4 weight
gate_weight_scale: torch.Tensor, # [K_sf, E_sf] FP8 E4M3 weight scale
gate_ws2: torch.Tensor, # weight_scale_2 (scalar or per-output)
gate_input_scale: torch.Tensor, # input_scale (activation global scale base)
e_bias: torch.Tensor, # [num_experts] FP32
routed_scaling_factor: float,
top_k: int,
out_weights: torch.Tensor, # [N, top_k] FP32, pre-allocated
out_ids: torch.Tensor, # [N, top_k] int32, pre-allocated
):
"""Dispatch the dense router (NVFP4 fused single-kernel path).
Single kernel: NVFP4 blockscaled GEMM + fused router epilogue.
Activation is quantized to NVFP4 inside the kernel.
No intermediate GMEM buffer. Pure NVFP4 + Blackwell tensor cores.
"""
from dsv4.kernels.router.nvfp4_fused_router_kernel import run_nvfp4_fused_router
# Global scales:
# gsa (activation global scale) = input_scale from checkpoint
# gsb (weight global scale) = weight_scale_2 (NOT input_scale * ws2)
gsa = gate_input_scale.float().item() if gate_input_scale.numel() == 1 else gate_input_scale.float().mean().item()
gsb_val = gate_ws2.float().item() if gate_ws2.numel() == 1 else gate_ws2.float().mean().item()
# The fused kernel handles activation quantization internally
# and writes directly to out_weights / out_ids
result_w, result_ids = run_nvfp4_fused_router(
hidden_states=hidden_states,
mat_b=gate_weight,
scale_b=gate_weight_scale,
gsa=gsa,
gsb_val=gsb_val,
e_bias=e_bias,
routed_scaling_factor=routed_scaling_factor,
top_k=top_k,
)
# Copy results into pre-allocated buffers
N = hidden_states.shape[0]
out_weights[:N].copy_(result_w[:N])
out_ids[:N].copy_(result_ids[:N])

View File

@@ -929,6 +929,8 @@ class Nvfp4FusedRouterKernel:
if fs5 > m0_s:
m0_s = fs5; m0_i = fi5; m0_a = fa5; m0_k = cutlass.Int32(5)
# Swap position 0 with the max (flat conditionals by position)
t_s = fs0; t_i = fi0; t_a = fa0
fs0 = m0_s; fi0 = m0_i; fa0 = m0_a
if m0_k == cutlass.Int32(1):
fs1 = t_s; fi1 = t_i; fa1 = t_a
if m0_k == cutlass.Int32(2):

View File

@@ -92,13 +92,22 @@ class Router:
self.device = device
# ---- Parameters (filled by load_weights / finalize_weights) ----
# Dense mode:
# gate_lin: Nvfp4Linear for the gate projection (NVFP4 GEMM)
# Fallback: W_gate BF16 + cuBLAS when NVFP4 scales not available
# e_bias: [num_experts] FP32 — auxiliary-loss-free selection bias.
# Dense mode — fused NVFP4 kernel (single-kernel, preferred):
# gate_weight: raw NVFP4 gate weight tensor [K_packed, E_packed] uint8
# gate_weight_scale: weight scale [K_sf, E_sf] FP8 E4M3
# gate_ws2: weight_scale_2 (global scale base)
# gate_input_scale: input_scale (activation global scale base)
# Dense mode — 2-kernel NVFP4 path (fallback):
# gate_lin: Nvfp4Linear for the gate projection
# Dense mode — BF16 fallback:
# W_gate: BF16 weight for cuBLAS when NVFP4 scales not available
# Hash mode:
# hash_lut: [vocab_size, top_k] int32 — precomputed expert IDs.
self.gate_lin = None # Nvfp4Linear for NVFP4 gate projection
self.gate_weight = None # Raw NVFP4 weight for fused kernel
self.gate_weight_scale = None # FP8 E4M3 scale for fused kernel
self.gate_ws2 = None # weight_scale_2 for fused kernel
self.gate_input_scale = None # input_scale for fused kernel
self.gate_lin = None # Nvfp4Linear for 2-kernel NVFP4 path
self.W_gate: Optional[torch.Tensor] = None # BF16 fallback
self.e_bias: Optional[torch.Tensor] = None
self.hash_lut: Optional[torch.Tensor] = None
@@ -145,7 +154,7 @@ class Router:
self.hash_lut = hash_lut.to(device=self.device, dtype=torch.int32)
def load_nvfp4_gate(self, gate_lin) -> None:
"""Set the NVFP4 gate linear layer (preferred over BF16 W_gate).
"""Set the NVFP4 gate linear layer (2-kernel path).
Called by the single_shot after constructing the Nvfp4Linear
from checkpoint NVFP4 scales. When set, _run_dense_impl uses
@@ -153,6 +162,19 @@ class Router:
"""
self.gate_lin = gate_lin
def load_nvfp4_fused_gate(self, gate_weight, gate_weight_scale,
gate_ws2, gate_input_scale) -> None:
"""Set raw NVFP4 gate tensors for the fused single-kernel path.
Preferred over load_nvfp4_gate (2-kernel) when available.
The fused kernel handles activation quantization + GEMM +
router epilogue in a single kernel launch.
"""
self.gate_weight = gate_weight.to(device=self.device)
self.gate_weight_scale = gate_weight_scale.to(device=self.device)
self.gate_ws2 = gate_ws2.to(device=self.device) if gate_ws2 is not None else None
self.gate_input_scale = gate_input_scale.to(device=self.device)
def finalize_weights(self) -> None:
"""Allocate output buffers and JIT-compile the routing kernel.
@@ -242,16 +264,33 @@ class Router:
# Called by the custom_op dispatch in dsv4/ops/router.py — not by user code.
# ------------------------------------------------------------------
def _run_dense_impl(self, hidden_states: torch.Tensor):
"""Hot-path: NVFP4 GEMM or BF16 fallback + activation_topk.
"""Hot-path: fused NVFP4, 2-kernel NVFP4, or BF16 fallback.
When gate_lin (Nvfp4Linear) is available, uses production NVFP4 GEMM.
Otherwise falls back to BF16 cuBLAS.
Priority:
1. Fused NVFP4 kernel (single-kernel GEMM + router epilogue)
2. 2-kernel NVFP4 path (Nvfp4Linear + activation_topk)
3. BF16 cuBLAS fallback
"""
from dsv4.kernels.router import dense_router_dispatch, dense_router_dispatch_nvfp4
N = hidden_states.shape[0]
out_w = self._topk_weights_buf[:N]
out_ids = self._topk_ids_buf[:N]
if self.gate_lin is not None:
if self.gate_weight is not None:
# Fused single-kernel path (preferred)
from dsv4.kernels.router import dense_router_dispatch_nvfp4_fused
dense_router_dispatch_nvfp4_fused(
hidden_states=hidden_states,
gate_weight=self.gate_weight,
gate_weight_scale=self.gate_weight_scale,
gate_ws2=self.gate_ws2,
gate_input_scale=self.gate_input_scale,
e_bias=self.e_bias,
routed_scaling_factor=self.routed_scaling_factor,
top_k=self.top_k,
out_weights=out_w,
out_ids=out_ids,
)
elif self.gate_lin is not None:
from dsv4.kernels.router import dense_router_dispatch_nvfp4
dense_router_dispatch_nvfp4(
hidden_states=hidden_states,
gate_lin=self.gate_lin,
@@ -262,6 +301,7 @@ class Router:
out_ids=out_ids,
)
else:
from dsv4.kernels.router import dense_router_dispatch
dense_router_dispatch(
hidden_states=hidden_states,
W_gate=self.W_gate,

View File

@@ -697,10 +697,14 @@ def main():
# Try NVFP4 gate weights first (production path)
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(all_w, pfx, 'gate')
if gate_w is not None and gate_ws is not None:
# NVFP4 gate: build production Nvfp4Linear
# NVFP4 gate: load raw tensors for fused single-kernel path
router.load_weights(e_bias=eb.to(dev, torch.float32))
gate_lin = make_nvfp4_linear(H, cfg["n_routed_experts"], dev, all_w, pfx, 'gate')
router.load_nvfp4_gate(gate_lin)
router.load_nvfp4_fused_gate(
gate_weight=gate_w.to(dev),
gate_weight_scale=gate_ws.to(dev),
gate_ws2=gate_ws2.to(dev) if gate_ws2 is not None else torch.tensor(1.0, device=dev),
gate_input_scale=gate_isc.to(dev) if gate_isc is not None else torch.tensor(1.0 / (6.0 * 448.0), device=dev),
)
else:
# BF16 fallback
gw = all_w.get(f"{pfx}.gate.weight")