fix: set activation global scales AFTER _ensure_stacked/_ensure_initialized (which override them)
This commit is contained in:
@@ -482,11 +482,10 @@ def _load_moe_weights_stacked(all_w, li, pfx, dev, moe, cfg):
|
||||
l2_sf_stacked = torch.stack(l2_sf_list).to(dev) if l2_sf_list else None
|
||||
del l1_fp4_list, l1_sf_list, l2_fp4_list, l2_sf_list
|
||||
moe.prepare_weights_from_stacked(l1_stacked, l1_sf_stacked, l1_gs_list, l2_stacked, l2_sf_stacked, l2_gs_list)
|
||||
# Set activation global scales from input_scale (gsa = input_scale, gsb = weight_scale_2)
|
||||
if l1_gsa_list:
|
||||
moe._l1_activation_global_scale = l1_gsa_list[0] # Use first expert's input_scale
|
||||
if l2_gsa_list:
|
||||
moe._l2_activation_global_scale = l2_gsa_list[0] # Use first expert's input_scale
|
||||
# Save activation global scales — _ensure_stacked will override them from l1_gs (which is 1.0)
|
||||
# We must re-set them AFTER _ensure_stacked
|
||||
moe._saved_l1_gsa = l1_gsa_list[0] if l1_gsa_list else 1.0 / (6.0 * 448.0)
|
||||
moe._saved_l2_gsa = l2_gsa_list[0] if l2_gsa_list else 1.0 / (6.0 * 448.0)
|
||||
moe.l1_ws2 = l1_ws2_list
|
||||
moe.l2_ws2 = l2_ws2_list
|
||||
|
||||
@@ -500,14 +499,16 @@ def _load_shared_expert_weights(all_w, li, pfx, dev, se, cfg):
|
||||
l1_isc = gisc.float().item() if gisc is not None else 1.0 / (6.0 * 448.0)
|
||||
se.l1_gs = [1.0] # gsb base — ws2 will be folded in by finalize_weights
|
||||
se.l1_ws2 = [gws2.to(dev) if gws2 is not None else None]
|
||||
se._l1_activation_global_scale = l1_isc # gsa = input_scale
|
||||
se._l1_activation_global_scale = l1_isc # Will be overridden by _ensure_initialized
|
||||
se._saved_l1_gsa = l1_isc # Save for after _ensure_initialized
|
||||
if dw is not None:
|
||||
se.l2_fp4 = [dw.to(dev)]
|
||||
se.l2_sf = [dws.to(dev)] if dws is not None else [torch.zeros(1, device=dev, dtype=torch.float8_e4m3fn)]
|
||||
l2_isc = disc.float().item() if disc is not None else 1.0 / (6.0 * 448.0)
|
||||
se.l2_gs = [1.0] # gsb base
|
||||
se.l2_ws2 = [dws2.to(dev) if dws2 is not None else None]
|
||||
se._l2_activation_global_scale = l2_isc # gsa = input_scale
|
||||
se._l2_activation_global_scale = l2_isc # Will be overridden by _ensure_initialized
|
||||
se._saved_l2_gsa = l2_isc # Save for after _ensure_initialized
|
||||
|
||||
def _cache_layer_weights_no_experts(all_w, n_layers, devices):
|
||||
cached = {}
|
||||
@@ -649,6 +650,11 @@ def main():
|
||||
_load_moe_weights_stacked(all_w, li, pfx, dev, moe, cfg)
|
||||
# EAGERLY process stacked weights → K-major + swizzle, free raw tensors
|
||||
moe._ensure_stacked()
|
||||
# Fix activation global scales — _ensure_stacked sets gsa from l1_gs (which is 1.0)
|
||||
if hasattr(moe, '_saved_l1_gsa'):
|
||||
moe._l1_activation_global_scale = moe._saved_l1_gsa
|
||||
if hasattr(moe, '_saved_l2_gsa'):
|
||||
moe._l2_activation_global_scale = moe._saved_l2_gsa
|
||||
moe_runners[li] = moe
|
||||
|
||||
se = Nvfp4SharedExpert(hidden_size=H, intermediate_size=cfg.get("moe_intermediate_size", 3072),
|
||||
@@ -656,6 +662,12 @@ def main():
|
||||
_load_shared_expert_weights(all_w, li, pfx, dev, se, cfg)
|
||||
# EAGERLY process shared expert weights
|
||||
se._ensure_initialized()
|
||||
# Fix activation global scales — _ensure_initialized sets gsa from l1_gs (which is 1.0)
|
||||
# The correct gsa is the input_scale from the checkpoint, saved in _saved_l1_gsa
|
||||
if hasattr(se, '_saved_l1_gsa'):
|
||||
se._l1_activation_global_scale = se._saved_l1_gsa
|
||||
if hasattr(se, '_saved_l2_gsa'):
|
||||
se._l2_activation_global_scale = se._saved_l2_gsa
|
||||
se_runners[li] = se
|
||||
if (li+1) % 10 == 0: print(f" Built {li+1}/{n_layers} MoE layers")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user