[V1] Move more control of kv cache initialization from model_executor to EngineCore (#11960)
Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple, Type
|
||||
from typing import Type
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
|
||||
@@ -31,11 +32,15 @@ class Executor(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def initialize(self, num_gpu_blocks: int) -> None:
|
||||
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
def determine_available_memory(self) -> int: # in bytes
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -23,6 +23,7 @@ from vllm.logger import init_logger
|
||||
from vllm.utils import (get_distributed_init_method, get_mp_context,
|
||||
get_open_port, get_open_zmq_ipc_path, zmq_socket_ctx)
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
@@ -90,29 +91,33 @@ class MultiprocExecutor(Executor):
|
||||
for w in self.workers:
|
||||
w.worker_response_mq.wait_until_ready()
|
||||
|
||||
def initialize(self, num_gpu_blocks: int) -> None:
|
||||
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize the KV caches and begin the model execution loop of the
|
||||
underlying workers.
|
||||
"""
|
||||
logger.info("# GPU blocks: %d", num_gpu_blocks)
|
||||
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, ))
|
||||
self.collective_rpc("initialize_cache", args=(kv_cache_config, ))
|
||||
self.collective_rpc("compile_or_warm_up_model")
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
def determine_available_memory(self) -> int:
|
||||
"""
|
||||
Determine the number of available KV blocks by invoking the
|
||||
Determine the available memory (in bytes) for KV cache by invoking the
|
||||
underlying worker.
|
||||
"""
|
||||
num_blocks = self.collective_rpc("determine_num_available_blocks")
|
||||
memory_sizes = self.collective_rpc("determine_available_memory")
|
||||
|
||||
# Since we use a shared centralized controller, we take the minimum
|
||||
# number of blocks across all workers to make sure all the memory
|
||||
# memory size across all workers to make sure all the memory
|
||||
# operators can be applied to all workers.
|
||||
num_gpu_blocks = min(b[0] for b in num_blocks)
|
||||
num_cpu_blocks = min(b[1] for b in num_blocks)
|
||||
return min(memory_sizes)
|
||||
|
||||
return num_gpu_blocks, num_cpu_blocks
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
"""
|
||||
Get all kv cache needed by the model by invoking the underlying worker.
|
||||
"""
|
||||
kv_cache_specs = self.collective_rpc("get_kv_cache_spec")
|
||||
assert all(s == kv_cache_specs[0] for s in kv_cache_specs)
|
||||
return kv_cache_specs[0]
|
||||
|
||||
def collective_rpc(self,
|
||||
method: str,
|
||||
|
||||
@@ -10,6 +10,7 @@ from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.executor.ray_utils import (RayWorkerWrapper,
|
||||
initialize_ray_cluster, ray)
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
if ray is not None:
|
||||
@@ -211,39 +212,40 @@ class RayExecutor(Executor):
|
||||
distributed_init_method=distributed_init_method,
|
||||
)
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
def determine_available_memory(self) -> int:
|
||||
"""
|
||||
Determine the number of available KV blocks.
|
||||
Determine the available GPU memory in bytes.
|
||||
|
||||
This invokes `determine_num_available_blocks` on each worker and takes
|
||||
This invokes `determine_available_memory` on each worker and takes
|
||||
the min of the results, guaranteeing that the selected cache sizes are
|
||||
compatible with all workers.
|
||||
|
||||
Returns:
|
||||
- tuple[num_gpu_blocks, num_cpu_blocks]
|
||||
"""
|
||||
# Get the maximum number of blocks that can be allocated on GPU and CPU.
|
||||
num_blocks = self._run_workers("determine_num_available_blocks")
|
||||
|
||||
memory_sizes = self._run_workers("determine_available_memory")
|
||||
|
||||
# Since we use a shared centralized controller, we take the minimum
|
||||
# number of blocks across all workers to make sure all the memory
|
||||
# memory size across all workers to make sure all the memory
|
||||
# operators can be applied to all workers.
|
||||
num_gpu_blocks = min(b[0] for b in num_blocks)
|
||||
num_cpu_blocks = min(b[1] for b in num_blocks)
|
||||
return min(memory_sizes)
|
||||
|
||||
return num_gpu_blocks, num_cpu_blocks
|
||||
|
||||
def initialize(self, num_gpu_blocks: int) -> None:
|
||||
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize the KV cache in all workers.
|
||||
"""
|
||||
# NOTE: This is logged in the executor because there can be >1 worker
|
||||
# with other executors. We could log in the engine level, but work
|
||||
# remains to abstract away the device for non-GPU configurations.
|
||||
logger.info("# GPU blocks: %d", num_gpu_blocks)
|
||||
self._run_workers("initialize_cache", num_gpu_blocks)
|
||||
self._run_workers("initialize_cache", kv_cache_config)
|
||||
self._run_workers("compile_or_warm_up_model")
|
||||
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
"""
|
||||
Get all kv cache needed by the model
|
||||
|
||||
This invokes `get_kv_cache_spec` on each worker and asserts that
|
||||
they are identical. The KVCacheSpec is then returned.
|
||||
"""
|
||||
kv_cache_specs = self._run_workers("get_kv_cache_spec")
|
||||
assert all(s == kv_cache_specs[0] for s in kv_cache_specs)
|
||||
return kv_cache_specs[0]
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
method: str,
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.worker.gpu_worker import Worker
|
||||
|
||||
@@ -49,20 +50,22 @@ class UniprocExecutor(Executor):
|
||||
distributed_init_method=distributed_init_method,
|
||||
)
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of available KV blocks by invoking the
|
||||
underlying worker.
|
||||
def determine_available_memory(self) -> int:
|
||||
"""Determine the available memory (in bytes) for KV cache by invoking
|
||||
the underlying worker.
|
||||
"""
|
||||
return self.worker.determine_num_available_blocks()
|
||||
return self.worker.determine_available_memory()
|
||||
|
||||
def initialize(self, num_gpu_blocks: int) -> None:
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
"""Get all kv cache needed by the model by invoking the underlying
|
||||
worker.
|
||||
"""
|
||||
return self.worker.get_kv_cache_spec()
|
||||
|
||||
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""Initialize the KV cache by invoking the underlying worker.
|
||||
"""
|
||||
# NOTE: This is logged in the executor because there can be >1 worker
|
||||
# with other executors. We could log in the engine level, but work
|
||||
# remains to abstract away the device for non-GPU configurations.
|
||||
logger.info("# GPU blocks: %d", num_gpu_blocks)
|
||||
self.worker.initialize_cache(num_gpu_blocks)
|
||||
self.worker.initialize_cache(kv_cache_config)
|
||||
self.worker.compile_or_warm_up_model()
|
||||
|
||||
def execute_model(
|
||||
|
||||
Reference in New Issue
Block a user