[Bugfix][WideEP] Apply TP Attn + EP MoE fix to other models (#24982)
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
committed by
GitHub
parent
f9df8b4ad7
commit
a5354b3ed2
@@ -6,7 +6,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed import get_dp_group
|
||||
from vllm.distributed import get_dp_group, get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import has_deep_ep, has_pplx
|
||||
@@ -34,41 +34,60 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def naive_multicast(self, x: torch.Tensor,
|
||||
cu_tokens_across_dp_cpu: torch.Tensor):
|
||||
cu_tokens_across_sp_cpu: torch.Tensor,
|
||||
is_sequence_parallel: bool) -> torch.Tensor:
|
||||
assert (len(x.shape) == 2)
|
||||
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
|
||||
buffer = torch.empty((cu_tokens_across_sp_cpu[-1], x.size(1)),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
|
||||
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||
self.dp_rank - 1]
|
||||
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
||||
rank = self.rank if is_sequence_parallel else self.dp_rank
|
||||
world_size = (self.world_size
|
||||
if is_sequence_parallel else self.dp_world_size)
|
||||
|
||||
start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1]
|
||||
end = cu_tokens_across_sp_cpu[rank]
|
||||
buffer[start:end, :].copy_(x)
|
||||
for idx in range(self.dp_world_size):
|
||||
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
|
||||
end = cu_tokens_across_dp_cpu[idx]
|
||||
self.dp_group.broadcast(buffer[start:end, :], idx)
|
||||
for idx in range(world_size):
|
||||
start = 0 if idx == 0 else cu_tokens_across_sp_cpu[idx - 1]
|
||||
end = cu_tokens_across_sp_cpu[idx]
|
||||
get_ep_group().broadcast(buffer[start:end, :], idx)
|
||||
|
||||
return buffer
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
sizes = get_forward_context(
|
||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
hidden_states, router_logits = get_dp_group().all_gatherv(
|
||||
[hidden_states, router_logits],
|
||||
dim=0,
|
||||
sizes=sizes,
|
||||
)
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
|
||||
|
||||
hidden_states = self.naive_multicast(hidden_states,
|
||||
cu_tokens_across_sp_cpu,
|
||||
is_sequence_parallel)
|
||||
router_logits = self.naive_multicast(router_logits,
|
||||
cu_tokens_across_sp_cpu,
|
||||
is_sequence_parallel)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
sizes = get_forward_context(
|
||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
|
||||
dim=0,
|
||||
sizes=sizes)
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
|
||||
ep_rank = self.rank if is_sequence_parallel else self.dp_rank
|
||||
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
||||
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
|
||||
|
||||
start = 0 if ep_rank == 0 else cu_tokens_across_sp_cpu[ep_rank - 1]
|
||||
end = cu_tokens_across_sp_cpu[ep_rank]
|
||||
|
||||
all_hidden_states = get_ep_group().all_reduce(hidden_states)
|
||||
hidden_states = all_hidden_states[start:end, :]
|
||||
return hidden_states
|
||||
|
||||
def destroy(self):
|
||||
@@ -84,29 +103,40 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Gather hidden_states and router_logits from all dp ranks.
|
||||
"""
|
||||
sizes = get_forward_context(
|
||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
hidden_states, router_logits = get_dp_group().all_gatherv(
|
||||
|
||||
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
|
||||
hidden_states, router_logits = dist_group.all_gatherv(
|
||||
[hidden_states, router_logits],
|
||||
dim=0,
|
||||
sizes=sizes,
|
||||
)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
"""
|
||||
Reduce-scatter hidden_states across all dp ranks.
|
||||
"""
|
||||
sizes = get_forward_context(
|
||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
|
||||
dim=0,
|
||||
sizes=sizes)
|
||||
|
||||
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||
hidden_states = dist_group.reduce_scatterv(hidden_states,
|
||||
dim=0,
|
||||
sizes=sizes)
|
||||
return hidden_states
|
||||
|
||||
def destroy(self):
|
||||
@@ -148,11 +178,17 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
||||
kwargs, pplx.AllToAll.internode
|
||||
if self.internode else pplx.AllToAll.intranode)
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
@@ -184,11 +220,17 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
def get_handle(self, kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
@@ -395,4 +437,4 @@ class FlashInferAllToAllManager(All2AllManagerBase):
|
||||
self.workspace_tensor = None
|
||||
self.prepare_workspace_tensor = None
|
||||
self.mapping = None
|
||||
self.initialized = False
|
||||
self.initialized = False
|
||||
|
||||
Reference in New Issue
Block a user