Fix warmup_gsa: handle multi-element _gsa_buf (Nvfp4GroupedLinear per-group gsa)

This commit is contained in:
2026-06-03 19:49:54 +00:00
parent 486f74d900
commit 2661cebe9a

View File

@@ -1886,7 +1886,8 @@ def main():
if pl is None: continue
for key, lin in pl.items():
if hasattr(lin, '_gsa_buf') and hasattr(lin, '_use_runtime_gsa') and lin._use_runtime_gsa:
fixed_gsa = lin._gsa_buf.item() # One-time sync
# Nvfp4GroupedLinear has per-group gsa; reduce to scalar (max) for fixed gsa
fixed_gsa = lin._gsa_buf.max().item() if lin._gsa_buf.numel() > 1 else lin._gsa_buf.item()
lin._activation_global_scale = fixed_gsa
lin._use_runtime_gsa = False
n_fixed += 1