[Misc] [Core] Implement RFC "Augment BaseExecutor interfaces to enable hardware-agnostic speculative decoding" (#3837)
This commit is contained in:
@@ -10,7 +10,6 @@ from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.engine.ray_utils import RayWorkerVllm, ray
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||
from vllm.executor.utils import check_block_size_valid
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
@@ -65,9 +64,6 @@ class RayGPUExecutor(ExecutorBase):
|
||||
# Create the parallel GPU workers.
|
||||
self._init_workers_ray(placement_group)
|
||||
|
||||
# Profile the memory usage and initialize the cache.
|
||||
self._init_cache()
|
||||
|
||||
self.forward_dag = None
|
||||
if USE_RAY_COMPILED_DAG:
|
||||
self.forward_dag = self._compiled_ray_dag()
|
||||
@@ -154,8 +150,8 @@ class RayGPUExecutor(ExecutorBase):
|
||||
scheduler_config = copy.deepcopy(self.scheduler_config)
|
||||
device_config = copy.deepcopy(self.device_config)
|
||||
lora_config = copy.deepcopy(self.lora_config)
|
||||
cache_config = copy.deepcopy(self.cache_config)
|
||||
vision_language_config = copy.deepcopy(self.vision_language_config)
|
||||
kv_cache_dtype = self.cache_config.cache_dtype
|
||||
|
||||
# Initialize the actual workers with the Worker class.
|
||||
for rank, (worker, (node_id, _)) in enumerate(
|
||||
@@ -165,32 +161,32 @@ class RayGPUExecutor(ExecutorBase):
|
||||
local_rank = node_workers[node_id].index(rank)
|
||||
worker.init_worker.remote(
|
||||
lambda rank=rank, local_rank=local_rank: Worker(
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
device_config,
|
||||
local_rank,
|
||||
rank,
|
||||
distributed_init_method,
|
||||
model_config=model_config,
|
||||
parallel_config=parallel_config,
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=device_config,
|
||||
cache_config=cache_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
lora_config=lora_config,
|
||||
vision_language_config=vision_language_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
))
|
||||
|
||||
# Initialize the driver worker with the Worker class.
|
||||
driver_rank = 0
|
||||
driver_local_rank = node_workers[driver_node_id].index(driver_rank)
|
||||
self.driver_worker = Worker(
|
||||
self.model_config,
|
||||
self.parallel_config,
|
||||
self.scheduler_config,
|
||||
self.device_config,
|
||||
driver_local_rank,
|
||||
driver_rank,
|
||||
distributed_init_method,
|
||||
model_config=self.model_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
device_config=self.device_config,
|
||||
cache_config=self.cache_config,
|
||||
local_rank=driver_local_rank,
|
||||
rank=driver_rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
vision_language_config=self.vision_language_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
|
||||
@@ -201,35 +197,18 @@ class RayGPUExecutor(ExecutorBase):
|
||||
max_parallel_loading_workers,
|
||||
)
|
||||
|
||||
def _init_cache(self) -> None:
|
||||
"""Profiles the memory usage and initializes the KV cache.
|
||||
def determine_num_available_blocks(self) -> tuple[int, int]:
|
||||
"""Determine the number of available KV blocks.
|
||||
|
||||
The engine will first conduct a profiling of the existing memory usage.
|
||||
Then, it calculate the maximum possible number of GPU and CPU blocks
|
||||
that can be allocated with the remaining free memory.
|
||||
More details can be found in the
|
||||
:meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method
|
||||
from class :class:`~vllm.worker.Worker`.
|
||||
This invokes `determine_num_available_blocks` on each worker and takes
|
||||
the min of the results, guaranteeing that the selected cache sizes are
|
||||
compatible with all workers.
|
||||
|
||||
Afterwards, as there may be multiple workers,
|
||||
we take the minimum number of blocks across all workers
|
||||
to ensure this can be applied to all of them.
|
||||
|
||||
Finally, the engine will initialize the KV cache
|
||||
with the calculated number of blocks.
|
||||
|
||||
.. tip::
|
||||
You may limit the usage of GPU memory
|
||||
by adjusting the `gpu_memory_utilization` parameter.
|
||||
Returns:
|
||||
- tuple[num_gpu_blocks, num_cpu_blocks]
|
||||
"""
|
||||
# Get the maximum number of blocks that can be allocated on GPU and CPU.
|
||||
num_blocks = self._run_workers(
|
||||
"profile_num_available_blocks",
|
||||
block_size=self.cache_config.block_size,
|
||||
gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
|
||||
cpu_swap_space=self.cache_config.swap_space_bytes,
|
||||
cache_dtype=self.cache_config.cache_dtype,
|
||||
)
|
||||
num_blocks = self._run_workers("determine_num_available_blocks", )
|
||||
|
||||
# Since we use a shared centralized controller, we take the minimum
|
||||
# number of blocks across all workers to make sure all the memory
|
||||
@@ -237,26 +216,25 @@ class RayGPUExecutor(ExecutorBase):
|
||||
num_gpu_blocks = min(b[0] for b in num_blocks)
|
||||
num_cpu_blocks = min(b[1] for b in num_blocks)
|
||||
|
||||
if self.cache_config.forced_num_gpu_blocks is not None:
|
||||
forced_num_gpu_blocks = self.cache_config.forced_num_gpu_blocks
|
||||
logger.info(f"Replacing profiled {num_gpu_blocks=} with "
|
||||
f"{forced_num_gpu_blocks=}")
|
||||
num_gpu_blocks = forced_num_gpu_blocks
|
||||
return num_gpu_blocks, num_cpu_blocks
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
"""Initialize the KV cache in all workers.
|
||||
"""
|
||||
|
||||
# NOTE: We log here to avoid multiple logs when number of workers is
|
||||
# greater than one. We could log in the engine, but not all executors
|
||||
# have GPUs.
|
||||
logger.info(f"# GPU blocks: {num_gpu_blocks}, "
|
||||
f"# CPU blocks: {num_cpu_blocks}")
|
||||
|
||||
check_block_size_valid(num_gpu_blocks, self.cache_config.block_size,
|
||||
self.model_config.max_model_len)
|
||||
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
# Initialize the cache.
|
||||
self._run_workers("init_cache_engine", cache_config=self.cache_config)
|
||||
# Warm up the model. This includes capturing the model into CUDA graph
|
||||
# if enforce_eager is False.
|
||||
self._run_workers("warm_up_model")
|
||||
self._run_workers("initialize_cache",
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks)
|
||||
|
||||
def execute_model(self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
|
||||
Reference in New Issue
Block a user