[Core] Simplifications to executor classes (#4071)

This commit is contained in:
Nick Hill
2024-04-15 13:05:09 -07:00
committed by GitHub
parent 0003e9154b
commit eb46fbfda2
5 changed files with 44 additions and 109 deletions

View File

@@ -3,11 +3,8 @@ import copy
import os
import pickle
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
TensorizerConfig, VisionLanguageConfig)
from vllm.engine.ray_utils import RayWorkerVllm, ray
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
@@ -32,27 +29,8 @@ USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))
class RayGPUExecutor(ExecutorBase):
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
tensorizer_config: Optional[TensorizerConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
self.tensorizer_config = tensorizer_config
assert (not speculative_config
def _init_executor(self) -> None:
assert (not self.speculative_config
), "Speculative decoding not yet supported for RayGPU backend."
assert self.parallel_config.worker_use_ray
@@ -273,7 +251,7 @@ class RayGPUExecutor(ExecutorBase):
lora_id=lora_id,
)
def list_loras(self) -> List[int]:
def list_loras(self) -> Set[int]:
return self._run_workers("list_loras")
def _run_workers(
@@ -416,7 +394,3 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
# Only the driver worker returns the sampling results.
output = all_outputs[0]
return output
async def check_health_async(self) -> None:
"""Raises an error if engine is unhealthy."""
self._check_if_any_actor_is_dead()