Enable Allgather/ReduceScatter backend for NaiveAllToAll (#23964)

Signed-off-by: Shu Wang. <shuw@nvidia.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: Shu Wang <shuw@nvidia.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Shu Wang
2025-09-18 10:52:58 -05:00
committed by GitHub
parent b419937c78
commit 2ea50e977a
3 changed files with 55 additions and 5 deletions

View File

@@ -5,6 +5,7 @@ from typing import Any
import torch
import torch.distributed as dist
from vllm.distributed import get_dp_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils import has_deep_ep, has_pplx
@@ -69,6 +70,44 @@ class NaiveAll2AllManager(All2AllManagerBase):
pass
class AgRsAll2AllManager(All2AllManagerBase):
"""
An implementation of all2all communication based on
all-gather (dispatch) and reduce-scatter (combine).
"""
def __init__(self, cpu_group):
super().__init__(cpu_group)
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
"""
Gather hidden_states and router_logits from all dp ranks.
"""
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states, router_logits = get_dp_group().all_gatherv(
[hidden_states, router_logits],
dim=0,
sizes=sizes,
)
return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Reduce-scatter hidden_states across all dp ranks.
"""
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
dim=0,
sizes=sizes)
return hidden_states
def destroy(self):
pass
class PPLXAll2AllManager(All2AllManagerBase):
"""
All2All communication based on PPLX kernels.