[MoE Refactor] Integrate Naive Prepare Finalize into MK (#32567)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: amirkl94 <203507526+amirkl94@users.noreply.github.com>
This commit is contained in:
@@ -59,7 +59,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
||||
|
||||
return buffer
|
||||
|
||||
def dispatch(
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
@@ -84,6 +84,34 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
||||
|
||||
return hidden_states, router_logits
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if extra_tensors is not None:
|
||||
raise NotImplementedError(
|
||||
"extra_tensors is not supported for NaiveAll2AllManager"
|
||||
)
|
||||
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
assert dp_metadata is not None
|
||||
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
|
||||
|
||||
hidden_states = self.naive_multicast(
|
||||
hidden_states, cu_tokens_across_sp_cpu, is_sequence_parallel
|
||||
)
|
||||
topk_weights = self.naive_multicast(
|
||||
topk_weights, cu_tokens_across_sp_cpu, is_sequence_parallel
|
||||
)
|
||||
topk_ids = self.naive_multicast(
|
||||
topk_ids, cu_tokens_across_sp_cpu, is_sequence_parallel
|
||||
)
|
||||
return hidden_states, topk_weights, topk_ids
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
@@ -114,7 +142,7 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def dispatch(
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
@@ -148,6 +176,46 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
||||
return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:])
|
||||
return gathered_tensors[0], gathered_tensors[1]
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Gather hidden_states and router_logits from all dp ranks.
|
||||
"""
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
assert dp_metadata is not None
|
||||
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
assert sizes is not None
|
||||
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
|
||||
|
||||
tensors_to_gather = [hidden_states, topk_weights, topk_ids]
|
||||
if extra_tensors is not None:
|
||||
tensors_to_gather.extend(extra_tensors)
|
||||
|
||||
gathered_tensors = dist_group.all_gatherv(
|
||||
tensors_to_gather,
|
||||
dim=0,
|
||||
sizes=sizes,
|
||||
)
|
||||
|
||||
hidden_states = gathered_tensors[0]
|
||||
topk_weights = gathered_tensors[1]
|
||||
topk_ids = gathered_tensors[2]
|
||||
|
||||
if extra_tensors is None:
|
||||
return hidden_states, topk_weights, topk_ids
|
||||
|
||||
return hidden_states, topk_weights, topk_ids, gathered_tensors[3:]
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
@@ -216,7 +284,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
||||
pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
|
||||
)
|
||||
|
||||
def dispatch(
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
@@ -225,6 +293,19 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
@@ -264,7 +345,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
def get_handle(self, kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
@@ -273,6 +354,19 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
from typing import Any
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import torch
|
||||
@@ -64,13 +63,32 @@ class All2AllManagerBase:
|
||||
# and reuse it for the same config.
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> Any:
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
# Subclasses should either:
|
||||
# - implement handling for extra_tensors, or
|
||||
# - raise a clear error if extra_tensors is not supported.
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
# Subclasses should either:
|
||||
# - implement handling for extra_tensors, or
|
||||
# - raise a clear error if extra_tensors is not supported.
|
||||
@@ -280,7 +298,7 @@ class DeviceCommunicatorBase:
|
||||
for module in moe_modules:
|
||||
module.maybe_init_modular_kernel()
|
||||
|
||||
def dispatch(
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
@@ -294,8 +312,29 @@ class DeviceCommunicatorBase:
|
||||
Dispatch the hidden states and router logits to the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
if extra_tensors is not None:
|
||||
return hidden_states, router_logits, extra_tensors
|
||||
return hidden_states, router_logits
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Dispatch the hidden states and topk weights/ids to the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
if extra_tensors is not None:
|
||||
return hidden_states, topk_weights, topk_ids, extra_tensors
|
||||
return hidden_states, topk_weights, topk_ids
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
|
||||
@@ -130,29 +130,65 @@ class CpuCommunicator(DeviceCommunicatorBase):
|
||||
) -> dict[str, torch.Tensor | Any]:
|
||||
return self.dist_module.recv_tensor_dict(src)
|
||||
|
||||
def dispatch( # type: ignore[override]
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Dispatch the hidden states and router logits to the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
|
||||
assert self.all2all_manager is not None
|
||||
return self.all2all_manager.dispatch(
|
||||
return self.all2all_manager.dispatch_router_logits(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
is_sequence_parallel,
|
||||
extra_tensors, # type: ignore[call-arg]
|
||||
extra_tensors,
|
||||
)
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Dispatch the hidden states and topk weights/ids to the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
assert self.all2all_manager is not None
|
||||
return self.all2all_manager.dispatch(
|
||||
hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
is_sequence_parallel,
|
||||
extra_tensors=extra_tensors,
|
||||
)
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Combine the hidden states and router logits from the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states = self.all2all_manager.combine(
|
||||
hidden_states, is_sequence_parallel
|
||||
return self.all2all_manager.combine(
|
||||
hidden_states,
|
||||
is_sequence_parallel,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class _CPUSHMDistributed:
|
||||
|
||||
@@ -322,7 +322,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
return output_list
|
||||
|
||||
def dispatch( # type: ignore[override]
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
@@ -332,19 +332,52 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Dispatch the hidden states and router logits to the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
|
||||
assert self.all2all_manager is not None
|
||||
return self.all2all_manager.dispatch(
|
||||
return self.all2all_manager.dispatch_router_logits(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
is_sequence_parallel,
|
||||
extra_tensors, # type: ignore[call-arg]
|
||||
extra_tensors,
|
||||
)
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Dispatch the hidden states and topk weights/ids to the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
assert self.all2all_manager is not None
|
||||
return self.all2all_manager.dispatch(
|
||||
hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
is_sequence_parallel,
|
||||
extra_tensors=extra_tensors,
|
||||
)
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Combine the hidden states and router logits from the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states = self.all2all_manager.combine(
|
||||
hidden_states, is_sequence_parallel
|
||||
return self.all2all_manager.combine(
|
||||
hidden_states,
|
||||
is_sequence_parallel,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
import torch.distributed as dist
|
||||
from flashinfer.comm.mnnvl import CommBackend as CommBackend
|
||||
|
||||
@@ -23,5 +25,14 @@ class CustomCommunicator(CommBackend):
|
||||
dist.all_gather_object(gathered, data, group=self._group)
|
||||
return gathered
|
||||
|
||||
# NOTE(rob): CommBackend is an abstract class, and bcast/barrier
|
||||
# are unimplemented on vLLM side. If we need to utilize these
|
||||
# methods in the future, can create a concrete implementation.
|
||||
def bcast(self, data: Any, root: int) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def barrier(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def Split(self, color: int, key: int) -> "CustomCommunicator":
|
||||
return self
|
||||
|
||||
@@ -196,26 +196,62 @@ class XpuCommunicator(DeviceCommunicatorBase):
|
||||
def broadcast(self, input_: torch.Tensor, src: int = 0) -> None:
|
||||
dist.broadcast(input_, src=src, group=self.device_group)
|
||||
|
||||
def dispatch(
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Dispatch the hidden states and router logits to the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
|
||||
assert self.all2all_manager is not None
|
||||
return self.all2all_manager.dispatch(
|
||||
return self.all2all_manager.dispatch_router_logits(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
is_sequence_parallel,
|
||||
extra_tensors, # type: ignore[call-arg]
|
||||
extra_tensors,
|
||||
)
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Dispatch the hidden states and topk weights/ids to the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
assert self.all2all_manager is not None
|
||||
return self.all2all_manager.dispatch(
|
||||
hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
is_sequence_parallel,
|
||||
extra_tensors=extra_tensors,
|
||||
)
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Combine the hidden states and router logits from the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states = self.all2all_manager.combine(
|
||||
hidden_states, is_sequence_parallel
|
||||
return self.all2all_manager.combine(
|
||||
hidden_states,
|
||||
is_sequence_parallel,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
@@ -1000,7 +1000,7 @@ class GroupCoordinator:
|
||||
if self.device_communicator is not None:
|
||||
self.device_communicator.prepare_communication_buffer_for_model(model)
|
||||
|
||||
def dispatch(
|
||||
def dispatch_router_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
@@ -1011,7 +1011,7 @@ class GroupCoordinator:
|
||||
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
if self.device_communicator is not None:
|
||||
return self.device_communicator.dispatch( # type: ignore[call-arg]
|
||||
return self.device_communicator.dispatch_router_logits(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
is_sequence_parallel,
|
||||
@@ -1020,6 +1020,28 @@ class GroupCoordinator:
|
||||
else:
|
||||
return hidden_states, router_logits
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
| tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
):
|
||||
if self.device_communicator is not None:
|
||||
return self.device_communicator.dispatch(
|
||||
hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
is_sequence_parallel,
|
||||
extra_tensors,
|
||||
)
|
||||
else:
|
||||
return hidden_states, topk_weights, topk_ids
|
||||
|
||||
def combine(
|
||||
self, hidden_states, is_sequence_parallel: bool = False
|
||||
) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user