[BugFix] Fix clean shutdown issues (#8492)
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
import asyncio
|
||||
import time
|
||||
import weakref
|
||||
from functools import partial
|
||||
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
|
||||
Mapping, Optional, Set, Tuple, Type, Union)
|
||||
from weakref import ReferenceType
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
|
||||
@@ -26,6 +28,7 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import weak_bind
|
||||
|
||||
logger = init_logger(__name__)
|
||||
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
|
||||
@@ -450,9 +453,6 @@ class AsyncLLMEngine:
|
||||
method yields the outputs from the :class:`LLMEngine` to the caller.
|
||||
|
||||
Args:
|
||||
worker_use_ray: Whether to use Ray for model workers. Required for
|
||||
distributed execution. Should be the same as
|
||||
`parallel_config.worker_use_ray`.
|
||||
log_requests: Whether to log the requests.
|
||||
start_engine_loop: If True, the background task to run the engine
|
||||
will be automatically started in the generate call.
|
||||
@@ -463,23 +463,22 @@ class AsyncLLMEngine:
|
||||
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
|
||||
|
||||
def __init__(self,
|
||||
worker_use_ray: bool,
|
||||
*args,
|
||||
log_requests: bool = True,
|
||||
start_engine_loop: bool = True,
|
||||
**kwargs) -> None:
|
||||
self.worker_use_ray = worker_use_ray
|
||||
self.log_requests = log_requests
|
||||
self.engine = self._engine_class(*args, **kwargs)
|
||||
|
||||
# This ensures quick processing of request outputs
|
||||
# so the append to asyncio queues is not delayed,
|
||||
# especially for multi-step.
|
||||
#
|
||||
self.use_process_request_outputs_callback = True
|
||||
self.use_process_request_outputs_callback = (
|
||||
self.engine.model_config.use_async_output_proc)
|
||||
|
||||
if self.use_process_request_outputs_callback:
|
||||
self.engine.process_request_outputs_callback = \
|
||||
self.process_request_outputs
|
||||
weak_bind(self.process_request_outputs)
|
||||
|
||||
self.background_loop: Optional[asyncio.Future] = None
|
||||
# We need to keep a reference to unshielded
|
||||
@@ -492,6 +491,11 @@ class AsyncLLMEngine:
|
||||
# Lazy initialized fields
|
||||
self._request_tracker: RequestTracker
|
||||
|
||||
def __del__(self):
|
||||
if rt := getattr(self, "request_tracker", None):
|
||||
# Wake up engine loop so that it will exit cleanly
|
||||
rt.new_requests_event.set()
|
||||
|
||||
@classmethod
|
||||
def _get_executor_cls(
|
||||
cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
|
||||
@@ -502,15 +506,12 @@ class AsyncLLMEngine:
|
||||
raise TypeError(
|
||||
"distributed_executor_backend must be a subclass of "
|
||||
f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
|
||||
if distributed_executor_backend.uses_ray: # type: ignore
|
||||
initialize_ray_cluster(engine_config.parallel_config)
|
||||
executor_class = distributed_executor_backend
|
||||
elif engine_config.device_config.device_type == "neuron":
|
||||
from vllm.executor.neuron_executor import NeuronExecutorAsync
|
||||
executor_class = NeuronExecutorAsync
|
||||
elif engine_config.device_config.device_type == "tpu":
|
||||
if distributed_executor_backend == "ray":
|
||||
initialize_ray_cluster(engine_config.parallel_config)
|
||||
from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
|
||||
executor_class = RayTPUExecutorAsync
|
||||
else:
|
||||
@@ -531,11 +532,9 @@ class AsyncLLMEngine:
|
||||
from vllm.executor.xpu_executor import XPUExecutorAsync
|
||||
executor_class = XPUExecutorAsync
|
||||
elif distributed_executor_backend == "ray":
|
||||
initialize_ray_cluster(engine_config.parallel_config)
|
||||
from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
|
||||
executor_class = RayXPUExecutorAsync
|
||||
elif distributed_executor_backend == "mp":
|
||||
initialize_ray_cluster(engine_config.parallel_config)
|
||||
from vllm.executor.multiproc_xpu_executor import (
|
||||
MultiprocessingXPUExecutorAsync)
|
||||
executor_class = MultiprocessingXPUExecutorAsync
|
||||
@@ -543,7 +542,6 @@ class AsyncLLMEngine:
|
||||
raise RuntimeError(
|
||||
"Not supported distributed execution model on XPU device.")
|
||||
elif distributed_executor_backend == "ray":
|
||||
initialize_ray_cluster(engine_config.parallel_config)
|
||||
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
|
||||
executor_class = RayGPUExecutorAsync
|
||||
elif distributed_executor_backend == "mp":
|
||||
@@ -559,19 +557,23 @@ class AsyncLLMEngine:
|
||||
def from_engine_args(
|
||||
cls,
|
||||
engine_args: AsyncEngineArgs,
|
||||
engine_config: Optional[EngineConfig] = None,
|
||||
start_engine_loop: bool = True,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||
) -> "AsyncLLMEngine":
|
||||
"""Creates an async LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
engine_config = engine_args.create_engine_config()
|
||||
if engine_config is None:
|
||||
engine_config = engine_args.create_engine_config()
|
||||
|
||||
executor_class = cls._get_executor_cls(engine_config)
|
||||
|
||||
if executor_class.uses_ray:
|
||||
initialize_ray_cluster(engine_config.parallel_config)
|
||||
|
||||
# Create the async LLM engine.
|
||||
engine = cls(
|
||||
executor_class.uses_ray,
|
||||
**engine_config.to_dict(),
|
||||
executor_class=executor_class,
|
||||
log_requests=not engine_args.disable_log_requests,
|
||||
@@ -628,7 +630,7 @@ class AsyncLLMEngine:
|
||||
self._request_tracker = RequestTracker()
|
||||
|
||||
self._background_loop_unshielded = asyncio.get_event_loop(
|
||||
).create_task(self.run_engine_loop())
|
||||
).create_task(self.run_engine_loop(weakref.ref(self)))
|
||||
self._background_loop_unshielded.add_done_callback(
|
||||
partial(_log_task_completion, error_callback=self._error_callback))
|
||||
self.background_loop = asyncio.shield(self._background_loop_unshielded)
|
||||
@@ -698,9 +700,16 @@ class AsyncLLMEngine:
|
||||
async def _engine_abort(self, request_ids: Iterable[str]):
|
||||
self.engine.abort_request(request_ids)
|
||||
|
||||
async def run_engine_loop(self):
|
||||
@staticmethod
|
||||
async def run_engine_loop(engine_ref: ReferenceType):
|
||||
"""We use a weakref to the engine so that the running loop
|
||||
doesn't prevent the engine being garbage collected."""
|
||||
engine: Optional["AsyncLLMEngine"] = engine_ref()
|
||||
if not engine:
|
||||
return
|
||||
|
||||
pipeline_parallel_size = \
|
||||
self.engine.parallel_config.pipeline_parallel_size
|
||||
engine.engine.parallel_config.pipeline_parallel_size
|
||||
has_requests_in_progress = [False] * pipeline_parallel_size
|
||||
while True:
|
||||
if not any(has_requests_in_progress):
|
||||
@@ -711,11 +720,21 @@ class AsyncLLMEngine:
|
||||
# timeout, and unblocks the RPC thread in the workers so that
|
||||
# they can process any other queued control plane messages,
|
||||
# such as add/remove lora adapters.
|
||||
await self.engine.stop_remote_worker_execution_loop_async()
|
||||
await self._request_tracker.wait_for_new_requests()
|
||||
await engine.engine.stop_remote_worker_execution_loop_async()
|
||||
request_tracker = engine._request_tracker
|
||||
# Allow engine to be garbage collected while
|
||||
# waiting for new requests
|
||||
del engine
|
||||
await asyncio.sleep(0)
|
||||
if engine_ref() is None:
|
||||
return
|
||||
await request_tracker.wait_for_new_requests()
|
||||
engine = engine_ref()
|
||||
if not engine:
|
||||
return
|
||||
logger.debug("Got new requests!")
|
||||
requests_in_progress = [
|
||||
asyncio.create_task(self.engine_step(ve))
|
||||
asyncio.create_task(engine.engine_step(ve))
|
||||
for ve in range(pipeline_parallel_size)
|
||||
]
|
||||
has_requests_in_progress = [True] * pipeline_parallel_size
|
||||
@@ -733,19 +752,20 @@ class AsyncLLMEngine:
|
||||
result = task.result()
|
||||
virtual_engine = requests_in_progress.index(task)
|
||||
has_unfinished_requests = (
|
||||
self.engine.has_unfinished_requests_for_virtual_engine(
|
||||
engine.engine.
|
||||
has_unfinished_requests_for_virtual_engine(
|
||||
virtual_engine))
|
||||
if result or has_unfinished_requests:
|
||||
requests_in_progress[virtual_engine] = (
|
||||
asyncio.create_task(
|
||||
self.engine_step(virtual_engine)))
|
||||
engine.engine_step(virtual_engine)))
|
||||
has_requests_in_progress[virtual_engine] = True
|
||||
else:
|
||||
has_requests_in_progress[virtual_engine] = False
|
||||
except asyncio.TimeoutError as exc:
|
||||
logger.error(
|
||||
"Engine iteration timed out. This should never happen!")
|
||||
self.set_errored(exc)
|
||||
engine.set_errored(exc)
|
||||
raise
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user