Enable loading of fused expert weights in the Transformers modelling backend (#36997)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1342,22 +1342,41 @@ class FusedMoE(CustomOp):
|
||||
weight_name = qual_name.replace(weight_name, param_name)
|
||||
param_name = weight_name.removeprefix(f"{self.layer_name}.")
|
||||
param = getattr(self, param_name)
|
||||
success = self.weight_loader(
|
||||
param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
weight_name=weight_name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=True,
|
||||
)
|
||||
if success:
|
||||
logger.debug(
|
||||
"Loaded %s for expert %d into %s",
|
||||
param_name,
|
||||
expert_id,
|
||||
self.layer_name,
|
||||
# Fused expert weights can be identified by their 3D tensors
|
||||
if loaded_weight.dim() == 3:
|
||||
# Repurpose expert_id as shard_idx for deconcatenating w1 and w3
|
||||
if shard_id in {"w1", "w3"}:
|
||||
shard_idx = expert_id
|
||||
experts_shard = loaded_weight.chunk(2, dim=1)[shard_idx]
|
||||
else:
|
||||
experts_shard = loaded_weight
|
||||
start = 0
|
||||
else:
|
||||
# loaded_weight is a single expert weight, so we add a dummy expert
|
||||
# dimension to unify the loading logic with the fused case
|
||||
experts_shard = loaded_weight.unsqueeze(0)
|
||||
start = expert_id
|
||||
|
||||
# Unified loading logic for fused and non-fused experts
|
||||
loaded_experts = experts_shard.unbind()
|
||||
for expert_id, loaded_expert in enumerate(loaded_experts, start=start):
|
||||
success = self.weight_loader(
|
||||
param=param,
|
||||
loaded_weight=loaded_expert,
|
||||
weight_name=weight_name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=True,
|
||||
)
|
||||
yield param_name
|
||||
if success:
|
||||
logger.debug(
|
||||
"Loaded expert %d of shard %s into %s for layer %s",
|
||||
expert_id,
|
||||
shard_id,
|
||||
param_name,
|
||||
self.layer_name,
|
||||
)
|
||||
yield param_name
|
||||
|
||||
def get_expert_weights(self) -> Iterable[torch.Tensor]:
|
||||
def _maybe_make_contiguous(
|
||||
|
||||
@@ -156,6 +156,17 @@ class MoEMixin(MixtureOfExperts):
|
||||
Params for weights, fp8 weight scales, fp8 activation scales
|
||||
(param_name, weight_name, expert_id, shard_id)
|
||||
"""
|
||||
# Models saved with fused experts. These are checkpoints released:
|
||||
# - After Transformers v5
|
||||
# - Before Transformers v5, but re-saved with save_original_format=False
|
||||
# In the fused experts case, we repurpose the expert_id as shard_idx for
|
||||
# deconcatenating w1 and w3 in FusedMoE.load_weights.
|
||||
expert_mapping = [
|
||||
("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
|
||||
("experts.w13_weight", "experts.gate_up_proj", 1, "w3"),
|
||||
("experts.w2_weight", "experts.down_proj", 0, "w2"),
|
||||
]
|
||||
# Models saved with ModuleList experts
|
||||
ckpt_names = [
|
||||
# (ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name)
|
||||
("gate_proj", "down_proj", "up_proj"), # Most common MoE style
|
||||
@@ -164,7 +175,6 @@ class MoEMixin(MixtureOfExperts):
|
||||
]
|
||||
num_experts = self.model_config.get_num_experts()
|
||||
num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts
|
||||
expert_mapping = []
|
||||
for gate_proj, down_proj, up_proj in ckpt_names:
|
||||
expert_mapping.extend(
|
||||
FusedMoE.make_expert_params_mapping(
|
||||
|
||||
Reference in New Issue
Block a user