FusedMoE support for the Transformers backend (#22650)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -960,6 +960,7 @@ class FusedMoE(CustomOp):
|
||||
is_sequence_parallel=False,
|
||||
zero_expert_num: Optional[int] = 0,
|
||||
zero_expert_type: Optional[str] = None,
|
||||
expert_mapping: Optional[list[tuple[str, str, int, str]]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
if params_dtype is None:
|
||||
@@ -996,6 +997,9 @@ class FusedMoE(CustomOp):
|
||||
self.zero_expert_num = zero_expert_num
|
||||
self.zero_expert_type = zero_expert_type
|
||||
|
||||
# Expert mapping used in self.load_weights
|
||||
self.expert_mapping = expert_mapping
|
||||
|
||||
# Round up hidden size if needed.
|
||||
hidden_size = maybe_roundup_hidden_size(hidden_size, moe_in_dtype,
|
||||
quant_config,
|
||||
@@ -1617,6 +1621,33 @@ class FusedMoE(CustomOp):
|
||||
|
||||
return False if return_success else None
|
||||
|
||||
def load_weights(
|
||||
self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> Iterable[str]:
|
||||
if (expert_mapping := self.expert_mapping) is None:
|
||||
raise ValueError("`self.expert_mapping` must be provided to "
|
||||
"load weights using `self.load_weights`.")
|
||||
for expert_name, loaded_weight in weights:
|
||||
qual_name = f"{self.layer_name}.{expert_name}"
|
||||
for param_name, weight_name, expert_id, shard_id in expert_mapping:
|
||||
if weight_name not in qual_name:
|
||||
continue
|
||||
weight_name = qual_name.replace(weight_name, param_name)
|
||||
param_name = weight_name.removeprefix(f"{self.layer_name}.")
|
||||
param = getattr(self, param_name)
|
||||
success = self.weight_loader(
|
||||
param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
weight_name=weight_name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=True,
|
||||
)
|
||||
if success:
|
||||
logger.debug("Loaded %s for expert %d into %s", param_name,
|
||||
expert_id, self.layer_name)
|
||||
yield param_name
|
||||
|
||||
def get_expert_weights(self) -> Iterable[torch.Tensor]:
|
||||
weights = list(self.named_parameters())
|
||||
assert all(weight.is_contiguous() for _, weight in weights)
|
||||
|
||||
Reference in New Issue
Block a user