[Misc] [Core] Implement RFC "Augment BaseExecutor interfaces to enable hardware-agnostic speculative decoding" (#3837)
This commit is contained in:
@@ -35,7 +35,6 @@ class CPUExecutor(ExecutorBase):
|
||||
|
||||
# Instantiate the worker and load the model to CPU.
|
||||
self._init_worker()
|
||||
self._init_cache()
|
||||
|
||||
def _init_worker(self):
|
||||
from vllm.worker.cpu_worker import CPUWorker
|
||||
@@ -46,10 +45,11 @@ class CPUExecutor(ExecutorBase):
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
self.driver_worker = CPUWorker(
|
||||
self.model_config,
|
||||
self.parallel_config,
|
||||
self.scheduler_config,
|
||||
self.device_config,
|
||||
model_config=self.model_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
device_config=self.device_config,
|
||||
cache_config=self.cache_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
@@ -60,35 +60,21 @@ class CPUExecutor(ExecutorBase):
|
||||
self.driver_worker.init_device()
|
||||
self.driver_worker.load_model()
|
||||
|
||||
def _init_cache(self) -> None:
|
||||
num_cpu_blocks = self.driver_worker.get_cpu_cache_block_num(
|
||||
block_size=self.cache_config.block_size,
|
||||
cache_space=self.cache_config.cpu_kvcache_space_bytes,
|
||||
cache_dtype=self.cache_config.cache_dtype,
|
||||
)
|
||||
def determine_num_available_blocks(self) -> tuple[int, int]:
|
||||
"""Determine the number of available KV blocks by invoking the
|
||||
underlying worker.
|
||||
"""
|
||||
return self.driver_worker.determine_num_available_blocks()
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
"""Initialize the KV cache by invoking the underlying worker.
|
||||
"""
|
||||
# NOTE: We log here to avoid multiple logs when number of workers is
|
||||
# greater than one. We could log in the engine, but not all executors
|
||||
# have GPUs.
|
||||
logger.info(f"# CPU blocks: {num_cpu_blocks}")
|
||||
if num_cpu_blocks <= 0:
|
||||
raise ValueError("No available memory for the cache blocks. "
|
||||
"Try increasing `VLLM_CPU_KVCACHE_SPACE` when "
|
||||
"initializing the engine.")
|
||||
|
||||
max_seq_len = self.cache_config.block_size * num_cpu_blocks
|
||||
if self.model_config.max_model_len > max_seq_len:
|
||||
raise ValueError(
|
||||
f"The model's max seq len ({self.model_config.max_model_len}) "
|
||||
"is larger than the maximum number of tokens that can be "
|
||||
f"stored in KV cache ({max_seq_len}). Try increasing "
|
||||
"`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when "
|
||||
"initializing the engine.")
|
||||
|
||||
# Note: To reuse the cache management procedure,
|
||||
# use cpu cache as 'gpu cache'.
|
||||
self.cache_config.num_gpu_blocks = num_cpu_blocks # type: ignore
|
||||
self.cache_config.num_cpu_blocks = 0 # type: ignore
|
||||
|
||||
# Initialize the cache.
|
||||
self.driver_worker.init_cache_engine(cache_config=self.cache_config)
|
||||
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||
|
||||
def execute_model(self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
@@ -104,13 +90,13 @@ class CPUExecutor(ExecutorBase):
|
||||
return output
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
raise NotImplementedError("LoRA is not implemented for cpu backend.")
|
||||
return self.driver_worker.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError("LoRA is not implemented for cpu backend.")
|
||||
return self.driver_worker.remove_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> List[int]:
|
||||
raise NotImplementedError("LoRA is not implemented for cpu backend.")
|
||||
return self.driver_worker.list_loras()
|
||||
|
||||
def check_health(self) -> None:
|
||||
# CPUExecutor will always be healthy as long as
|
||||
|
||||
Reference in New Issue
Block a user