feat: NVFP4 gate projection for router (replaces BF16 cuBLAS)

The dense router now uses NVFP4 GEMM via Nvfp4Linear for the gate
projection when NVFP4 scales are available in the checkpoint. This
replaces the BF16 cuBLAS GEMM with Blackwell SM100 tensor-core
NVFP4 acceleration.

Changes:
- dsv4/layers/router.py: add gate_lin (Nvfp4Linear) alongside W_gate
  fallback. New load_nvfp4_gate() method.
- dsv4/kernels/router/dense_router_decode.py: add
  dense_router_dispatch_nvfp4() using Nvfp4Linear + activation_topk
- dsv4/kernels/router/__init__.py: export new function
- single_shot_inference.py: load NVFP4 gate weights when available,
  fall back to BF16 when not
This commit is contained in:
2026-06-01 05:58:56 +00:00
parent 9f14cb17d1
commit cf2b7ab7ec
4 changed files with 92 additions and 36 deletions

View File

@@ -1,11 +1,12 @@
"""DSV4 Router kernels — dispatch and CUDA kernel wrappers.
Exports:
dense_router_dispatch: GEMM + fused activation + top-k (all N)
dense_router_dispatch: BF16 GEMM + fused activation + top-k (fallback)
dense_router_dispatch_nvfp4: NVFP4 GEMM + fused activation + top-k (production)
hash_router_dispatch: Hash routing via precomputed LUT gather
"""
from dsv4.kernels.router.dense_router_decode import dense_router_dispatch
from dsv4.kernels.router.dense_router_decode import dense_router_dispatch, dense_router_dispatch_nvfp4
def hash_router_dispatch(

View File

@@ -1,12 +1,15 @@
"""DSV4 Dense Router — BF16 GEMM + sqrt(softplus) + bias + top-k.
"""DSV4 Dense Router — NVFP4 GEMM + sqrt(softplus) + bias + top-k.
Production path: BF16 GEMM via cuBLAS (tensor cores on Blackwell) followed by
the fused activation_topk CUDA kernel for sqrt(softplus) + bias + top-k + renorm.
Production path: NVFP4 GEMM via Nvfp4Linear (Blackwell tensor cores)
followed by the fused activation_topk CUDA kernel for sqrt(softplus) +
bias + top-k + renorm.
The CuTeDSL fused GEMM+epilogue kernel was attempted but make_trivial_tiled_mma
for BF16 on SM100 has no working reference in our codebase (all other GEMMs use
NVFP4 blockscaled MMA). The unfused path is production-grade: cuBLAS uses SM100
tensor cores, and activation_topk is a real CUDA kernel (not PyTorch).
BF16 cuBLAS fallback: When NVFP4 scales are not available in the
checkpoint, dense_router_dispatch uses torch.nn.functional.linear
(cuBLAS, SM100 tensor cores) instead.
The CuTeDSL fused GEMM+epilogue kernel (dense_router_decode_kernel.py)
is the future single-kernel path but is not yet production-ready.
"""
from __future__ import annotations
@@ -23,7 +26,7 @@ def dense_router_dispatch(
out_weights: torch.Tensor, # [N, top_k] FP32, pre-allocated
out_ids: torch.Tensor, # [N, top_k] int32, pre-allocated
):
"""Dispatch the dense router.
"""Dispatch the dense router (BF16 cuBLAS fallback).
BF16 GEMM via torch.nn.functional.linear (cuBLAS, SM100 tensor cores),
then fused activation + top-k via the CUDA kernel.
@@ -34,3 +37,25 @@ def dense_router_dispatch(
logits, e_bias, routed_scaling_factor, top_k,
out_weights, out_ids,
)
def dense_router_dispatch_nvfp4(
hidden_states: torch.Tensor, # [N, hidden_size] BF16
gate_lin, # Nvfp4Linear instance
e_bias: torch.Tensor, # [num_experts] FP32
routed_scaling_factor: float,
top_k: int,
out_weights: torch.Tensor, # [N, top_k] FP32, pre-allocated
out_ids: torch.Tensor, # [N, top_k] int32, pre-allocated
):
"""Dispatch the dense router (NVFP4 production GEMM).
NVFP4 GEMM via Nvfp4Linear (Blackwell SM100 tensor cores),
then fused activation + top-k via the CUDA kernel.
"""
logits = gate_lin(hidden_states).float() # (N, E) FP32
from dsv4.kernels.router._activation_topk import run_fused_activation_topk
run_fused_activation_topk(
logits, e_bias, routed_scaling_factor, top_k,
out_weights, out_ids,
)

View File

@@ -93,11 +93,13 @@ class Router:
# ---- Parameters (filled by load_weights / finalize_weights) ----
# Dense mode:
# W_gate: [hidden_size, num_experts] BF16
# gate_lin: Nvfp4Linear for the gate projection (NVFP4 GEMM)
# Fallback: W_gate BF16 + cuBLAS when NVFP4 scales not available
# e_bias: [num_experts] FP32 — auxiliary-loss-free selection bias.
# Hash mode:
# hash_lut: [vocab_size, top_k] int32 — precomputed expert IDs.
self.W_gate: Optional[torch.Tensor] = None
self.gate_lin = None # Nvfp4Linear for NVFP4 gate projection
self.W_gate: Optional[torch.Tensor] = None # BF16 fallback
self.e_bias: Optional[torch.Tensor] = None
self.hash_lut: Optional[torch.Tensor] = None
@@ -124,15 +126,14 @@ class Router:
nearly always loader bugs and silent acceptance would mask them.
"""
if self.mode == "dense":
if W_gate is None or e_bias is None:
raise ValueError("dense router needs both W_gate and e_bias")
assert W_gate.shape == (self.hidden_size, self.num_experts), \
f"W_gate shape {tuple(W_gate.shape)} != " \
f"{(self.hidden_size, self.num_experts)}"
if e_bias is None:
raise ValueError("dense router needs e_bias")
assert e_bias.shape == (self.num_experts,), \
f"e_bias shape {tuple(e_bias.shape)} != ({self.num_experts},)"
self.W_gate = W_gate.to(device=self.device, dtype=torch.bfloat16)
self.e_bias = e_bias.to(device=self.device, dtype=torch.float32)
if W_gate is not None:
self.W_gate = W_gate.to(device=self.device, dtype=torch.bfloat16)
# gate_lin is set separately via load_nvfp4_gate()
else: # hash
if hash_lut is None:
raise ValueError("hash router needs hash_lut")
@@ -143,6 +144,15 @@ class Router:
"hash_lut contains out-of-range expert IDs"
self.hash_lut = hash_lut.to(device=self.device, dtype=torch.int32)
def load_nvfp4_gate(self, gate_lin) -> None:
"""Set the NVFP4 gate linear layer (preferred over BF16 W_gate).
Called by the single_shot after constructing the Nvfp4Linear
from checkpoint NVFP4 scales. When set, _run_dense_impl uses
the production NVFP4 GEMM path instead of BF16 cuBLAS.
"""
self.gate_lin = gate_lin
def finalize_weights(self) -> None:
"""Allocate output buffers and JIT-compile the routing kernel.
@@ -232,25 +242,35 @@ class Router:
# Called by the custom_op dispatch in dsv4/ops/router.py — not by user code.
# ------------------------------------------------------------------
def _run_dense_impl(self, hidden_states: torch.Tensor):
"""Hot-path entry into the fused decode/prefill kernel.
"""Hot-path: NVFP4 GEMM or BF16 fallback + activation_topk.
Implementation lives in dsv4/kernels/router/dense_router_decode.py
(small N) or dsv4/kernels/router/dense_router_prefill.py (large N).
The selection is internal to that module — Router doesn't care.
When gate_lin (Nvfp4Linear) is available, uses production NVFP4 GEMM.
Otherwise falls back to BF16 cuBLAS.
"""
from dsv4.kernels.router import dense_router_dispatch
from dsv4.kernels.router import dense_router_dispatch, dense_router_dispatch_nvfp4
N = hidden_states.shape[0]
out_w = self._topk_weights_buf[:N]
out_ids = self._topk_ids_buf[:N]
dense_router_dispatch(
hidden_states=hidden_states,
W_gate=self.W_gate,
e_bias=self.e_bias,
routed_scaling_factor=self.routed_scaling_factor,
top_k=self.top_k,
out_weights=out_w,
out_ids=out_ids,
)
if self.gate_lin is not None:
dense_router_dispatch_nvfp4(
hidden_states=hidden_states,
gate_lin=self.gate_lin,
e_bias=self.e_bias,
routed_scaling_factor=self.routed_scaling_factor,
top_k=self.top_k,
out_weights=out_w,
out_ids=out_ids,
)
else:
dense_router_dispatch(
hidden_states=hidden_states,
W_gate=self.W_gate,
e_bias=self.e_bias,
routed_scaling_factor=self.routed_scaling_factor,
top_k=self.top_k,
out_weights=out_w,
out_ids=out_ids,
)
return out_w, out_ids
def _run_hash_impl(self, token_ids: torch.Tensor):

View File

@@ -665,10 +665,20 @@ def main():
if is_hash:
router.load_weights(hash_lut=all_w[f"{pfx}.gate.tid2eid"].to(dev, torch.int32))
else:
gw = all_w.get(f"{pfx}.gate.weight"); eb = all_w.get(f"{pfx}.gate.e_score_correction_bias")
if gw is not None and eb is not None:
if gw.shape == (cfg["n_routed_experts"], H): gw = gw.T.contiguous()
router.load_weights(W_gate=gw.bfloat16().to(dev), e_bias=eb.to(dev, torch.float32))
eb = all_w.get(f"{pfx}.gate.e_score_correction_bias")
# Try NVFP4 gate weights first (production path)
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(all_w, pfx, 'gate')
if gate_w is not None and gate_ws is not None:
# NVFP4 gate: build production Nvfp4Linear
router.load_weights(e_bias=eb.to(dev, torch.float32))
gate_lin = make_nvfp4_linear(H, cfg["n_routed_experts"], dev, all_w, pfx, 'gate')
router.load_nvfp4_gate(gate_lin)
else:
# BF16 fallback
gw = all_w.get(f"{pfx}.gate.weight")
if gw is not None:
if gw.shape == (cfg["n_routed_experts"], H): gw = gw.T.contiguous()
router.load_weights(W_gate=gw.bfloat16().to(dev), e_bias=eb.to(dev, torch.float32))
router.finalize_weights(); routers[li] = router
moe = Nvfp4MoE(num_experts=cfg["n_routed_experts"], hidden_size=H,