fix: set activation global scales AFTER _ensure_stacked/_ensure_initialized (which override them)

This commit is contained in:
2026-06-01 03:43:09 +00:00
parent 27c63b01d6
commit 3dd95ce77b

View File

@@ -482,11 +482,10 @@ def _load_moe_weights_stacked(all_w, li, pfx, dev, moe, cfg):
l2_sf_stacked = torch.stack(l2_sf_list).to(dev) if l2_sf_list else None
del l1_fp4_list, l1_sf_list, l2_fp4_list, l2_sf_list
moe.prepare_weights_from_stacked(l1_stacked, l1_sf_stacked, l1_gs_list, l2_stacked, l2_sf_stacked, l2_gs_list)
# Set activation global scales from input_scale (gsa = input_scale, gsb = weight_scale_2)
if l1_gsa_list:
moe._l1_activation_global_scale = l1_gsa_list[0] # Use first expert's input_scale
if l2_gsa_list:
moe._l2_activation_global_scale = l2_gsa_list[0] # Use first expert's input_scale
# Save activation global scales — _ensure_stacked will override them from l1_gs (which is 1.0)
# We must re-set them AFTER _ensure_stacked
moe._saved_l1_gsa = l1_gsa_list[0] if l1_gsa_list else 1.0 / (6.0 * 448.0)
moe._saved_l2_gsa = l2_gsa_list[0] if l2_gsa_list else 1.0 / (6.0 * 448.0)
moe.l1_ws2 = l1_ws2_list
moe.l2_ws2 = l2_ws2_list
@@ -500,14 +499,16 @@ def _load_shared_expert_weights(all_w, li, pfx, dev, se, cfg):
l1_isc = gisc.float().item() if gisc is not None else 1.0 / (6.0 * 448.0)
se.l1_gs = [1.0] # gsb base — ws2 will be folded in by finalize_weights
se.l1_ws2 = [gws2.to(dev) if gws2 is not None else None]
se._l1_activation_global_scale = l1_isc # gsa = input_scale
se._l1_activation_global_scale = l1_isc # Will be overridden by _ensure_initialized
se._saved_l1_gsa = l1_isc # Save for after _ensure_initialized
if dw is not None:
se.l2_fp4 = [dw.to(dev)]
se.l2_sf = [dws.to(dev)] if dws is not None else [torch.zeros(1, device=dev, dtype=torch.float8_e4m3fn)]
l2_isc = disc.float().item() if disc is not None else 1.0 / (6.0 * 448.0)
se.l2_gs = [1.0] # gsb base
se.l2_ws2 = [dws2.to(dev) if dws2 is not None else None]
se._l2_activation_global_scale = l2_isc # gsa = input_scale
se._l2_activation_global_scale = l2_isc # Will be overridden by _ensure_initialized
se._saved_l2_gsa = l2_isc # Save for after _ensure_initialized
def _cache_layer_weights_no_experts(all_w, n_layers, devices):
cached = {}
@@ -649,6 +650,11 @@ def main():
_load_moe_weights_stacked(all_w, li, pfx, dev, moe, cfg)
# EAGERLY process stacked weights → K-major + swizzle, free raw tensors
moe._ensure_stacked()
# Fix activation global scales — _ensure_stacked sets gsa from l1_gs (which is 1.0)
if hasattr(moe, '_saved_l1_gsa'):
moe._l1_activation_global_scale = moe._saved_l1_gsa
if hasattr(moe, '_saved_l2_gsa'):
moe._l2_activation_global_scale = moe._saved_l2_gsa
moe_runners[li] = moe
se = Nvfp4SharedExpert(hidden_size=H, intermediate_size=cfg.get("moe_intermediate_size", 3072),
@@ -656,6 +662,12 @@ def main():
_load_shared_expert_weights(all_w, li, pfx, dev, se, cfg)
# EAGERLY process shared expert weights
se._ensure_initialized()
# Fix activation global scales — _ensure_initialized sets gsa from l1_gs (which is 1.0)
# The correct gsa is the input_scale from the checkpoint, saved in _saved_l1_gsa
if hasattr(se, '_saved_l1_gsa'):
se._l1_activation_global_scale = se._saved_l1_gsa
if hasattr(se, '_saved_l2_gsa'):
se._l2_activation_global_scale = se._saved_l2_gsa
se_runners[li] = se
if (li+1) % 10 == 0: print(f" Built {li+1}/{n_layers} MoE layers")
torch.cuda.empty_cache()