[Core] Simplify core kv-cache blocks initialization logic (#36521)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -117,18 +117,7 @@ class EngineCore:
|
||||
self._eep_scale_up_before_kv_init()
|
||||
|
||||
# Setup KV Caches and update CacheConfig after profiling.
|
||||
num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
|
||||
vllm_config
|
||||
)
|
||||
if kv_cache_config.kv_cache_groups:
|
||||
vllm_config.cache_config.block_size = min(
|
||||
g.kv_cache_spec.block_size for g in kv_cache_config.kv_cache_groups
|
||||
)
|
||||
vllm_config.validate_block_size()
|
||||
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
|
||||
|
||||
kv_cache_config = self._initialize_kv_caches(vllm_config)
|
||||
self.structured_output_manager = StructuredOutputManager(vllm_config)
|
||||
|
||||
# Setup scheduler.
|
||||
@@ -233,9 +222,7 @@ class EngineCore:
|
||||
enable_envs_cache()
|
||||
|
||||
@instrument(span_name="Prepare model")
|
||||
def _initialize_kv_caches(
|
||||
self, vllm_config: VllmConfig
|
||||
) -> tuple[int, int, KVCacheConfig]:
|
||||
def _initialize_kv_caches(self, vllm_config: VllmConfig) -> KVCacheConfig:
|
||||
start = time.time()
|
||||
|
||||
# Get all kv cache needed by the model
|
||||
@@ -276,8 +263,14 @@ class EngineCore:
|
||||
self.collective_rpc("update_max_model_len", args=(max_model_len_after,))
|
||||
|
||||
scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs)
|
||||
num_gpu_blocks = scheduler_kv_cache_config.num_blocks
|
||||
num_cpu_blocks = 0
|
||||
vllm_config.cache_config.num_gpu_blocks = scheduler_kv_cache_config.num_blocks
|
||||
kv_cache_groups = scheduler_kv_cache_config.kv_cache_groups
|
||||
if kv_cache_groups:
|
||||
vllm_config.cache_config.block_size = min(
|
||||
g.kv_cache_spec.block_size for g in kv_cache_groups
|
||||
)
|
||||
|
||||
vllm_config.validate_block_size()
|
||||
|
||||
# Initialize kv cache and warmup the execution
|
||||
self.model_executor.initialize_from_config(kv_cache_configs)
|
||||
@@ -288,7 +281,7 @@ class EngineCore:
|
||||
elapsed,
|
||||
scope="local",
|
||||
)
|
||||
return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
|
||||
return scheduler_kv_cache_config
|
||||
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
return self.model_executor.supported_tasks
|
||||
|
||||
@@ -203,21 +203,17 @@ class Worker(WorkerBase):
|
||||
self.model_runner.init_fp8_kv_scales()
|
||||
|
||||
def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
from vllm.device_allocator.cumem import CuMemAllocator
|
||||
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
if tag == "weights":
|
||||
assert allocator.get_current_usage() == 0, (
|
||||
"Sleep mode can only be used for one instance per process."
|
||||
)
|
||||
return allocator.use_memory_pool(tag=tag)
|
||||
else:
|
||||
if not self.vllm_config.model_config.enable_sleep_mode:
|
||||
return nullcontext()
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
from vllm.device_allocator.cumem import CuMemAllocator
|
||||
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
if tag == "weights":
|
||||
assert allocator.get_current_usage() == 0, (
|
||||
"Sleep mode can only be used for one instance per process."
|
||||
)
|
||||
return allocator.use_memory_pool(tag=tag)
|
||||
|
||||
@instrument(span_name="Init device")
|
||||
def init_device(self):
|
||||
|
||||
@@ -104,10 +104,6 @@ class WorkerBase:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
|
||||
"""Initialize the KV cache with the given size in blocks."""
|
||||
raise NotImplementedError
|
||||
|
||||
def reset_mm_cache(self) -> None:
|
||||
reset_fn = getattr(self.model_runner, "reset_mm_cache", None)
|
||||
if callable(reset_fn):
|
||||
|
||||
Reference in New Issue
Block a user