[Core] Simplifications to executor classes (#4071)

This commit is contained in:
Nick Hill
2024-04-15 13:05:09 -07:00
committed by GitHub
parent 0003e9154b
commit eb46fbfda2
5 changed files with 44 additions and 109 deletions

View File

@@ -1,10 +1,9 @@
import os
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Set, Tuple
import torch
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@@ -16,23 +15,13 @@ logger = init_logger(__name__)
class CPUExecutor(ExecutorBase):
def __init__(self, model_config: ModelConfig, cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], *args, **kwargs) -> None:
assert device_config.device_type == "cpu"
assert lora_config is None, "cpu backend doesn't support LoRA"
model_config = _verify_and_get_model_config(model_config)
cache_config = _verify_and_get_cache_config(cache_config)
scheduler_config = _verify_and_get_scheduler_config(scheduler_config)
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
def _init_executor(self) -> None:
assert self.device_config.device_type == "cpu"
assert self.lora_config is None, "cpu backend doesn't support LoRA"
self.model_config = _verify_and_get_model_config(self.model_config)
self.cache_config = _verify_and_get_cache_config(self.cache_config)
self.scheduler_config = _verify_and_get_scheduler_config(
self.scheduler_config)
# Instantiate the worker and load the model to CPU.
self._init_worker()
@@ -99,7 +88,7 @@ class CPUExecutor(ExecutorBase):
def remove_lora(self, lora_id: int) -> bool:
return self.driver_worker.remove_lora(lora_id)
def list_loras(self) -> List[int]:
def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()
def check_health(self) -> None: