FusedMoE support for the Transformers backend (#22650)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-03 07:12:15 +01:00
committed by GitHub
parent 39b643dc1a
commit 10d765482d
10 changed files with 485 additions and 91 deletions

View File

@@ -960,6 +960,7 @@ class FusedMoE(CustomOp):
is_sequence_parallel=False,
zero_expert_num: Optional[int] = 0,
zero_expert_type: Optional[str] = None,
expert_mapping: Optional[list[tuple[str, str, int, str]]] = None,
):
super().__init__()
if params_dtype is None:
@@ -996,6 +997,9 @@ class FusedMoE(CustomOp):
self.zero_expert_num = zero_expert_num
self.zero_expert_type = zero_expert_type
# Expert mapping used in self.load_weights
self.expert_mapping = expert_mapping
# Round up hidden size if needed.
hidden_size = maybe_roundup_hidden_size(hidden_size, moe_in_dtype,
quant_config,
@@ -1617,6 +1621,33 @@ class FusedMoE(CustomOp):
return False if return_success else None
def load_weights(
self, weights: Iterable[tuple[str,
torch.Tensor]]) -> Iterable[str]:
if (expert_mapping := self.expert_mapping) is None:
raise ValueError("`self.expert_mapping` must be provided to "
"load weights using `self.load_weights`.")
for expert_name, loaded_weight in weights:
qual_name = f"{self.layer_name}.{expert_name}"
for param_name, weight_name, expert_id, shard_id in expert_mapping:
if weight_name not in qual_name:
continue
weight_name = qual_name.replace(weight_name, param_name)
param_name = weight_name.removeprefix(f"{self.layer_name}.")
param = getattr(self, param_name)
success = self.weight_loader(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
logger.debug("Loaded %s for expert %d into %s", param_name,
expert_id, self.layer_name)
yield param_name
def get_expert_weights(self) -> Iterable[torch.Tensor]:
weights = list(self.named_parameters())
assert all(weight.is_contiguous() for _, weight in weights)