[Core] Simplifications to executor classes (#4071)
This commit is contained in:
@@ -1,10 +1,9 @@
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@@ -16,23 +15,13 @@ logger = init_logger(__name__)
|
||||
|
||||
class CPUExecutor(ExecutorBase):
|
||||
|
||||
def __init__(self, model_config: ModelConfig, cache_config: CacheConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig], *args, **kwargs) -> None:
|
||||
assert device_config.device_type == "cpu"
|
||||
assert lora_config is None, "cpu backend doesn't support LoRA"
|
||||
model_config = _verify_and_get_model_config(model_config)
|
||||
cache_config = _verify_and_get_cache_config(cache_config)
|
||||
scheduler_config = _verify_and_get_scheduler_config(scheduler_config)
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
def _init_executor(self) -> None:
|
||||
assert self.device_config.device_type == "cpu"
|
||||
assert self.lora_config is None, "cpu backend doesn't support LoRA"
|
||||
self.model_config = _verify_and_get_model_config(self.model_config)
|
||||
self.cache_config = _verify_and_get_cache_config(self.cache_config)
|
||||
self.scheduler_config = _verify_and_get_scheduler_config(
|
||||
self.scheduler_config)
|
||||
|
||||
# Instantiate the worker and load the model to CPU.
|
||||
self._init_worker()
|
||||
@@ -99,7 +88,7 @@ class CPUExecutor(ExecutorBase):
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self.driver_worker.remove_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> List[int]:
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.driver_worker.list_loras()
|
||||
|
||||
def check_health(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user