[Core] Simplifications to executor classes (#4071)
This commit is contained in:
@@ -3,11 +3,8 @@ import copy
|
||||
import os
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||
TensorizerConfig, VisionLanguageConfig)
|
||||
from vllm.engine.ray_utils import RayWorkerVllm, ray
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
@@ -32,27 +29,8 @@ USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))
|
||||
|
||||
class RayGPUExecutor(ExecutorBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
cache_config: CacheConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
tensorizer_config: Optional[TensorizerConfig],
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.vision_language_config = vision_language_config
|
||||
self.tensorizer_config = tensorizer_config
|
||||
assert (not speculative_config
|
||||
def _init_executor(self) -> None:
|
||||
assert (not self.speculative_config
|
||||
), "Speculative decoding not yet supported for RayGPU backend."
|
||||
|
||||
assert self.parallel_config.worker_use_ray
|
||||
@@ -273,7 +251,7 @@ class RayGPUExecutor(ExecutorBase):
|
||||
lora_id=lora_id,
|
||||
)
|
||||
|
||||
def list_loras(self) -> List[int]:
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self._run_workers("list_loras")
|
||||
|
||||
def _run_workers(
|
||||
@@ -416,7 +394,3 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
|
||||
# Only the driver worker returns the sampling results.
|
||||
output = all_outputs[0]
|
||||
return output
|
||||
|
||||
async def check_health_async(self) -> None:
|
||||
"""Raises an error if engine is unhealthy."""
|
||||
self._check_if_any_actor_is_dead()
|
||||
|
||||
Reference in New Issue
Block a user