Fix: add _use_runtime_gsa=True to router gate GEMM in single_shot
The checkpoint-path gate was using the checkpoint's input_scale as gsa — the same E4M3 overflow bug we fixed in Nvfp4Linear/Nvfp4MoE/etc. The runtime-quantized BF16 path was using 1/(6*448) as a fixed gsa. Both now compute gsa from actual activation magnitude at runtime.
This commit is contained in:
@@ -728,7 +728,8 @@ def main():
|
||||
isc_v = gate_isc.float().item() if gate_isc is not None else 1.0/(6.0*448.0)
|
||||
gate_lin.gs = [1.0]
|
||||
gate_lin.ws2 = [torch.tensor([ws2_v], device=dev, dtype=torch.float32)]
|
||||
gate_lin._activation_global_scale = isc_v
|
||||
gate_lin._activation_global_scale = isc_v # placeholder — runtime gsa overrides this
|
||||
gate_lin._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow
|
||||
gate_lin.finalize_weights()
|
||||
router.load_nvfp4_gate(gate_lin)
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
@@ -746,7 +747,8 @@ def main():
|
||||
gate_lin.sf = [g_sf]
|
||||
gate_lin.gs = [g_gs]
|
||||
gate_lin.ws2 = [torch.tensor([g_gs], device=dev, dtype=torch.float32)]
|
||||
gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0)
|
||||
gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0) # placeholder — runtime gsa overrides
|
||||
gate_lin._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow
|
||||
gate_lin.finalize_weights()
|
||||
router.load_nvfp4_gate(gate_lin)
|
||||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||||
|
||||
Reference in New Issue
Block a user