Wire NVFP4 fused router kernel into e2e single-shot pipeline
- Add dense_router_dispatch_nvfp4_fused() in dense_router_decode.py: single-kernel NVFP4 blockscaled GEMM + fused router epilogue - Router.load_nvfp4_fused_gate(): stores raw NVFP4 tensors for fused path - Router._run_dense_impl() dispatch priority: fused > 2-kernel > BF16 - single_shot_inference.py: loads raw NVFP4 gate weights for fused kernel instead of building Nvfp4Linear (which was the 2-kernel path) - Fix selection sort bug in nvfp4_fused_router_kernel.py: pass 0 was missing t_s/t_i/t_a temp save before swap, causing undefined vars - Export dense_router_dispatch_nvfp4_fused from __init__.py
This commit is contained in:
@@ -2,11 +2,16 @@
|
||||
|
||||
Exports:
|
||||
dense_router_dispatch: BF16 GEMM + fused activation + top-k (fallback)
|
||||
dense_router_dispatch_nvfp4: NVFP4 GEMM + fused activation + top-k (production)
|
||||
dense_router_dispatch_nvfp4: NVFP4 GEMM + fused activation + top-k (2-kernel)
|
||||
dense_router_dispatch_nvfp4_fused: NVFP4 fused single-kernel GEMM + router epilogue
|
||||
hash_router_dispatch: Hash routing via precomputed LUT gather
|
||||
"""
|
||||
|
||||
from dsv4.kernels.router.dense_router_decode import dense_router_dispatch, dense_router_dispatch_nvfp4
|
||||
from dsv4.kernels.router.dense_router_decode import (
|
||||
dense_router_dispatch,
|
||||
dense_router_dispatch_nvfp4,
|
||||
dense_router_dispatch_nvfp4_fused,
|
||||
)
|
||||
|
||||
|
||||
def hash_router_dispatch(
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
"""DSV4 Dense Router — NVFP4 GEMM + sqrt(softplus) + bias + top-k.
|
||||
|
||||
Production path: NVFP4 GEMM via Nvfp4Linear (Blackwell tensor cores)
|
||||
followed by the fused activation_topk CUDA kernel for sqrt(softplus) +
|
||||
bias + top-k + renorm.
|
||||
|
||||
BF16 cuBLAS fallback: When NVFP4 scales are not available in the
|
||||
checkpoint, dense_router_dispatch uses torch.nn.functional.linear
|
||||
(cuBLAS, SM100 tensor cores) instead.
|
||||
|
||||
The CuTeDSL fused GEMM+epilogue kernel (dense_router_decode_kernel.py)
|
||||
is the future single-kernel path but is not yet production-ready.
|
||||
Production paths (in priority order):
|
||||
1. NVFP4 fused router kernel (nvfp4_fused_router_kernel.py):
|
||||
Single-kernel blockscaled GEMM + fused router epilogue.
|
||||
No intermediate GMEM buffer. Pure NVFP4 + Blackwell tensor cores.
|
||||
2. NVFP4 GEMM + activation_topk (2-kernel path):
|
||||
Nvfp4Linear (Blackwell tensor cores) + fused activation_topk CUDA kernel.
|
||||
3. BF16 cuBLAS fallback: When NVFP4 scales are not available in the
|
||||
checkpoint, dense_router_dispatch uses torch.nn.functional.linear
|
||||
(cuBLAS, SM100 tensor cores) instead.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -48,7 +47,7 @@ def dense_router_dispatch_nvfp4(
|
||||
out_weights: torch.Tensor, # [N, top_k] FP32, pre-allocated
|
||||
out_ids: torch.Tensor, # [N, top_k] int32, pre-allocated
|
||||
):
|
||||
"""Dispatch the dense router (NVFP4 production GEMM).
|
||||
"""Dispatch the dense router (NVFP4 production GEMM, 2-kernel path).
|
||||
|
||||
NVFP4 GEMM via Nvfp4Linear (Blackwell SM100 tensor cores),
|
||||
then fused activation + top-k via the CUDA kernel.
|
||||
@@ -59,3 +58,47 @@ def dense_router_dispatch_nvfp4(
|
||||
logits, e_bias, routed_scaling_factor, top_k,
|
||||
out_weights, out_ids,
|
||||
)
|
||||
|
||||
|
||||
def dense_router_dispatch_nvfp4_fused(
|
||||
hidden_states: torch.Tensor, # [N, hidden_size] BF16
|
||||
gate_weight: torch.Tensor, # [K_packed, E_packed] uint8 NVFP4 weight
|
||||
gate_weight_scale: torch.Tensor, # [K_sf, E_sf] FP8 E4M3 weight scale
|
||||
gate_ws2: torch.Tensor, # weight_scale_2 (scalar or per-output)
|
||||
gate_input_scale: torch.Tensor, # input_scale (activation global scale base)
|
||||
e_bias: torch.Tensor, # [num_experts] FP32
|
||||
routed_scaling_factor: float,
|
||||
top_k: int,
|
||||
out_weights: torch.Tensor, # [N, top_k] FP32, pre-allocated
|
||||
out_ids: torch.Tensor, # [N, top_k] int32, pre-allocated
|
||||
):
|
||||
"""Dispatch the dense router (NVFP4 fused single-kernel path).
|
||||
|
||||
Single kernel: NVFP4 blockscaled GEMM + fused router epilogue.
|
||||
Activation is quantized to NVFP4 inside the kernel.
|
||||
No intermediate GMEM buffer. Pure NVFP4 + Blackwell tensor cores.
|
||||
"""
|
||||
from dsv4.kernels.router.nvfp4_fused_router_kernel import run_nvfp4_fused_router
|
||||
|
||||
# Global scales:
|
||||
# gsa (activation global scale) = input_scale from checkpoint
|
||||
# gsb (weight global scale) = weight_scale_2 (NOT input_scale * ws2)
|
||||
gsa = gate_input_scale.float().item() if gate_input_scale.numel() == 1 else gate_input_scale.float().mean().item()
|
||||
gsb_val = gate_ws2.float().item() if gate_ws2.numel() == 1 else gate_ws2.float().mean().item()
|
||||
|
||||
# The fused kernel handles activation quantization internally
|
||||
# and writes directly to out_weights / out_ids
|
||||
result_w, result_ids = run_nvfp4_fused_router(
|
||||
hidden_states=hidden_states,
|
||||
mat_b=gate_weight,
|
||||
scale_b=gate_weight_scale,
|
||||
gsa=gsa,
|
||||
gsb_val=gsb_val,
|
||||
e_bias=e_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
top_k=top_k,
|
||||
)
|
||||
# Copy results into pre-allocated buffers
|
||||
N = hidden_states.shape[0]
|
||||
out_weights[:N].copy_(result_w[:N])
|
||||
out_ids[:N].copy_(result_ids[:N])
|
||||
|
||||
@@ -929,6 +929,8 @@ class Nvfp4FusedRouterKernel:
|
||||
if fs5 > m0_s:
|
||||
m0_s = fs5; m0_i = fi5; m0_a = fa5; m0_k = cutlass.Int32(5)
|
||||
# Swap position 0 with the max (flat conditionals by position)
|
||||
t_s = fs0; t_i = fi0; t_a = fa0
|
||||
fs0 = m0_s; fi0 = m0_i; fa0 = m0_a
|
||||
if m0_k == cutlass.Int32(1):
|
||||
fs1 = t_s; fi1 = t_i; fa1 = t_a
|
||||
if m0_k == cutlass.Int32(2):
|
||||
|
||||
@@ -92,13 +92,22 @@ class Router:
|
||||
self.device = device
|
||||
|
||||
# ---- Parameters (filled by load_weights / finalize_weights) ----
|
||||
# Dense mode:
|
||||
# gate_lin: Nvfp4Linear for the gate projection (NVFP4 GEMM)
|
||||
# Fallback: W_gate BF16 + cuBLAS when NVFP4 scales not available
|
||||
# e_bias: [num_experts] FP32 — auxiliary-loss-free selection bias.
|
||||
# Dense mode — fused NVFP4 kernel (single-kernel, preferred):
|
||||
# gate_weight: raw NVFP4 gate weight tensor [K_packed, E_packed] uint8
|
||||
# gate_weight_scale: weight scale [K_sf, E_sf] FP8 E4M3
|
||||
# gate_ws2: weight_scale_2 (global scale base)
|
||||
# gate_input_scale: input_scale (activation global scale base)
|
||||
# Dense mode — 2-kernel NVFP4 path (fallback):
|
||||
# gate_lin: Nvfp4Linear for the gate projection
|
||||
# Dense mode — BF16 fallback:
|
||||
# W_gate: BF16 weight for cuBLAS when NVFP4 scales not available
|
||||
# Hash mode:
|
||||
# hash_lut: [vocab_size, top_k] int32 — precomputed expert IDs.
|
||||
self.gate_lin = None # Nvfp4Linear for NVFP4 gate projection
|
||||
self.gate_weight = None # Raw NVFP4 weight for fused kernel
|
||||
self.gate_weight_scale = None # FP8 E4M3 scale for fused kernel
|
||||
self.gate_ws2 = None # weight_scale_2 for fused kernel
|
||||
self.gate_input_scale = None # input_scale for fused kernel
|
||||
self.gate_lin = None # Nvfp4Linear for 2-kernel NVFP4 path
|
||||
self.W_gate: Optional[torch.Tensor] = None # BF16 fallback
|
||||
self.e_bias: Optional[torch.Tensor] = None
|
||||
self.hash_lut: Optional[torch.Tensor] = None
|
||||
@@ -145,7 +154,7 @@ class Router:
|
||||
self.hash_lut = hash_lut.to(device=self.device, dtype=torch.int32)
|
||||
|
||||
def load_nvfp4_gate(self, gate_lin) -> None:
|
||||
"""Set the NVFP4 gate linear layer (preferred over BF16 W_gate).
|
||||
"""Set the NVFP4 gate linear layer (2-kernel path).
|
||||
|
||||
Called by the single_shot after constructing the Nvfp4Linear
|
||||
from checkpoint NVFP4 scales. When set, _run_dense_impl uses
|
||||
@@ -153,6 +162,19 @@ class Router:
|
||||
"""
|
||||
self.gate_lin = gate_lin
|
||||
|
||||
def load_nvfp4_fused_gate(self, gate_weight, gate_weight_scale,
|
||||
gate_ws2, gate_input_scale) -> None:
|
||||
"""Set raw NVFP4 gate tensors for the fused single-kernel path.
|
||||
|
||||
Preferred over load_nvfp4_gate (2-kernel) when available.
|
||||
The fused kernel handles activation quantization + GEMM +
|
||||
router epilogue in a single kernel launch.
|
||||
"""
|
||||
self.gate_weight = gate_weight.to(device=self.device)
|
||||
self.gate_weight_scale = gate_weight_scale.to(device=self.device)
|
||||
self.gate_ws2 = gate_ws2.to(device=self.device) if gate_ws2 is not None else None
|
||||
self.gate_input_scale = gate_input_scale.to(device=self.device)
|
||||
|
||||
def finalize_weights(self) -> None:
|
||||
"""Allocate output buffers and JIT-compile the routing kernel.
|
||||
|
||||
@@ -242,16 +264,33 @@ class Router:
|
||||
# Called by the custom_op dispatch in dsv4/ops/router.py — not by user code.
|
||||
# ------------------------------------------------------------------
|
||||
def _run_dense_impl(self, hidden_states: torch.Tensor):
|
||||
"""Hot-path: NVFP4 GEMM or BF16 fallback + activation_topk.
|
||||
"""Hot-path: fused NVFP4, 2-kernel NVFP4, or BF16 fallback.
|
||||
|
||||
When gate_lin (Nvfp4Linear) is available, uses production NVFP4 GEMM.
|
||||
Otherwise falls back to BF16 cuBLAS.
|
||||
Priority:
|
||||
1. Fused NVFP4 kernel (single-kernel GEMM + router epilogue)
|
||||
2. 2-kernel NVFP4 path (Nvfp4Linear + activation_topk)
|
||||
3. BF16 cuBLAS fallback
|
||||
"""
|
||||
from dsv4.kernels.router import dense_router_dispatch, dense_router_dispatch_nvfp4
|
||||
N = hidden_states.shape[0]
|
||||
out_w = self._topk_weights_buf[:N]
|
||||
out_ids = self._topk_ids_buf[:N]
|
||||
if self.gate_lin is not None:
|
||||
if self.gate_weight is not None:
|
||||
# Fused single-kernel path (preferred)
|
||||
from dsv4.kernels.router import dense_router_dispatch_nvfp4_fused
|
||||
dense_router_dispatch_nvfp4_fused(
|
||||
hidden_states=hidden_states,
|
||||
gate_weight=self.gate_weight,
|
||||
gate_weight_scale=self.gate_weight_scale,
|
||||
gate_ws2=self.gate_ws2,
|
||||
gate_input_scale=self.gate_input_scale,
|
||||
e_bias=self.e_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
top_k=self.top_k,
|
||||
out_weights=out_w,
|
||||
out_ids=out_ids,
|
||||
)
|
||||
elif self.gate_lin is not None:
|
||||
from dsv4.kernels.router import dense_router_dispatch_nvfp4
|
||||
dense_router_dispatch_nvfp4(
|
||||
hidden_states=hidden_states,
|
||||
gate_lin=self.gate_lin,
|
||||
@@ -262,6 +301,7 @@ class Router:
|
||||
out_ids=out_ids,
|
||||
)
|
||||
else:
|
||||
from dsv4.kernels.router import dense_router_dispatch
|
||||
dense_router_dispatch(
|
||||
hidden_states=hidden_states,
|
||||
W_gate=self.W_gate,
|
||||
|
||||
@@ -697,10 +697,14 @@ def main():
|
||||
# Try NVFP4 gate weights first (production path)
|
||||
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(all_w, pfx, 'gate')
|
||||
if gate_w is not None and gate_ws is not None:
|
||||
# NVFP4 gate: build production Nvfp4Linear
|
||||
# NVFP4 gate: load raw tensors for fused single-kernel path
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
gate_lin = make_nvfp4_linear(H, cfg["n_routed_experts"], dev, all_w, pfx, 'gate')
|
||||
router.load_nvfp4_gate(gate_lin)
|
||||
router.load_nvfp4_fused_gate(
|
||||
gate_weight=gate_w.to(dev),
|
||||
gate_weight_scale=gate_ws.to(dev),
|
||||
gate_ws2=gate_ws2.to(dev) if gate_ws2 is not None else torch.tensor(1.0, device=dev),
|
||||
gate_input_scale=gate_isc.to(dev) if gate_isc is not None else torch.tensor(1.0 / (6.0 * 448.0), device=dev),
|
||||
)
|
||||
else:
|
||||
# BF16 fallback
|
||||
gw = all_w.get(f"{pfx}.gate.weight")
|
||||
|
||||
Reference in New Issue
Block a user