Add distributed model executor abstraction (#3191)

This commit is contained in:
Zhuohan Li
2024-03-11 11:03:45 -07:00
committed by GitHub
parent 657061fdce
commit 4c922709b6
13 changed files with 817 additions and 508 deletions

View File

@@ -2,8 +2,8 @@ import asyncio
import os
import time
from functools import partial
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
Union, AsyncIterator, Callable)
from typing import (Callable, Dict, Iterable, List, Optional, Set, Tuple, Type,
Union, AsyncIterator)
from transformers import PreTrainedTokenizer
@@ -11,7 +11,7 @@ from vllm.lora.request import LoRARequest
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.ray_utils import initialize_cluster, ray
from vllm.engine.ray_utils import initialize_ray_cluster, ray
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
@@ -208,17 +208,10 @@ class _AsyncLLMEngine(LLMEngine):
if not scheduler_outputs.is_empty():
# Execute the model.
all_outputs = await self._run_workers_async(
"execute_model",
driver_kwargs={
"seq_group_metadata_list": seq_group_metadata_list,
"blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
"blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
"blocks_to_copy": scheduler_outputs.blocks_to_copy,
})
# Only the driver worker returns the sampling results.
output = all_outputs[0]
output = await self.model_executor.execute_model_async(
seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
scheduler_outputs.blocks_to_swap_out,
scheduler_outputs.blocks_to_copy)
else:
output = []
@@ -268,37 +261,8 @@ class _AsyncLLMEngine(LLMEngine):
lora_request=lora_request,
)
async def _run_workers_async(
self,
method: str,
*args,
driver_args: Optional[List[Any]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
coros = []
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs
# Run the driver worker asynchronously.
driver_executor = getattr(self.driver_worker, method)
coros.append(asyncio.get_event_loop().run_in_executor(
None, partial(driver_executor, *driver_args, **driver_kwargs)))
# Run the ray workers asynchronously.
for worker in self.workers:
coros.append(worker.execute_method.remote(method, *args, **kwargs))
all_outputs = await asyncio.gather(*coros)
return all_outputs
async def check_health_async(self):
"""Raises an error if engine is unhealthy."""
self._check_if_any_actor_is_dead()
async def check_health_async(self) -> None:
self.model_executor.check_health()
class AsyncLLMEngine:
@@ -353,6 +317,34 @@ class AsyncLLMEngine:
self._request_tracker: Optional[RequestTracker] = None
self._errored_with: Optional[BaseException] = None
@classmethod
def from_engine_args(cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2]
if parallel_config.worker_use_ray or engine_args.engine_use_ray:
initialize_ray_cluster(parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
else:
assert parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
from vllm.executor.gpu_executor import GPUExecutorAsync
executor_class = GPUExecutorAsync
# Create the async LLM engine.
engine = cls(parallel_config.worker_use_ray,
engine_args.engine_use_ray,
*engine_configs,
executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop)
return engine
@property
def is_running(self) -> bool:
return (self.background_loop is not None
@@ -670,35 +662,13 @@ class AsyncLLMEngine:
else:
return self.engine.get_model_config()
@classmethod
def from_engine_args(cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2]
# Initialize the cluster.
placement_group = initialize_cluster(parallel_config,
engine_args.engine_use_ray)
# Create the async LLM engine.
engine = cls(parallel_config.worker_use_ray,
engine_args.engine_use_ray,
*engine_configs,
placement_group,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop)
return engine
async def do_log_stats(self) -> None:
if self.engine_use_ray:
await self.engine.do_log_stats.remote()
else:
self.engine.do_log_stats()
async def check_health(self):
async def check_health(self) -> None:
"""Raises an error if engine is unhealthy."""
t = time.perf_counter()
logger.debug("Starting health check...")