From db6e3545da2382bead90c23f9186884937fc3692 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 1 Jun 2026 17:25:04 +0000 Subject: [PATCH] Fix: add _use_runtime_gsa=True to router gate GEMM in single_shot MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The checkpoint-path gate was using the checkpoint's input_scale as gsa — the same E4M3 overflow bug we fixed in Nvfp4Linear/Nvfp4MoE/etc. The runtime-quantized BF16 path was using 1/(6*448) as a fixed gsa. Both now compute gsa from actual activation magnitude at runtime. --- single_shot_inference.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/single_shot_inference.py b/single_shot_inference.py index cc8eafd4..d33f4160 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -728,7 +728,8 @@ def main(): isc_v = gate_isc.float().item() if gate_isc is not None else 1.0/(6.0*448.0) gate_lin.gs = [1.0] gate_lin.ws2 = [torch.tensor([ws2_v], device=dev, dtype=torch.float32)] - gate_lin._activation_global_scale = isc_v + gate_lin._activation_global_scale = isc_v # placeholder — runtime gsa overrides this + gate_lin._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow gate_lin.finalize_weights() router.load_nvfp4_gate(gate_lin) router.load_weights(e_bias=eb.to(dev, torch.float32)) @@ -746,7 +747,8 @@ def main(): gate_lin.sf = [g_sf] gate_lin.gs = [g_gs] gate_lin.ws2 = [torch.tensor([g_gs], device=dev, dtype=torch.float32)] - gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0) + gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0) # placeholder — runtime gsa overrides + gate_lin._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow gate_lin.finalize_weights() router.load_nvfp4_gate(gate_lin) router.load_weights(e_bias=eb.to(dev, torch.float32))