diff --git a/single_shot_inference.py b/single_shot_inference.py index d54f0416..5c914949 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -1886,7 +1886,8 @@ def main(): if pl is None: continue for key, lin in pl.items(): if hasattr(lin, '_gsa_buf') and hasattr(lin, '_use_runtime_gsa') and lin._use_runtime_gsa: - fixed_gsa = lin._gsa_buf.item() # One-time sync + # Nvfp4GroupedLinear has per-group gsa; reduce to scalar (max) for fixed gsa + fixed_gsa = lin._gsa_buf.max().item() if lin._gsa_buf.numel() > 1 else lin._gsa_buf.item() lin._activation_global_scale = fixed_gsa lin._use_runtime_gsa = False n_fixed += 1