Fix warmup_gsa: handle multi-element _gsa_buf (Nvfp4GroupedLinear per-group gsa)
This commit is contained in:
@@ -1886,7 +1886,8 @@ def main():
|
||||
if pl is None: continue
|
||||
for key, lin in pl.items():
|
||||
if hasattr(lin, '_gsa_buf') and hasattr(lin, '_use_runtime_gsa') and lin._use_runtime_gsa:
|
||||
fixed_gsa = lin._gsa_buf.item() # One-time sync
|
||||
# Nvfp4GroupedLinear has per-group gsa; reduce to scalar (max) for fixed gsa
|
||||
fixed_gsa = lin._gsa_buf.max().item() if lin._gsa_buf.numel() > 1 else lin._gsa_buf.item()
|
||||
lin._activation_global_scale = fixed_gsa
|
||||
lin._use_runtime_gsa = False
|
||||
n_fixed += 1
|
||||
|
||||
Reference in New Issue
Block a user