Quantize BF16 gate weight to NVFP4 for fused router + add global scales to GEMM

CRITICAL: Checkpoint stores gate weights as BF16, not NVFP4.
Previous code fell back to BF16 cuBLAS because weight_scale was missing.
Now we quantize the BF16 gate weight to NVFP4 at load time using
quantize_to_nvfp4() and pass the result to the fused router kernel.

Also added global scale (gsa, gsb) parameters to the kernel:
- gsa (activation global scale) applied during activation quantization
- gsb (weight global scale) applied in epilogue before sqrt(softplus)
- The MMA output is (A * SFA) @ (B * SFB), missing gsa*gsb
- Epilogue now computes sqrt(softplus(logit * gsa * gsb))
  instead of sqrt(softplus(logit))
This commit is contained in:
2026-06-01 10:14:29 +00:00
parent 24fed15ed6
commit 71deeb91a9
2 changed files with 31 additions and 9 deletions

View File

@@ -262,7 +262,7 @@ class Nvfp4FusedRouterKernel:
# run() — Python entry point
# -----------------------------------------------------------------
def run(self, mat_a, mat_b, scale_a, scale_b, mat_c,
M, N, K, stream=None):
M, N, K, gsa, gsb, stream=None):
if stream is None:
stream = cuda.CUstream(0)
@@ -336,7 +336,7 @@ class Nvfp4FusedRouterKernel:
self.c_smem_layout_staged,
self.epi_tile,
tile_sched_params,
M, N, K,
M, N, K, gsa, gsb,
).launch(
grid=grid, block=[self.threads_per_cta, 1, 1],
cluster=(*self.cluster_shape_mn, 1),
@@ -359,7 +359,7 @@ class Nvfp4FusedRouterKernel:
c_smem_layout_staged,
epi_tile,
tile_sched_params,
M, N, K):
M, N, K, gsa, gsb):
warp_idx = cute.arch.warp_idx()
warp_idx = cute.arch.make_warp_uniform(warp_idx)
@@ -723,12 +723,15 @@ class Nvfp4FusedRouterKernel:
acc_pipeline.consumer_release(acc_cs)
acc_cs.advance()
# Activation: sqrt(softplus(logit))
# softplus(x) = max(x, 0) + log(1 + exp(-|x|))
# This replaces SwiGLU in the MoE epilogue
# Activation: sqrt(softplus(logit * gsa * gsb))
# Global scales are applied before the activation, same as
# how MoE epilogue applies them before SwiGLU.
# The MMA output is (A * SFA) @ (B * SFB), missing gsa*gsb.
scale = cutlass.Float32(gsa * gsb)
acc_vec = tTR_rAcc.load()
for e in cutlass.range(cute.size(acc_vec), unroll=4):
logit = acc_vec[e]
logit = acc_vec[e] * scale
# softplus(x) = max(x, 0) + log(1 + exp(-|x|))
abs_x = cute.math.absf(logit)
pos = cute.math.fmax(logit, cutlass.Float32(0.0))
exp_neg = cute.math.exp(-abs_x)
@@ -856,6 +859,8 @@ def run_nvfp4_fused_router(
scale_b=cute_sfb,
mat_c=cute_c,
M=N, N=E, K=K,
gsa=gsa,
gsb=gsb_val,
)
# Add e_bias (selection bias) and run top-k

View File

@@ -706,11 +706,28 @@ def main():
gate_input_scale=gate_isc.to(dev) if gate_isc is not None else torch.tensor(1.0 / (6.0 * 448.0), device=dev),
)
else:
# BF16 fallback
# BF16 gate weight: quantize to NVFP4 for fused kernel
gw = all_w.get(f"{pfx}.gate.weight")
if gw is not None:
if gw.shape == (cfg["n_routed_experts"], H): gw = gw.T.contiguous()
router.load_weights(W_gate=gw.bfloat16().to(dev), e_bias=eb.to(dev, torch.float32))
gw = gw.bfloat16().to(dev)
# Quantize BF16 → NVFP4 for fused router kernel
from dsv4.ops.quantize import quantize_to_nvfp4
gw_fp4, gw_sf, gw_gs = quantize_to_nvfp4(gw)
router.load_weights(e_bias=eb.to(dev, torch.float32))
# gsb (weight global scale) = gw_gs from quantization
# gsa (activation global scale) = 1.0 (applied during activation quantization inside kernel)
# Actually: gsa is passed to quantize_activation_nvfp4 inside run_nvfp4_fused_router
# We need to compute the correct gsa. For NVFP4, gsa = 1/(max_val * 448)
# But since activation is quantized at runtime, gsa = input_scale from Nvfp4Linear = 1/(6*448)
router.load_nvfp4_fused_gate(
gate_weight=gw_fp4,
gate_weight_scale=gw_sf,
gate_ws2=torch.tensor(gw_gs, device=dev), # gsb = weight global scale
gate_input_scale=torch.tensor(1.0 / (6.0 * 448.0), device=dev), # gsa = activation global scale
)
else:
router.load_weights(e_bias=eb.to(dev, torch.float32))
router.finalize_weights(); routers[li] = router
moe = Nvfp4MoE(num_experts=cfg["n_routed_experts"], hidden_size=H,