Add distributed model executor abstraction (#3191)
This commit is contained in:
@@ -2,8 +2,8 @@ import asyncio
|
||||
import os
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
|
||||
Union, AsyncIterator, Callable)
|
||||
from typing import (Callable, Dict, Iterable, List, Optional, Set, Tuple, Type,
|
||||
Union, AsyncIterator)
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
@@ -11,7 +11,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.engine.ray_utils import initialize_cluster, ray
|
||||
from vllm.engine.ray_utils import initialize_ray_cluster, ray
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@@ -208,17 +208,10 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
|
||||
if not scheduler_outputs.is_empty():
|
||||
# Execute the model.
|
||||
all_outputs = await self._run_workers_async(
|
||||
"execute_model",
|
||||
driver_kwargs={
|
||||
"seq_group_metadata_list": seq_group_metadata_list,
|
||||
"blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
|
||||
"blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
|
||||
"blocks_to_copy": scheduler_outputs.blocks_to_copy,
|
||||
})
|
||||
|
||||
# Only the driver worker returns the sampling results.
|
||||
output = all_outputs[0]
|
||||
output = await self.model_executor.execute_model_async(
|
||||
seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
|
||||
scheduler_outputs.blocks_to_swap_out,
|
||||
scheduler_outputs.blocks_to_copy)
|
||||
else:
|
||||
output = []
|
||||
|
||||
@@ -268,37 +261,8 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
async def _run_workers_async(
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
driver_args: Optional[List[Any]] = None,
|
||||
driver_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers."""
|
||||
coros = []
|
||||
|
||||
if driver_args is None:
|
||||
driver_args = args
|
||||
if driver_kwargs is None:
|
||||
driver_kwargs = kwargs
|
||||
|
||||
# Run the driver worker asynchronously.
|
||||
driver_executor = getattr(self.driver_worker, method)
|
||||
coros.append(asyncio.get_event_loop().run_in_executor(
|
||||
None, partial(driver_executor, *driver_args, **driver_kwargs)))
|
||||
|
||||
# Run the ray workers asynchronously.
|
||||
for worker in self.workers:
|
||||
coros.append(worker.execute_method.remote(method, *args, **kwargs))
|
||||
|
||||
all_outputs = await asyncio.gather(*coros)
|
||||
return all_outputs
|
||||
|
||||
async def check_health_async(self):
|
||||
"""Raises an error if engine is unhealthy."""
|
||||
self._check_if_any_actor_is_dead()
|
||||
async def check_health_async(self) -> None:
|
||||
self.model_executor.check_health()
|
||||
|
||||
|
||||
class AsyncLLMEngine:
|
||||
@@ -353,6 +317,34 @@ class AsyncLLMEngine:
|
||||
self._request_tracker: Optional[RequestTracker] = None
|
||||
self._errored_with: Optional[BaseException] = None
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(cls,
|
||||
engine_args: AsyncEngineArgs,
|
||||
start_engine_loop: bool = True) -> "AsyncLLMEngine":
|
||||
"""Creates an async LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
engine_configs = engine_args.create_engine_configs()
|
||||
parallel_config = engine_configs[2]
|
||||
if parallel_config.worker_use_ray or engine_args.engine_use_ray:
|
||||
initialize_ray_cluster(parallel_config)
|
||||
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
|
||||
executor_class = RayGPUExecutorAsync
|
||||
else:
|
||||
assert parallel_config.world_size == 1, (
|
||||
"Ray is required if parallel_config.world_size > 1.")
|
||||
from vllm.executor.gpu_executor import GPUExecutorAsync
|
||||
executor_class = GPUExecutorAsync
|
||||
# Create the async LLM engine.
|
||||
engine = cls(parallel_config.worker_use_ray,
|
||||
engine_args.engine_use_ray,
|
||||
*engine_configs,
|
||||
executor_class,
|
||||
log_requests=not engine_args.disable_log_requests,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
max_log_len=engine_args.max_log_len,
|
||||
start_engine_loop=start_engine_loop)
|
||||
return engine
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return (self.background_loop is not None
|
||||
@@ -670,35 +662,13 @@ class AsyncLLMEngine:
|
||||
else:
|
||||
return self.engine.get_model_config()
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(cls,
|
||||
engine_args: AsyncEngineArgs,
|
||||
start_engine_loop: bool = True) -> "AsyncLLMEngine":
|
||||
"""Creates an async LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
engine_configs = engine_args.create_engine_configs()
|
||||
parallel_config = engine_configs[2]
|
||||
# Initialize the cluster.
|
||||
placement_group = initialize_cluster(parallel_config,
|
||||
engine_args.engine_use_ray)
|
||||
# Create the async LLM engine.
|
||||
engine = cls(parallel_config.worker_use_ray,
|
||||
engine_args.engine_use_ray,
|
||||
*engine_configs,
|
||||
placement_group,
|
||||
log_requests=not engine_args.disable_log_requests,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
max_log_len=engine_args.max_log_len,
|
||||
start_engine_loop=start_engine_loop)
|
||||
return engine
|
||||
|
||||
async def do_log_stats(self) -> None:
|
||||
if self.engine_use_ray:
|
||||
await self.engine.do_log_stats.remote()
|
||||
else:
|
||||
self.engine.do_log_stats()
|
||||
|
||||
async def check_health(self):
|
||||
async def check_health(self) -> None:
|
||||
"""Raises an error if engine is unhealthy."""
|
||||
t = time.perf_counter()
|
||||
logger.debug("Starting health check...")
|
||||
|
||||
Reference in New Issue
Block a user