Fix: add _use_runtime_gsa=True to router gate GEMM in single_shot

The checkpoint-path gate was using the checkpoint's input_scale as gsa
— the same E4M3 overflow bug we fixed in Nvfp4Linear/Nvfp4MoE/etc.
The runtime-quantized BF16 path was using 1/(6*448) as a fixed gsa.

Both now compute gsa from actual activation magnitude at runtime.
This commit is contained in:
2026-06-01 17:25:04 +00:00
parent 9d57b0453b
commit db6e3545da

View File

@@ -728,7 +728,8 @@ def main():
isc_v = gate_isc.float().item() if gate_isc is not None else 1.0/(6.0*448.0)
gate_lin.gs = [1.0]
gate_lin.ws2 = [torch.tensor([ws2_v], device=dev, dtype=torch.float32)]
gate_lin._activation_global_scale = isc_v
gate_lin._activation_global_scale = isc_v # placeholder — runtime gsa overrides this
gate_lin._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow
gate_lin.finalize_weights()
router.load_nvfp4_gate(gate_lin)
router.load_weights(e_bias=eb.to(dev, torch.float32))
@@ -746,7 +747,8 @@ def main():
gate_lin.sf = [g_sf]
gate_lin.gs = [g_gs]
gate_lin.ws2 = [torch.tensor([g_gs], device=dev, dtype=torch.float32)]
gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0)
gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0) # placeholder — runtime gsa overrides
gate_lin._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow
gate_lin.finalize_weights()
router.load_nvfp4_gate(gate_lin)
router.load_weights(e_bias=eb.to(dev, torch.float32))