From 71deeb91a9b13ce024ee2013737f7fe1537b3706 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 1 Jun 2026 10:14:29 +0000 Subject: [PATCH] Quantize BF16 gate weight to NVFP4 for fused router + add global scales to GEMM CRITICAL: Checkpoint stores gate weights as BF16, not NVFP4. Previous code fell back to BF16 cuBLAS because weight_scale was missing. Now we quantize the BF16 gate weight to NVFP4 at load time using quantize_to_nvfp4() and pass the result to the fused router kernel. Also added global scale (gsa, gsb) parameters to the kernel: - gsa (activation global scale) applied during activation quantization - gsb (weight global scale) applied in epilogue before sqrt(softplus) - The MMA output is (A * SFA) @ (B * SFB), missing gsa*gsb - Epilogue now computes sqrt(softplus(logit * gsa * gsb)) instead of sqrt(softplus(logit)) --- .../router/nvfp4_fused_router_kernel.py | 19 ++++++++++------- single_shot_inference.py | 21 +++++++++++++++++-- 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/dsv4/kernels/router/nvfp4_fused_router_kernel.py b/dsv4/kernels/router/nvfp4_fused_router_kernel.py index 224cd52f..3445bb82 100644 --- a/dsv4/kernels/router/nvfp4_fused_router_kernel.py +++ b/dsv4/kernels/router/nvfp4_fused_router_kernel.py @@ -262,7 +262,7 @@ class Nvfp4FusedRouterKernel: # run() — Python entry point # ----------------------------------------------------------------- def run(self, mat_a, mat_b, scale_a, scale_b, mat_c, - M, N, K, stream=None): + M, N, K, gsa, gsb, stream=None): if stream is None: stream = cuda.CUstream(0) @@ -336,7 +336,7 @@ class Nvfp4FusedRouterKernel: self.c_smem_layout_staged, self.epi_tile, tile_sched_params, - M, N, K, + M, N, K, gsa, gsb, ).launch( grid=grid, block=[self.threads_per_cta, 1, 1], cluster=(*self.cluster_shape_mn, 1), @@ -359,7 +359,7 @@ class Nvfp4FusedRouterKernel: c_smem_layout_staged, epi_tile, tile_sched_params, - M, N, K): + M, N, K, gsa, gsb): warp_idx = cute.arch.warp_idx() warp_idx = cute.arch.make_warp_uniform(warp_idx) @@ -723,12 +723,15 @@ class Nvfp4FusedRouterKernel: acc_pipeline.consumer_release(acc_cs) acc_cs.advance() - # Activation: sqrt(softplus(logit)) - # softplus(x) = max(x, 0) + log(1 + exp(-|x|)) - # This replaces SwiGLU in the MoE epilogue + # Activation: sqrt(softplus(logit * gsa * gsb)) + # Global scales are applied before the activation, same as + # how MoE epilogue applies them before SwiGLU. + # The MMA output is (A * SFA) @ (B * SFB), missing gsa*gsb. + scale = cutlass.Float32(gsa * gsb) acc_vec = tTR_rAcc.load() for e in cutlass.range(cute.size(acc_vec), unroll=4): - logit = acc_vec[e] + logit = acc_vec[e] * scale + # softplus(x) = max(x, 0) + log(1 + exp(-|x|)) abs_x = cute.math.absf(logit) pos = cute.math.fmax(logit, cutlass.Float32(0.0)) exp_neg = cute.math.exp(-abs_x) @@ -856,6 +859,8 @@ def run_nvfp4_fused_router( scale_b=cute_sfb, mat_c=cute_c, M=N, N=E, K=K, + gsa=gsa, + gsb=gsb_val, ) # Add e_bias (selection bias) and run top-k diff --git a/single_shot_inference.py b/single_shot_inference.py index 001a3853..52fcbe13 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -706,11 +706,28 @@ def main(): gate_input_scale=gate_isc.to(dev) if gate_isc is not None else torch.tensor(1.0 / (6.0 * 448.0), device=dev), ) else: - # BF16 fallback + # BF16 gate weight: quantize to NVFP4 for fused kernel gw = all_w.get(f"{pfx}.gate.weight") if gw is not None: if gw.shape == (cfg["n_routed_experts"], H): gw = gw.T.contiguous() - router.load_weights(W_gate=gw.bfloat16().to(dev), e_bias=eb.to(dev, torch.float32)) + gw = gw.bfloat16().to(dev) + # Quantize BF16 → NVFP4 for fused router kernel + from dsv4.ops.quantize import quantize_to_nvfp4 + gw_fp4, gw_sf, gw_gs = quantize_to_nvfp4(gw) + router.load_weights(e_bias=eb.to(dev, torch.float32)) + # gsb (weight global scale) = gw_gs from quantization + # gsa (activation global scale) = 1.0 (applied during activation quantization inside kernel) + # Actually: gsa is passed to quantize_activation_nvfp4 inside run_nvfp4_fused_router + # We need to compute the correct gsa. For NVFP4, gsa = 1/(max_val * 448) + # But since activation is quantized at runtime, gsa = input_scale from Nvfp4Linear = 1/(6*448) + router.load_nvfp4_fused_gate( + gate_weight=gw_fp4, + gate_weight_scale=gw_sf, + gate_ws2=torch.tensor(gw_gs, device=dev), # gsb = weight global scale + gate_input_scale=torch.tensor(1.0 / (6.0 * 448.0), device=dev), # gsa = activation global scale + ) + else: + router.load_weights(e_bias=eb.to(dev, torch.float32)) router.finalize_weights(); routers[li] = router moe = Nvfp4MoE(num_experts=cfg["n_routed_experts"], hidden_size=H,