feat: NVFP4 gate projection for router (replaces BF16 cuBLAS)
The dense router now uses NVFP4 GEMM via Nvfp4Linear for the gate projection when NVFP4 scales are available in the checkpoint. This replaces the BF16 cuBLAS GEMM with Blackwell SM100 tensor-core NVFP4 acceleration. Changes: - dsv4/layers/router.py: add gate_lin (Nvfp4Linear) alongside W_gate fallback. New load_nvfp4_gate() method. - dsv4/kernels/router/dense_router_decode.py: add dense_router_dispatch_nvfp4() using Nvfp4Linear + activation_topk - dsv4/kernels/router/__init__.py: export new function - single_shot_inference.py: load NVFP4 gate weights when available, fall back to BF16 when not
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
"""DSV4 Router kernels — dispatch and CUDA kernel wrappers.
|
||||
|
||||
Exports:
|
||||
dense_router_dispatch: GEMM + fused activation + top-k (all N)
|
||||
dense_router_dispatch: BF16 GEMM + fused activation + top-k (fallback)
|
||||
dense_router_dispatch_nvfp4: NVFP4 GEMM + fused activation + top-k (production)
|
||||
hash_router_dispatch: Hash routing via precomputed LUT gather
|
||||
"""
|
||||
|
||||
from dsv4.kernels.router.dense_router_decode import dense_router_dispatch
|
||||
from dsv4.kernels.router.dense_router_decode import dense_router_dispatch, dense_router_dispatch_nvfp4
|
||||
|
||||
|
||||
def hash_router_dispatch(
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
"""DSV4 Dense Router — BF16 GEMM + sqrt(softplus) + bias + top-k.
|
||||
"""DSV4 Dense Router — NVFP4 GEMM + sqrt(softplus) + bias + top-k.
|
||||
|
||||
Production path: BF16 GEMM via cuBLAS (tensor cores on Blackwell) followed by
|
||||
the fused activation_topk CUDA kernel for sqrt(softplus) + bias + top-k + renorm.
|
||||
Production path: NVFP4 GEMM via Nvfp4Linear (Blackwell tensor cores)
|
||||
followed by the fused activation_topk CUDA kernel for sqrt(softplus) +
|
||||
bias + top-k + renorm.
|
||||
|
||||
The CuTeDSL fused GEMM+epilogue kernel was attempted but make_trivial_tiled_mma
|
||||
for BF16 on SM100 has no working reference in our codebase (all other GEMMs use
|
||||
NVFP4 blockscaled MMA). The unfused path is production-grade: cuBLAS uses SM100
|
||||
tensor cores, and activation_topk is a real CUDA kernel (not PyTorch).
|
||||
BF16 cuBLAS fallback: When NVFP4 scales are not available in the
|
||||
checkpoint, dense_router_dispatch uses torch.nn.functional.linear
|
||||
(cuBLAS, SM100 tensor cores) instead.
|
||||
|
||||
The CuTeDSL fused GEMM+epilogue kernel (dense_router_decode_kernel.py)
|
||||
is the future single-kernel path but is not yet production-ready.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -23,7 +26,7 @@ def dense_router_dispatch(
|
||||
out_weights: torch.Tensor, # [N, top_k] FP32, pre-allocated
|
||||
out_ids: torch.Tensor, # [N, top_k] int32, pre-allocated
|
||||
):
|
||||
"""Dispatch the dense router.
|
||||
"""Dispatch the dense router (BF16 cuBLAS fallback).
|
||||
|
||||
BF16 GEMM via torch.nn.functional.linear (cuBLAS, SM100 tensor cores),
|
||||
then fused activation + top-k via the CUDA kernel.
|
||||
@@ -34,3 +37,25 @@ def dense_router_dispatch(
|
||||
logits, e_bias, routed_scaling_factor, top_k,
|
||||
out_weights, out_ids,
|
||||
)
|
||||
|
||||
|
||||
def dense_router_dispatch_nvfp4(
|
||||
hidden_states: torch.Tensor, # [N, hidden_size] BF16
|
||||
gate_lin, # Nvfp4Linear instance
|
||||
e_bias: torch.Tensor, # [num_experts] FP32
|
||||
routed_scaling_factor: float,
|
||||
top_k: int,
|
||||
out_weights: torch.Tensor, # [N, top_k] FP32, pre-allocated
|
||||
out_ids: torch.Tensor, # [N, top_k] int32, pre-allocated
|
||||
):
|
||||
"""Dispatch the dense router (NVFP4 production GEMM).
|
||||
|
||||
NVFP4 GEMM via Nvfp4Linear (Blackwell SM100 tensor cores),
|
||||
then fused activation + top-k via the CUDA kernel.
|
||||
"""
|
||||
logits = gate_lin(hidden_states).float() # (N, E) FP32
|
||||
from dsv4.kernels.router._activation_topk import run_fused_activation_topk
|
||||
run_fused_activation_topk(
|
||||
logits, e_bias, routed_scaling_factor, top_k,
|
||||
out_weights, out_ids,
|
||||
)
|
||||
|
||||
@@ -93,11 +93,13 @@ class Router:
|
||||
|
||||
# ---- Parameters (filled by load_weights / finalize_weights) ----
|
||||
# Dense mode:
|
||||
# W_gate: [hidden_size, num_experts] BF16
|
||||
# gate_lin: Nvfp4Linear for the gate projection (NVFP4 GEMM)
|
||||
# Fallback: W_gate BF16 + cuBLAS when NVFP4 scales not available
|
||||
# e_bias: [num_experts] FP32 — auxiliary-loss-free selection bias.
|
||||
# Hash mode:
|
||||
# hash_lut: [vocab_size, top_k] int32 — precomputed expert IDs.
|
||||
self.W_gate: Optional[torch.Tensor] = None
|
||||
self.gate_lin = None # Nvfp4Linear for NVFP4 gate projection
|
||||
self.W_gate: Optional[torch.Tensor] = None # BF16 fallback
|
||||
self.e_bias: Optional[torch.Tensor] = None
|
||||
self.hash_lut: Optional[torch.Tensor] = None
|
||||
|
||||
@@ -124,15 +126,14 @@ class Router:
|
||||
nearly always loader bugs and silent acceptance would mask them.
|
||||
"""
|
||||
if self.mode == "dense":
|
||||
if W_gate is None or e_bias is None:
|
||||
raise ValueError("dense router needs both W_gate and e_bias")
|
||||
assert W_gate.shape == (self.hidden_size, self.num_experts), \
|
||||
f"W_gate shape {tuple(W_gate.shape)} != " \
|
||||
f"{(self.hidden_size, self.num_experts)}"
|
||||
if e_bias is None:
|
||||
raise ValueError("dense router needs e_bias")
|
||||
assert e_bias.shape == (self.num_experts,), \
|
||||
f"e_bias shape {tuple(e_bias.shape)} != ({self.num_experts},)"
|
||||
self.W_gate = W_gate.to(device=self.device, dtype=torch.bfloat16)
|
||||
self.e_bias = e_bias.to(device=self.device, dtype=torch.float32)
|
||||
if W_gate is not None:
|
||||
self.W_gate = W_gate.to(device=self.device, dtype=torch.bfloat16)
|
||||
# gate_lin is set separately via load_nvfp4_gate()
|
||||
else: # hash
|
||||
if hash_lut is None:
|
||||
raise ValueError("hash router needs hash_lut")
|
||||
@@ -143,6 +144,15 @@ class Router:
|
||||
"hash_lut contains out-of-range expert IDs"
|
||||
self.hash_lut = hash_lut.to(device=self.device, dtype=torch.int32)
|
||||
|
||||
def load_nvfp4_gate(self, gate_lin) -> None:
|
||||
"""Set the NVFP4 gate linear layer (preferred over BF16 W_gate).
|
||||
|
||||
Called by the single_shot after constructing the Nvfp4Linear
|
||||
from checkpoint NVFP4 scales. When set, _run_dense_impl uses
|
||||
the production NVFP4 GEMM path instead of BF16 cuBLAS.
|
||||
"""
|
||||
self.gate_lin = gate_lin
|
||||
|
||||
def finalize_weights(self) -> None:
|
||||
"""Allocate output buffers and JIT-compile the routing kernel.
|
||||
|
||||
@@ -232,25 +242,35 @@ class Router:
|
||||
# Called by the custom_op dispatch in dsv4/ops/router.py — not by user code.
|
||||
# ------------------------------------------------------------------
|
||||
def _run_dense_impl(self, hidden_states: torch.Tensor):
|
||||
"""Hot-path entry into the fused decode/prefill kernel.
|
||||
"""Hot-path: NVFP4 GEMM or BF16 fallback + activation_topk.
|
||||
|
||||
Implementation lives in dsv4/kernels/router/dense_router_decode.py
|
||||
(small N) or dsv4/kernels/router/dense_router_prefill.py (large N).
|
||||
The selection is internal to that module — Router doesn't care.
|
||||
When gate_lin (Nvfp4Linear) is available, uses production NVFP4 GEMM.
|
||||
Otherwise falls back to BF16 cuBLAS.
|
||||
"""
|
||||
from dsv4.kernels.router import dense_router_dispatch
|
||||
from dsv4.kernels.router import dense_router_dispatch, dense_router_dispatch_nvfp4
|
||||
N = hidden_states.shape[0]
|
||||
out_w = self._topk_weights_buf[:N]
|
||||
out_ids = self._topk_ids_buf[:N]
|
||||
dense_router_dispatch(
|
||||
hidden_states=hidden_states,
|
||||
W_gate=self.W_gate,
|
||||
e_bias=self.e_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
top_k=self.top_k,
|
||||
out_weights=out_w,
|
||||
out_ids=out_ids,
|
||||
)
|
||||
if self.gate_lin is not None:
|
||||
dense_router_dispatch_nvfp4(
|
||||
hidden_states=hidden_states,
|
||||
gate_lin=self.gate_lin,
|
||||
e_bias=self.e_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
top_k=self.top_k,
|
||||
out_weights=out_w,
|
||||
out_ids=out_ids,
|
||||
)
|
||||
else:
|
||||
dense_router_dispatch(
|
||||
hidden_states=hidden_states,
|
||||
W_gate=self.W_gate,
|
||||
e_bias=self.e_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
top_k=self.top_k,
|
||||
out_weights=out_w,
|
||||
out_ids=out_ids,
|
||||
)
|
||||
return out_w, out_ids
|
||||
|
||||
def _run_hash_impl(self, token_ids: torch.Tensor):
|
||||
|
||||
@@ -665,10 +665,20 @@ def main():
|
||||
if is_hash:
|
||||
router.load_weights(hash_lut=all_w[f"{pfx}.gate.tid2eid"].to(dev, torch.int32))
|
||||
else:
|
||||
gw = all_w.get(f"{pfx}.gate.weight"); eb = all_w.get(f"{pfx}.gate.e_score_correction_bias")
|
||||
if gw is not None and eb is not None:
|
||||
if gw.shape == (cfg["n_routed_experts"], H): gw = gw.T.contiguous()
|
||||
router.load_weights(W_gate=gw.bfloat16().to(dev), e_bias=eb.to(dev, torch.float32))
|
||||
eb = all_w.get(f"{pfx}.gate.e_score_correction_bias")
|
||||
# Try NVFP4 gate weights first (production path)
|
||||
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(all_w, pfx, 'gate')
|
||||
if gate_w is not None and gate_ws is not None:
|
||||
# NVFP4 gate: build production Nvfp4Linear
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
gate_lin = make_nvfp4_linear(H, cfg["n_routed_experts"], dev, all_w, pfx, 'gate')
|
||||
router.load_nvfp4_gate(gate_lin)
|
||||
else:
|
||||
# BF16 fallback
|
||||
gw = all_w.get(f"{pfx}.gate.weight")
|
||||
if gw is not None:
|
||||
if gw.shape == (cfg["n_routed_experts"], H): gw = gw.T.contiguous()
|
||||
router.load_weights(W_gate=gw.bfloat16().to(dev), e_bias=eb.to(dev, torch.float32))
|
||||
router.finalize_weights(); routers[li] = router
|
||||
|
||||
moe = Nvfp4MoE(num_experts=cfg["n_routed_experts"], hidden_size=H,
|
||||
|
||||
Reference in New Issue
Block a user