[Feature] Pipeline Parallel Async send/recv, 2.9% E2E throughput improvement (#33368)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Wentao Ye
2026-02-13 03:38:16 -05:00
committed by GitHub
parent dddbff4624
commit 3d2a026fd0
3 changed files with 298 additions and 81 deletions

View File

@@ -4,6 +4,7 @@
import gc
import os
from collections.abc import Callable
from contextlib import AbstractContextManager, nullcontext
from types import NoneType
from typing import TYPE_CHECKING, Any, cast
@@ -30,6 +31,7 @@ from vllm.distributed.kv_transfer import (
has_kv_transfer_group,
)
from vllm.distributed.parallel_state import (
Handle,
get_pcp_group,
get_pp_group,
get_tp_group,
@@ -68,6 +70,38 @@ if TYPE_CHECKING:
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
class AsyncIntermediateTensors(IntermediateTensors):
"""IntermediateTensors with lazy comm synchronization"""
def __init__(
self,
tensors: dict[str, torch.Tensor],
comm_handles: list[Handle] | None = None,
comm_postprocess: list[Callable[[], None]] | None = None,
) -> None:
super().__init__(tensors)
self._comm_handles = comm_handles
self._comm_postprocess = comm_postprocess
self._comm_waited = False
def wait_for_comm(self) -> None:
if self._comm_waited:
return
if self._comm_handles:
for handle in self._comm_handles:
handle.wait()
if self._comm_postprocess:
for fn in self._comm_postprocess:
fn()
self._comm_waited = True
def __getattribute__(self, name: str):
# ensure `.tensors` is ready before use
if name == "tensors" and not object.__getattribute__(self, "_comm_waited"):
object.__getattribute__(self, "wait_for_comm")()
return object.__getattribute__(self, name)
class Worker(WorkerBase):
def __init__(
self,
@@ -113,6 +147,8 @@ class Worker(WorkerBase):
raise ValueError(f"Unknown profiler type: {self.profiler_config.profiler}")
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
# pending non-blocking PP send work from the previous iteration
self._pp_send_work: list[Handle] = []
def sleep(self, level: int = 1) -> None:
from vllm.device_allocator.cumem import CuMemAllocator
@@ -600,6 +636,12 @@ class Worker(WorkerBase):
def execute_model(
self, scheduler_output: "SchedulerOutput"
) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
# ensure any previous non-blocking PP sends are complete
if self._pp_send_work:
for handle in self._pp_send_work:
handle.wait()
self._pp_send_work = []
intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
@@ -637,12 +679,18 @@ class Worker(WorkerBase):
}
if forward_pass and not get_pp_group().is_first_rank:
tensor_dict = get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group(),
all_gather_tensors=all_gather_tensors,
tensor_dict, comm_handles, comm_postprocess = (
get_pp_group().irecv_tensor_dict(
all_gather_group=get_tp_group(),
all_gather_tensors=all_gather_tensors,
)
)
assert tensor_dict is not None
intermediate_tensors = IntermediateTensors(tensor_dict)
intermediate_tensors = AsyncIntermediateTensors(
tensor_dict,
comm_handles=comm_handles,
comm_postprocess=comm_postprocess,
)
with self.annotate_profile(scheduler_output):
output = self.model_runner.execute_model(
@@ -660,7 +708,8 @@ class Worker(WorkerBase):
and not get_pp_group().is_last_rank
)
get_pp_group().send_tensor_dict(
# launch non-blocking send of intermediate tensors
self._pp_send_work = get_pp_group().isend_tensor_dict(
output.tensors,
all_gather_group=get_tp_group(),
all_gather_tensors=all_gather_tensors,