[Bugfix][WideEP] Apply TP Attn + EP MoE fix to other models (#24982)

Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
Tyler Michael Smith
2025-09-27 10:22:28 -04:00
committed by GitHub
parent f9df8b4ad7
commit a5354b3ed2
23 changed files with 540 additions and 375 deletions

View File

@@ -6,7 +6,7 @@ import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.distributed import get_dp_group
from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils import has_deep_ep, has_pplx
@@ -34,41 +34,60 @@ class NaiveAll2AllManager(All2AllManagerBase):
super().__init__(cpu_group)
def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor):
cu_tokens_across_sp_cpu: torch.Tensor,
is_sequence_parallel: bool) -> torch.Tensor:
assert (len(x.shape) == 2)
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
buffer = torch.empty((cu_tokens_across_sp_cpu[-1], x.size(1)),
device=x.device,
dtype=x.dtype)
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
self.dp_rank - 1]
end = cu_tokens_across_dp_cpu[self.dp_rank]
rank = self.rank if is_sequence_parallel else self.dp_rank
world_size = (self.world_size
if is_sequence_parallel else self.dp_world_size)
start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1]
end = cu_tokens_across_sp_cpu[rank]
buffer[start:end, :].copy_(x)
for idx in range(self.dp_world_size):
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
end = cu_tokens_across_dp_cpu[idx]
self.dp_group.broadcast(buffer[start:end, :], idx)
for idx in range(world_size):
start = 0 if idx == 0 else cu_tokens_across_sp_cpu[idx - 1]
end = cu_tokens_across_sp_cpu[idx]
get_ep_group().broadcast(buffer[start:end, :], idx)
return buffer
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states, router_logits = get_dp_group().all_gatherv(
[hidden_states, router_logits],
dim=0,
sizes=sizes,
)
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
hidden_states = self.naive_multicast(hidden_states,
cu_tokens_across_sp_cpu,
is_sequence_parallel)
router_logits = self.naive_multicast(router_logits,
cu_tokens_across_sp_cpu,
is_sequence_parallel)
return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
dim=0,
sizes=sizes)
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
ep_rank = self.rank if is_sequence_parallel else self.dp_rank
dp_metadata = get_forward_context().dp_metadata
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
start = 0 if ep_rank == 0 else cu_tokens_across_sp_cpu[ep_rank - 1]
end = cu_tokens_across_sp_cpu[ep_rank]
all_hidden_states = get_ep_group().all_reduce(hidden_states)
hidden_states = all_hidden_states[start:end, :]
return hidden_states
def destroy(self):
@@ -84,29 +103,40 @@ class AgRsAll2AllManager(All2AllManagerBase):
def __init__(self, cpu_group):
super().__init__(cpu_group)
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Gather hidden_states and router_logits from all dp ranks.
"""
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states, router_logits = get_dp_group().all_gatherv(
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
hidden_states, router_logits = dist_group.all_gatherv(
[hidden_states, router_logits],
dim=0,
sizes=sizes,
)
return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
"""
Reduce-scatter hidden_states across all dp ranks.
"""
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
dim=0,
sizes=sizes)
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
hidden_states = dist_group.reduce_scatterv(hidden_states,
dim=0,
sizes=sizes)
return hidden_states
def destroy(self):
@@ -148,11 +178,17 @@ class PPLXAll2AllManager(All2AllManagerBase):
kwargs, pplx.AllToAll.internode
if self.internode else pplx.AllToAll.intranode)
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
@@ -184,11 +220,17 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
def get_handle(self, kwargs):
raise NotImplementedError
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
@@ -395,4 +437,4 @@ class FlashInferAllToAllManager(All2AllManagerBase):
self.workspace_tensor = None
self.prepare_workspace_tensor = None
self.mapping = None
self.initialized = False
self.initialized = False