[Core] Simplify core kv-cache blocks initialization logic (#36521)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2026-03-10 13:20:02 -07:00
committed by GitHub
parent 2a68464c5b
commit 65b2f405dc
4 changed files with 28 additions and 37 deletions

View File

@@ -117,18 +117,7 @@ class EngineCore:
self._eep_scale_up_before_kv_init()
# Setup KV Caches and update CacheConfig after profiling.
num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
vllm_config
)
if kv_cache_config.kv_cache_groups:
vllm_config.cache_config.block_size = min(
g.kv_cache_spec.block_size for g in kv_cache_config.kv_cache_groups
)
vllm_config.validate_block_size()
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
kv_cache_config = self._initialize_kv_caches(vllm_config)
self.structured_output_manager = StructuredOutputManager(vllm_config)
# Setup scheduler.
@@ -233,9 +222,7 @@ class EngineCore:
enable_envs_cache()
@instrument(span_name="Prepare model")
def _initialize_kv_caches(
self, vllm_config: VllmConfig
) -> tuple[int, int, KVCacheConfig]:
def _initialize_kv_caches(self, vllm_config: VllmConfig) -> KVCacheConfig:
start = time.time()
# Get all kv cache needed by the model
@@ -276,8 +263,14 @@ class EngineCore:
self.collective_rpc("update_max_model_len", args=(max_model_len_after,))
scheduler_kv_cache_config = generate_scheduler_kv_cache_config(kv_cache_configs)
num_gpu_blocks = scheduler_kv_cache_config.num_blocks
num_cpu_blocks = 0
vllm_config.cache_config.num_gpu_blocks = scheduler_kv_cache_config.num_blocks
kv_cache_groups = scheduler_kv_cache_config.kv_cache_groups
if kv_cache_groups:
vllm_config.cache_config.block_size = min(
g.kv_cache_spec.block_size for g in kv_cache_groups
)
vllm_config.validate_block_size()
# Initialize kv cache and warmup the execution
self.model_executor.initialize_from_config(kv_cache_configs)
@@ -288,7 +281,7 @@ class EngineCore:
elapsed,
scope="local",
)
return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
return scheduler_kv_cache_config
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.model_executor.supported_tasks

View File

@@ -203,21 +203,17 @@ class Worker(WorkerBase):
self.model_runner.init_fp8_kv_scales()
def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager:
if self.vllm_config.model_config.enable_sleep_mode:
from vllm.device_allocator.cumem import CuMemAllocator
allocator = CuMemAllocator.get_instance()
if tag == "weights":
assert allocator.get_current_usage() == 0, (
"Sleep mode can only be used for one instance per process."
)
return allocator.use_memory_pool(tag=tag)
else:
if not self.vllm_config.model_config.enable_sleep_mode:
return nullcontext()
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
from vllm.device_allocator.cumem import CuMemAllocator
allocator = CuMemAllocator.get_instance()
if tag == "weights":
assert allocator.get_current_usage() == 0, (
"Sleep mode can only be used for one instance per process."
)
return allocator.use_memory_pool(tag=tag)
@instrument(span_name="Init device")
def init_device(self):

View File

@@ -104,10 +104,6 @@ class WorkerBase:
"""
raise NotImplementedError
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
"""Initialize the KV cache with the given size in blocks."""
raise NotImplementedError
def reset_mm_cache(self) -> None:
reset_fn = getattr(self.model_runner, "reset_mm_cache", None)
if callable(reset_fn):