diff --git a/single_shot_inference.py b/single_shot_inference.py index cc8eafd4..d33f4160 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -728,7 +728,8 @@ def main(): isc_v = gate_isc.float().item() if gate_isc is not None else 1.0/(6.0*448.0) gate_lin.gs = [1.0] gate_lin.ws2 = [torch.tensor([ws2_v], device=dev, dtype=torch.float32)] - gate_lin._activation_global_scale = isc_v + gate_lin._activation_global_scale = isc_v # placeholder — runtime gsa overrides this + gate_lin._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow gate_lin.finalize_weights() router.load_nvfp4_gate(gate_lin) router.load_weights(e_bias=eb.to(dev, torch.float32)) @@ -746,7 +747,8 @@ def main(): gate_lin.sf = [g_sf] gate_lin.gs = [g_gs] gate_lin.ws2 = [torch.tensor([g_gs], device=dev, dtype=torch.float32)] - gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0) + gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0) # placeholder — runtime gsa overrides + gate_lin._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow gate_lin.finalize_weights() router.load_nvfp4_gate(gate_lin) router.load_weights(e_bias=eb.to(dev, torch.float32))