[MoE Refactor] Integrate Naive Prepare Finalize into MK (#32567)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: amirkl94 <203507526+amirkl94@users.noreply.github.com>
This commit is contained in:
@@ -130,29 +130,65 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
) -> dict[str, torch.Tensor | Any]:
|
||||
return self.dist_module.recv_tensor_dict(src)
|
||||
|
||||
def dispatch( # type: ignore[override]
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Dispatch the hidden states and router logits to the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
|
||||
assert self.all2all_manager is not None
|
||||
return self.all2all_manager.dispatch(
|
||||
return self.all2all_manager.dispatch_router_logits(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
is_sequence_parallel,
|
||||
extra_tensors, # type: ignore[call-arg]
|
||||
extra_tensors,
|
||||
)
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Dispatch the hidden states and topk weights/ids to the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
assert self.all2all_manager is not None
|
||||
return self.all2all_manager.dispatch(
|
||||
hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
is_sequence_parallel,
|
||||
extra_tensors=extra_tensors,
|
||||
)
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Combine the hidden states and router logits from the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states = self.all2all_manager.combine(
|
||||
hidden_states, is_sequence_parallel
|
||||
return self.all2all_manager.combine(
|
||||
hidden_states,
|
||||
is_sequence_parallel,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class _CPUSHMDistributed:
|
||||
|
||||
Reference in New Issue
Block a user