Enable loading of fused expert weights in the Transformers modelling backend (#36997)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2026-03-14 07:01:06 +00:00
committed by GitHub
parent 74fe80ee95
commit ffa5d74f15
2 changed files with 45 additions and 16 deletions

View File

@@ -1342,22 +1342,41 @@ class FusedMoE(CustomOp):
weight_name = qual_name.replace(weight_name, param_name)
param_name = weight_name.removeprefix(f"{self.layer_name}.")
param = getattr(self, param_name)
success = self.weight_loader(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
logger.debug(
"Loaded %s for expert %d into %s",
param_name,
expert_id,
self.layer_name,
# Fused expert weights can be identified by their 3D tensors
if loaded_weight.dim() == 3:
# Repurpose expert_id as shard_idx for deconcatenating w1 and w3
if shard_id in {"w1", "w3"}:
shard_idx = expert_id
experts_shard = loaded_weight.chunk(2, dim=1)[shard_idx]
else:
experts_shard = loaded_weight
start = 0
else:
# loaded_weight is a single expert weight, so we add a dummy expert
# dimension to unify the loading logic with the fused case
experts_shard = loaded_weight.unsqueeze(0)
start = expert_id
# Unified loading logic for fused and non-fused experts
loaded_experts = experts_shard.unbind()
for expert_id, loaded_expert in enumerate(loaded_experts, start=start):
success = self.weight_loader(
param=param,
loaded_weight=loaded_expert,
weight_name=weight_name,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
yield param_name
if success:
logger.debug(
"Loaded expert %d of shard %s into %s for layer %s",
expert_id,
shard_id,
param_name,
self.layer_name,
)
yield param_name
def get_expert_weights(self) -> Iterable[torch.Tensor]:
def _maybe_make_contiguous(

View File

@@ -156,6 +156,17 @@ class MoEMixin(MixtureOfExperts):
Params for weights, fp8 weight scales, fp8 activation scales
(param_name, weight_name, expert_id, shard_id)
"""
# Models saved with fused experts. These are checkpoints released:
# - After Transformers v5
# - Before Transformers v5, but re-saved with save_original_format=False
# In the fused experts case, we repurpose the expert_id as shard_idx for
# deconcatenating w1 and w3 in FusedMoE.load_weights.
expert_mapping = [
("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
("experts.w13_weight", "experts.gate_up_proj", 1, "w3"),
("experts.w2_weight", "experts.down_proj", 0, "w2"),
]
# Models saved with ModuleList experts
ckpt_names = [
# (ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name)
("gate_proj", "down_proj", "up_proj"), # Most common MoE style
@@ -164,7 +175,6 @@ class MoEMixin(MixtureOfExperts):
]
num_experts = self.model_config.get_num_experts()
num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts
expert_mapping = []
for gate_proj, down_proj, up_proj in ckpt_names:
expert_mapping.extend(
FusedMoE.make_expert_params_mapping(