From 3dd95ce77bd4994fd4b337794dd7ce255b7684af Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 1 Jun 2026 03:43:09 +0000 Subject: [PATCH] fix: set activation global scales AFTER _ensure_stacked/_ensure_initialized (which override them) --- single_shot_inference.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/single_shot_inference.py b/single_shot_inference.py index f789520a..0be94b1d 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -482,11 +482,10 @@ def _load_moe_weights_stacked(all_w, li, pfx, dev, moe, cfg): l2_sf_stacked = torch.stack(l2_sf_list).to(dev) if l2_sf_list else None del l1_fp4_list, l1_sf_list, l2_fp4_list, l2_sf_list moe.prepare_weights_from_stacked(l1_stacked, l1_sf_stacked, l1_gs_list, l2_stacked, l2_sf_stacked, l2_gs_list) - # Set activation global scales from input_scale (gsa = input_scale, gsb = weight_scale_2) - if l1_gsa_list: - moe._l1_activation_global_scale = l1_gsa_list[0] # Use first expert's input_scale - if l2_gsa_list: - moe._l2_activation_global_scale = l2_gsa_list[0] # Use first expert's input_scale + # Save activation global scales — _ensure_stacked will override them from l1_gs (which is 1.0) + # We must re-set them AFTER _ensure_stacked + moe._saved_l1_gsa = l1_gsa_list[0] if l1_gsa_list else 1.0 / (6.0 * 448.0) + moe._saved_l2_gsa = l2_gsa_list[0] if l2_gsa_list else 1.0 / (6.0 * 448.0) moe.l1_ws2 = l1_ws2_list moe.l2_ws2 = l2_ws2_list @@ -500,14 +499,16 @@ def _load_shared_expert_weights(all_w, li, pfx, dev, se, cfg): l1_isc = gisc.float().item() if gisc is not None else 1.0 / (6.0 * 448.0) se.l1_gs = [1.0] # gsb base — ws2 will be folded in by finalize_weights se.l1_ws2 = [gws2.to(dev) if gws2 is not None else None] - se._l1_activation_global_scale = l1_isc # gsa = input_scale + se._l1_activation_global_scale = l1_isc # Will be overridden by _ensure_initialized + se._saved_l1_gsa = l1_isc # Save for after _ensure_initialized if dw is not None: se.l2_fp4 = [dw.to(dev)] se.l2_sf = [dws.to(dev)] if dws is not None else [torch.zeros(1, device=dev, dtype=torch.float8_e4m3fn)] l2_isc = disc.float().item() if disc is not None else 1.0 / (6.0 * 448.0) se.l2_gs = [1.0] # gsb base se.l2_ws2 = [dws2.to(dev) if dws2 is not None else None] - se._l2_activation_global_scale = l2_isc # gsa = input_scale + se._l2_activation_global_scale = l2_isc # Will be overridden by _ensure_initialized + se._saved_l2_gsa = l2_isc # Save for after _ensure_initialized def _cache_layer_weights_no_experts(all_w, n_layers, devices): cached = {} @@ -649,6 +650,11 @@ def main(): _load_moe_weights_stacked(all_w, li, pfx, dev, moe, cfg) # EAGERLY process stacked weights → K-major + swizzle, free raw tensors moe._ensure_stacked() + # Fix activation global scales — _ensure_stacked sets gsa from l1_gs (which is 1.0) + if hasattr(moe, '_saved_l1_gsa'): + moe._l1_activation_global_scale = moe._saved_l1_gsa + if hasattr(moe, '_saved_l2_gsa'): + moe._l2_activation_global_scale = moe._saved_l2_gsa moe_runners[li] = moe se = Nvfp4SharedExpert(hidden_size=H, intermediate_size=cfg.get("moe_intermediate_size", 3072), @@ -656,6 +662,12 @@ def main(): _load_shared_expert_weights(all_w, li, pfx, dev, se, cfg) # EAGERLY process shared expert weights se._ensure_initialized() + # Fix activation global scales — _ensure_initialized sets gsa from l1_gs (which is 1.0) + # The correct gsa is the input_scale from the checkpoint, saved in _saved_l1_gsa + if hasattr(se, '_saved_l1_gsa'): + se._l1_activation_global_scale = se._saved_l1_gsa + if hasattr(se, '_saved_l2_gsa'): + se._l2_activation_global_scale = se._saved_l2_gsa se_runners[li] = se if (li+1) % 10 == 0: print(f" Built {li+1}/{n_layers} MoE layers") torch.cuda.empty_cache()