[Core] Simplify async KV output aggregation (#28327)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -29,6 +29,7 @@ import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import destroy_distributed_environment, destroy_model_parallel
|
||||
from vllm.distributed.device_communicators.shm_broadcast import Handle, MessageQueue
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_dp_group,
|
||||
get_ep_group,
|
||||
@@ -57,8 +58,13 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FutureWrapper(Future):
|
||||
def __init__(self, futures_queue: deque[tuple["FutureWrapper", Callable]]):
|
||||
def __init__(
|
||||
self,
|
||||
futures_queue: deque[tuple["FutureWrapper", Callable]],
|
||||
aggregate: Callable = lambda x: x,
|
||||
):
|
||||
self.futures_queue = futures_queue
|
||||
self.aggregate = aggregate
|
||||
super().__init__()
|
||||
|
||||
def result(self, timeout=None):
|
||||
@@ -72,7 +78,7 @@ class FutureWrapper(Future):
|
||||
|
||||
def wait_for_response(self, get_response: Callable):
|
||||
try:
|
||||
response = get_response()
|
||||
response = self.aggregate(get_response())
|
||||
with suppress(InvalidStateError):
|
||||
self.set_result(response)
|
||||
except Exception as e:
|
||||
@@ -160,7 +166,6 @@ class MultiprocExecutor(Executor):
|
||||
self.futures_queue = deque[tuple[FutureWrapper, Callable]]()
|
||||
|
||||
self.output_rank = self._get_output_rank()
|
||||
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
||||
|
||||
def start_worker_monitor(self):
|
||||
workers = self.workers
|
||||
@@ -199,44 +204,27 @@ class MultiprocExecutor(Executor):
|
||||
def execute_model( # type: ignore[override]
|
||||
self, scheduler_output: SchedulerOutput, non_block: bool = False
|
||||
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
|
||||
return self._execute_with_aggregation(
|
||||
"execute_model", scheduler_output, non_block=non_block
|
||||
return self.collective_rpc(
|
||||
"execute_model",
|
||||
args=(scheduler_output,),
|
||||
unique_reply_rank=self.output_rank,
|
||||
non_block=non_block,
|
||||
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
|
||||
kv_output_aggregator=self.kv_output_aggregator,
|
||||
)
|
||||
|
||||
def sample_tokens( # type: ignore[override]
|
||||
self, grammar_output: GrammarOutput | None, non_block: bool = False
|
||||
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
|
||||
return self._execute_with_aggregation( # type: ignore[return-value]
|
||||
"sample_tokens", grammar_output, non_block=non_block
|
||||
)
|
||||
|
||||
def _execute_with_aggregation(
|
||||
self, method: str, *args, non_block: bool = False
|
||||
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
|
||||
if not self.has_connector:
|
||||
# get output only from a single worker (output_rank)
|
||||
return self.collective_rpc(
|
||||
method,
|
||||
args=args,
|
||||
unique_reply_rank=self.output_rank,
|
||||
non_block=non_block,
|
||||
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
# get output from all workers
|
||||
outputs = self.collective_rpc(
|
||||
method,
|
||||
args=args,
|
||||
return self.collective_rpc(
|
||||
"sample_tokens",
|
||||
args=(grammar_output,),
|
||||
unique_reply_rank=self.output_rank,
|
||||
non_block=non_block,
|
||||
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
|
||||
kv_output_aggregator=self.kv_output_aggregator,
|
||||
)
|
||||
|
||||
# aggregate all workers output to a single output
|
||||
assert self.kv_output_aggregator is not None
|
||||
if non_block:
|
||||
return self.kv_output_aggregator.async_aggregate(outputs, self.output_rank)
|
||||
return self.kv_output_aggregator.aggregate(outputs, self.output_rank)
|
||||
|
||||
def execute_dummy_batch(self) -> None:
|
||||
self.collective_rpc("execute_dummy_batch", unique_reply_rank=self.output_rank)
|
||||
|
||||
@@ -254,8 +242,10 @@ class MultiprocExecutor(Executor):
|
||||
kwargs: dict | None = None,
|
||||
non_block: bool = False,
|
||||
unique_reply_rank: int | None = None,
|
||||
kv_output_aggregator: KVOutputAggregator = None,
|
||||
) -> Any | list[Any] | Future[Any | list[Any]]:
|
||||
"""Returns single result if unique_reply_rank is provided, otherwise list."""
|
||||
"""Returns single result if unique_reply_rank and/or kv_output_aggregator
|
||||
is provided, otherwise list."""
|
||||
|
||||
if self.is_failed:
|
||||
raise RuntimeError("Executor failed.")
|
||||
@@ -263,20 +253,23 @@ class MultiprocExecutor(Executor):
|
||||
deadline = None if timeout is None else time.monotonic() + timeout
|
||||
kwargs = kwargs or {}
|
||||
|
||||
# NOTE: If the args are heterogeneous, then we pack them into a list,
|
||||
# and unpack them in the method of every worker, because every worker
|
||||
# knows their own rank.
|
||||
if kv_output_aggregator is not None:
|
||||
output_rank = None
|
||||
aggregate: Callable[[Any], Any] = partial(
|
||||
kv_output_aggregator.aggregate, output_rank=unique_reply_rank or 0
|
||||
)
|
||||
else:
|
||||
output_rank = unique_reply_rank
|
||||
aggregate = lambda x: x
|
||||
|
||||
if isinstance(method, str):
|
||||
send_method = method
|
||||
else:
|
||||
send_method = cloudpickle.dumps(method, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, unique_reply_rank))
|
||||
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, output_rank))
|
||||
|
||||
workers = (
|
||||
(self.workers[unique_reply_rank],)
|
||||
if unique_reply_rank is not None
|
||||
else self.workers
|
||||
(self.workers[output_rank],) if output_rank is not None else self.workers
|
||||
)
|
||||
|
||||
shutdown_event = self.shutdown_event
|
||||
@@ -299,10 +292,10 @@ class MultiprocExecutor(Executor):
|
||||
" stack trace above for the root cause"
|
||||
)
|
||||
responses.append(result)
|
||||
return responses[0] if unique_reply_rank is not None else responses
|
||||
return responses[0] if output_rank is not None else responses
|
||||
|
||||
if non_block:
|
||||
future = FutureWrapper(self.futures_queue)
|
||||
future = FutureWrapper(self.futures_queue, aggregate=aggregate)
|
||||
self.futures_queue.appendleft((future, get_response))
|
||||
return future
|
||||
|
||||
@@ -311,7 +304,7 @@ class MultiprocExecutor(Executor):
|
||||
future, get_fut_response = self.futures_queue.pop()
|
||||
future.wait_for_response(get_fut_response)
|
||||
|
||||
return get_response()
|
||||
return aggregate(get_response())
|
||||
|
||||
@staticmethod
|
||||
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
|
||||
|
||||
Reference in New Issue
Block a user