[Bugfix] Add missing extra_tensors arg to DeviceCommunicatorBase.disp… (#31644)

Signed-off-by: kunzh <zhikun.wu@outlook.com>
This commit is contained in:
kzwrime
2026-01-06 01:26:09 +08:00
committed by GitHub
parent c455b771fd
commit 21156ff199
2 changed files with 9 additions and 5 deletions

View File

@@ -285,6 +285,7 @@ class DeviceCommunicatorBase:
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Dispatch the hidden states and router logits to the appropriate device.

View File

@@ -23,11 +23,11 @@ class XpuCommunicator(DeviceCommunicatorBase):
):
super().__init__(cpu_group, device, device_group, unique_name)
if self.use_all2all:
if self.all2all_backend != "naive":
if self.all2all_backend != "naive": # type: ignore[has-type]
logger.warning(
"`%s` all2all manager is not supported on XPU. "
"Falling back to `naive` all2all manager for XPU.",
self.all2all_backend,
self.all2all_backend, # type: ignore[has-type]
)
self.all2all_backend = "naive"
if self.all2all_backend == "naive":
@@ -78,12 +78,15 @@ class XpuCommunicator(DeviceCommunicatorBase):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits, is_sequence_parallel
return self.all2all_manager.dispatch(
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors, # type: ignore[call-arg]
)
return hidden_states, router_logits
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False