[BugFix] Fix clean shutdown issues (#8492)

This commit is contained in:
Nick Hill
2024-09-16 17:33:46 +01:00
committed by GitHub
parent 837c1968f9
commit acd5511b6d
11 changed files with 213 additions and 134 deletions

View File

@@ -1,8 +1,10 @@
import asyncio
import time
import weakref
from functools import partial
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
Mapping, Optional, Set, Tuple, Type, Union)
from weakref import ReferenceType
import vllm.envs as envs
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
@@ -26,6 +28,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import weak_bind
logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
@@ -450,9 +453,6 @@ class AsyncLLMEngine:
method yields the outputs from the :class:`LLMEngine` to the caller.
Args:
worker_use_ray: Whether to use Ray for model workers. Required for
distributed execution. Should be the same as
`parallel_config.worker_use_ray`.
log_requests: Whether to log the requests.
start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call.
@@ -463,23 +463,22 @@ class AsyncLLMEngine:
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
def __init__(self,
worker_use_ray: bool,
*args,
log_requests: bool = True,
start_engine_loop: bool = True,
**kwargs) -> None:
self.worker_use_ray = worker_use_ray
self.log_requests = log_requests
self.engine = self._engine_class(*args, **kwargs)
# This ensures quick processing of request outputs
# so the append to asyncio queues is not delayed,
# especially for multi-step.
#
self.use_process_request_outputs_callback = True
self.use_process_request_outputs_callback = (
self.engine.model_config.use_async_output_proc)
if self.use_process_request_outputs_callback:
self.engine.process_request_outputs_callback = \
self.process_request_outputs
weak_bind(self.process_request_outputs)
self.background_loop: Optional[asyncio.Future] = None
# We need to keep a reference to unshielded
@@ -492,6 +491,11 @@ class AsyncLLMEngine:
# Lazy initialized fields
self._request_tracker: RequestTracker
def __del__(self):
if rt := getattr(self, "request_tracker", None):
# Wake up engine loop so that it will exit cleanly
rt.new_requests_event.set()
@classmethod
def _get_executor_cls(
cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
@@ -502,15 +506,12 @@ class AsyncLLMEngine:
raise TypeError(
"distributed_executor_backend must be a subclass of "
f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
if distributed_executor_backend.uses_ray: # type: ignore
initialize_ray_cluster(engine_config.parallel_config)
executor_class = distributed_executor_backend
elif engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "tpu":
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
executor_class = RayTPUExecutorAsync
else:
@@ -531,11 +532,9 @@ class AsyncLLMEngine:
from vllm.executor.xpu_executor import XPUExecutorAsync
executor_class = XPUExecutorAsync
elif distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
executor_class = RayXPUExecutorAsync
elif distributed_executor_backend == "mp":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.multiproc_xpu_executor import (
MultiprocessingXPUExecutorAsync)
executor_class = MultiprocessingXPUExecutorAsync
@@ -543,7 +542,6 @@ class AsyncLLMEngine:
raise RuntimeError(
"Not supported distributed execution model on XPU device.")
elif distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
elif distributed_executor_backend == "mp":
@@ -559,19 +557,23 @@ class AsyncLLMEngine:
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
engine_config: Optional[EngineConfig] = None,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
if engine_config is None:
engine_config = engine_args.create_engine_config()
executor_class = cls._get_executor_cls(engine_config)
if executor_class.uses_ray:
initialize_ray_cluster(engine_config.parallel_config)
# Create the async LLM engine.
engine = cls(
executor_class.uses_ray,
**engine_config.to_dict(),
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
@@ -628,7 +630,7 @@ class AsyncLLMEngine:
self._request_tracker = RequestTracker()
self._background_loop_unshielded = asyncio.get_event_loop(
).create_task(self.run_engine_loop())
).create_task(self.run_engine_loop(weakref.ref(self)))
self._background_loop_unshielded.add_done_callback(
partial(_log_task_completion, error_callback=self._error_callback))
self.background_loop = asyncio.shield(self._background_loop_unshielded)
@@ -698,9 +700,16 @@ class AsyncLLMEngine:
async def _engine_abort(self, request_ids: Iterable[str]):
self.engine.abort_request(request_ids)
async def run_engine_loop(self):
@staticmethod
async def run_engine_loop(engine_ref: ReferenceType):
"""We use a weakref to the engine so that the running loop
doesn't prevent the engine being garbage collected."""
engine: Optional["AsyncLLMEngine"] = engine_ref()
if not engine:
return
pipeline_parallel_size = \
self.engine.parallel_config.pipeline_parallel_size
engine.engine.parallel_config.pipeline_parallel_size
has_requests_in_progress = [False] * pipeline_parallel_size
while True:
if not any(has_requests_in_progress):
@@ -711,11 +720,21 @@ class AsyncLLMEngine:
# timeout, and unblocks the RPC thread in the workers so that
# they can process any other queued control plane messages,
# such as add/remove lora adapters.
await self.engine.stop_remote_worker_execution_loop_async()
await self._request_tracker.wait_for_new_requests()
await engine.engine.stop_remote_worker_execution_loop_async()
request_tracker = engine._request_tracker
# Allow engine to be garbage collected while
# waiting for new requests
del engine
await asyncio.sleep(0)
if engine_ref() is None:
return
await request_tracker.wait_for_new_requests()
engine = engine_ref()
if not engine:
return
logger.debug("Got new requests!")
requests_in_progress = [
asyncio.create_task(self.engine_step(ve))
asyncio.create_task(engine.engine_step(ve))
for ve in range(pipeline_parallel_size)
]
has_requests_in_progress = [True] * pipeline_parallel_size
@@ -733,19 +752,20 @@ class AsyncLLMEngine:
result = task.result()
virtual_engine = requests_in_progress.index(task)
has_unfinished_requests = (
self.engine.has_unfinished_requests_for_virtual_engine(
engine.engine.
has_unfinished_requests_for_virtual_engine(
virtual_engine))
if result or has_unfinished_requests:
requests_in_progress[virtual_engine] = (
asyncio.create_task(
self.engine_step(virtual_engine)))
engine.engine_step(virtual_engine)))
has_requests_in_progress[virtual_engine] = True
else:
has_requests_in_progress[virtual_engine] = False
except asyncio.TimeoutError as exc:
logger.error(
"Engine iteration timed out. This should never happen!")
self.set_errored(exc)
engine.set_errored(exc)
raise
await asyncio.sleep(0)