[Bugfix] Add missing extra_tensors arg to DeviceCommunicatorBase.disp… (#31644)
Signed-off-by: kunzh <zhikun.wu@outlook.com>
This commit is contained in:
@@ -285,6 +285,7 @@ class DeviceCommunicatorBase:
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Dispatch the hidden states and router logits to the appropriate device.
|
||||
|
||||
@@ -23,11 +23,11 @@ class XpuCommunicator(DeviceCommunicatorBase):
|
||||
):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
if self.use_all2all:
|
||||
if self.all2all_backend != "naive":
|
||||
if self.all2all_backend != "naive": # type: ignore[has-type]
|
||||
logger.warning(
|
||||
"`%s` all2all manager is not supported on XPU. "
|
||||
"Falling back to `naive` all2all manager for XPU.",
|
||||
self.all2all_backend,
|
||||
self.all2all_backend, # type: ignore[has-type]
|
||||
)
|
||||
self.all2all_backend = "naive"
|
||||
if self.all2all_backend == "naive":
|
||||
@@ -78,12 +78,15 @@ class XpuCommunicator(DeviceCommunicatorBase):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states, router_logits = self.all2all_manager.dispatch(
|
||||
hidden_states, router_logits, is_sequence_parallel
|
||||
return self.all2all_manager.dispatch(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
is_sequence_parallel,
|
||||
extra_tensors, # type: ignore[call-arg]
|
||||
)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
|
||||
Reference in New Issue
Block a user