[Feature] Pipeline Parallel Async send/recv, 2.9% E2E throughput improvement (#33368)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
|
||||
import gc
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from types import NoneType
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
@@ -30,6 +31,7 @@ from vllm.distributed.kv_transfer import (
|
||||
has_kv_transfer_group,
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
Handle,
|
||||
get_pcp_group,
|
||||
get_pp_group,
|
||||
get_tp_group,
|
||||
@@ -68,6 +70,38 @@ if TYPE_CHECKING:
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
|
||||
class AsyncIntermediateTensors(IntermediateTensors):
|
||||
"""IntermediateTensors with lazy comm synchronization"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tensors: dict[str, torch.Tensor],
|
||||
comm_handles: list[Handle] | None = None,
|
||||
comm_postprocess: list[Callable[[], None]] | None = None,
|
||||
) -> None:
|
||||
super().__init__(tensors)
|
||||
self._comm_handles = comm_handles
|
||||
self._comm_postprocess = comm_postprocess
|
||||
self._comm_waited = False
|
||||
|
||||
def wait_for_comm(self) -> None:
|
||||
if self._comm_waited:
|
||||
return
|
||||
if self._comm_handles:
|
||||
for handle in self._comm_handles:
|
||||
handle.wait()
|
||||
if self._comm_postprocess:
|
||||
for fn in self._comm_postprocess:
|
||||
fn()
|
||||
self._comm_waited = True
|
||||
|
||||
def __getattribute__(self, name: str):
|
||||
# ensure `.tensors` is ready before use
|
||||
if name == "tensors" and not object.__getattribute__(self, "_comm_waited"):
|
||||
object.__getattribute__(self, "wait_for_comm")()
|
||||
return object.__getattribute__(self, name)
|
||||
|
||||
|
||||
class Worker(WorkerBase):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -113,6 +147,8 @@ class Worker(WorkerBase):
|
||||
raise ValueError(f"Unknown profiler type: {self.profiler_config.profiler}")
|
||||
|
||||
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
|
||||
# pending non-blocking PP send work from the previous iteration
|
||||
self._pp_send_work: list[Handle] = []
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
from vllm.device_allocator.cumem import CuMemAllocator
|
||||
@@ -600,6 +636,12 @@ class Worker(WorkerBase):
|
||||
def execute_model(
|
||||
self, scheduler_output: "SchedulerOutput"
|
||||
) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
|
||||
# ensure any previous non-blocking PP sends are complete
|
||||
if self._pp_send_work:
|
||||
for handle in self._pp_send_work:
|
||||
handle.wait()
|
||||
self._pp_send_work = []
|
||||
|
||||
intermediate_tensors = None
|
||||
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
@@ -637,12 +679,18 @@ class Worker(WorkerBase):
|
||||
}
|
||||
|
||||
if forward_pass and not get_pp_group().is_first_rank:
|
||||
tensor_dict = get_pp_group().recv_tensor_dict(
|
||||
all_gather_group=get_tp_group(),
|
||||
all_gather_tensors=all_gather_tensors,
|
||||
tensor_dict, comm_handles, comm_postprocess = (
|
||||
get_pp_group().irecv_tensor_dict(
|
||||
all_gather_group=get_tp_group(),
|
||||
all_gather_tensors=all_gather_tensors,
|
||||
)
|
||||
)
|
||||
assert tensor_dict is not None
|
||||
intermediate_tensors = IntermediateTensors(tensor_dict)
|
||||
intermediate_tensors = AsyncIntermediateTensors(
|
||||
tensor_dict,
|
||||
comm_handles=comm_handles,
|
||||
comm_postprocess=comm_postprocess,
|
||||
)
|
||||
|
||||
with self.annotate_profile(scheduler_output):
|
||||
output = self.model_runner.execute_model(
|
||||
@@ -660,7 +708,8 @@ class Worker(WorkerBase):
|
||||
and not get_pp_group().is_last_rank
|
||||
)
|
||||
|
||||
get_pp_group().send_tensor_dict(
|
||||
# launch non-blocking send of intermediate tensors
|
||||
self._pp_send_work = get_pp_group().isend_tensor_dict(
|
||||
output.tensors,
|
||||
all_gather_group=get_tp_group(),
|
||||
all_gather_tensors=all_gather_tensors,
|
||||
|
||||
Reference in New Issue
Block a user