[V1][Frontend] Improve Shutdown And Logs (#11737)

Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Andrew Feldman <afeldman@neuralmagic.com>
Co-authored-by: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Robert Shaw
2025-04-16 22:48:34 -04:00
committed by GitHub
parent 3c776dcefb
commit 2b05b8ce69
16 changed files with 1031 additions and 347 deletions

View File

@@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
import os
from collections.abc import AsyncGenerator, Mapping
from copy import copy
from typing import Optional, Union
@@ -26,9 +24,10 @@ from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device, cdiv, kill_process_tree
from vllm.utils import Device, cdiv
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.core_client import AsyncMPClient, DPAsyncMPClient
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm.v1.engine.output_processor import (OutputProcessor,
RequestOutputCollector)
from vllm.v1.engine.parallel_sampling import ParentRequest
@@ -61,8 +60,6 @@ class AsyncLLM(EngineClient):
"AsyncLLMEngine.from_vllm_config(...) or explicitly set "
"VLLM_USE_V1=0 or 1 and report this issue on Github.")
assert start_engine_loop
self.model_config = vllm_config.model_config
self.vllm_config = vllm_config
self.log_requests = log_requests
@@ -99,15 +96,23 @@ class AsyncLLM(EngineClient):
log_stats=self.log_stats)
# EngineCore (starts the engine in background process).
self.engine_core = EngineCoreClient.make_client(
multiprocess_mode=True,
asyncio_mode=True,
core_client_class = AsyncMPClient if (
vllm_config.parallel_config.data_parallel_size
== 1) else DPAsyncMPClient
self.engine_core = core_client_class(
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=self.log_stats,
)
self.output_handler: Optional[asyncio.Task] = None
try:
# Start output handler eagerly if we are in the asyncio eventloop.
asyncio.get_running_loop()
self._run_output_handler()
except RuntimeError:
pass
@classmethod
def from_vllm_config(
@@ -165,6 +170,9 @@ class AsyncLLM(EngineClient):
usage_context=usage_context,
)
def __del__(self):
self.shutdown()
def shutdown(self):
"""Shutdown, cleaning up the background proc and IPC."""
@@ -187,6 +195,9 @@ class AsyncLLM(EngineClient):
) -> RequestOutputCollector:
"""Add new request to the AsyncLLM."""
if self.errored:
raise EngineDeadError()
assert isinstance(params, SamplingParams), \
"Pooling is not supported in V1"
@@ -261,9 +272,7 @@ class AsyncLLM(EngineClient):
# We start the output_handler on the first call to generate() so
# we can call __init__ before the event loop, which enables us
# to handle startup failure gracefully in the OpenAI server.
if self.output_handler is None:
self.output_handler = asyncio.create_task(
self._run_output_handler())
self._run_output_handler()
q = await self.add_request(
request_id,
@@ -288,62 +297,96 @@ class AsyncLLM(EngineClient):
finished = out.finished
yield out
# If the request is disconnected by the client, the
# generate() task will be canceled. So, we abort the
# request if we end up here.
# If the request is disconnected by the client, generate()
# is cancelled. So, we abort the request if we end up here.
except asyncio.CancelledError:
await self.abort(request_id)
if self.log_requests:
logger.info("Request %s aborted.", request_id)
raise
async def _run_output_handler(self):
# Engine is dead. Do not abort since we shut down.
except EngineDeadError:
if self.log_requests:
logger.info("Request %s failed (engine dead).", request_id)
raise
# Request validation error.
except ValueError:
if self.log_requests:
logger.info("Request %s failed (bad request).", request_id)
raise
# Unexpected error in the generate() task (possibly recoverable).
except Exception as e:
await self.abort(request_id)
if self.log_requests:
logger.info("Request %s failed.", request_id)
raise EngineGenerateError() from e
def _run_output_handler(self):
"""Background loop: pulls from EngineCore and pushes to AsyncStreams."""
try:
while True:
# 1) Pull EngineCoreOutputs from the EngineCore.
outputs = await self.engine_core.get_output_async()
num_outputs = len(outputs.outputs)
if self.output_handler is not None:
return
iteration_stats = IterationStats() if (
self.log_stats and num_outputs) else None
# Ensure that the task doesn't have a circular ref back to the AsyncLLM
# object, or else it won't be garbage collected and cleaned up properly.
engine_core = self.engine_core
output_processor = self.output_processor
log_stats = self.log_stats
stat_loggers = self.stat_loggers if log_stats else None
# Split outputs into chunks of at most
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
# event loop for too long.
if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
slices = (outputs.outputs, )
else:
slices = np.array_split(
outputs.outputs,
cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))
async def output_handler():
try:
while True:
# 1) Pull EngineCoreOutputs from the EngineCore.
outputs = await engine_core.get_output_async()
num_outputs = len(outputs.outputs)
for i, outputs_slice in enumerate(slices):
# 2) Process EngineCoreOutputs.
processed_outputs = self.output_processor.process_outputs(
outputs_slice, outputs.timestamp, iteration_stats)
# NOTE: RequestOutputs are pushed to their queues.
assert not processed_outputs.request_outputs
iteration_stats = IterationStats() if (
log_stats and num_outputs) else None
# Allow other asyncio tasks to run between chunks
if i + 1 < len(slices):
await asyncio.sleep(0)
# Split outputs into chunks of at most
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
# event loop for too long.
if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
slices = (outputs.outputs, )
else:
slices = np.array_split(
outputs.outputs,
cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))
# 3) Abort any reqs that finished due to stop strings.
await self.engine_core.abort_requests_async(
processed_outputs.reqs_to_abort)
for i, outputs_slice in enumerate(slices):
# 2) Process EngineCoreOutputs.
processed_outputs = output_processor.process_outputs(
outputs_slice, outputs.timestamp, iteration_stats)
# NOTE: RequestOutputs are pushed to their queues.
assert not processed_outputs.request_outputs
# 4) Logging.
# TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial.
self._record_stats(
engine_index=outputs.engine_index,
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
)
# Allow other asyncio tasks to run between chunks
if i + 1 < len(slices):
await asyncio.sleep(0)
except Exception as e:
logger.exception("EngineCore output handler hit an error: %s", e)
kill_process_tree(os.getpid())
# 3) Abort any reqs that finished due to stop strings.
await engine_core.abort_requests_async(
processed_outputs.reqs_to_abort)
# 4) Logging.
# TODO(rob): make into a coroutine and launch it in
# background thread once Prometheus overhead is non-trivial.
if stat_loggers:
assert outputs.scheduler_stats is not None
AsyncLLM._record_stats(
stat_loggers[outputs.engine_index],
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
)
except Exception as e:
logger.exception("AsyncLLM output_handler failed.")
output_processor.propagate_error(e)
self.output_handler = asyncio.create_task(output_handler())
async def abort(self, request_id: str) -> None:
"""Abort RequestId in OutputProcessor and EngineCore."""
@@ -354,17 +397,15 @@ class AsyncLLM(EngineClient):
if self.log_requests:
logger.info("Aborted request %s.", request_id)
@staticmethod
def _record_stats(
self,
scheduler_stats: Optional[SchedulerStats],
stat_loggers: list[StatLoggerBase],
scheduler_stats: SchedulerStats,
iteration_stats: Optional[IterationStats],
engine_index: int = 0,
):
if not self.log_stats:
return
assert scheduler_stats is not None
for stat_logger in self.stat_loggers[engine_index]:
"""static so that it can be used from the output_handler task
without a circular ref to AsyncLLM."""
for stat_logger in stat_loggers:
stat_logger.record(scheduler_stats=scheduler_stats,
iteration_stats=iteration_stats)
@@ -451,16 +492,17 @@ class AsyncLLM(EngineClient):
@property
def is_running(self) -> bool:
return True
# Is None before the loop is started.
return self.output_handler is None or not self.output_handler.done()
@property
def is_stopped(self) -> bool:
return False
return self.errored
@property
def errored(self) -> bool:
return False
return self.engine_core.resources.engine_dead or not self.is_running
@property
def dead_error(self) -> BaseException:
return Exception() # TODO: implement
return EngineDeadError()