Reapply [Attention] Refactor check_and_update_config (#35122)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2026-03-09 10:17:14 -04:00
committed by GitHub
parent 5578f2a4d3
commit 77a73458e3
32 changed files with 311 additions and 279 deletions

View File

@@ -2,16 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import field
from typing import Literal
from typing import ClassVar, Literal
from pydantic import Field, SkipValidation, field_validator
from pydantic import Field, SkipValidation, field_validator, model_validator
from vllm.config.utils import config
from vllm.logger import init_logger
logger = init_logger(__name__)
BlockSize = Literal[1, 8, 16, 32, 64, 128, 256]
CacheDType = Literal[
"auto",
"bfloat16",
@@ -31,12 +30,13 @@ KVOffloadingBackend = Literal["native", "lmcache"]
class CacheConfig:
"""Configuration for the KV cache."""
block_size: SkipValidation[BlockSize] = None # type: ignore[assignment]
"""Size of a contiguous cache block in number of tokens.
DEFAULT_BLOCK_SIZE: ClassVar[int] = 16
This config has no static default. If left unspecified by the user, it will
be set in `Platform.check_and_update_config()` based on the current
platform."""
block_size: SkipValidation[int] = None # type: ignore[assignment]
"""Size of a contiguous cache block in number of tokens.
Accepts None (meaning "use default"). After construction, always int."""
user_specified_block_size: bool = field(default=False, init=False)
"""Whether block_size was explicitly provided. Derived automatically."""
gpu_memory_utilization: float = Field(default=0.9, gt=0, le=1)
"""The fraction of GPU memory to be used for the model executor, which can
range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
@@ -169,6 +169,8 @@ class CacheConfig:
"prefix_caching_hash_algo",
"cpu_kvcache_space_bytes",
"mamba_page_size_padded",
"user_specified_block_size",
"_block_size_resolved",
# Post-init/derived counters
"num_gpu_blocks",
"num_cpu_blocks",
@@ -186,6 +188,22 @@ class CacheConfig:
# metrics info
return {key: str(value) for key, value in self.__dict__.items()}
_block_size_resolved: bool = field(default=False, init=False)
"""Guard against pydantic re-running _apply_block_size_default."""
@model_validator(mode="after")
def _apply_block_size_default(self) -> "CacheConfig":
# Pydantic re-runs validators when CacheConfig is nested inside
# another pydantic model (e.g. VllmConfig). Guard against that.
if self._block_size_resolved:
return self
object.__setattr__(self, "_block_size_resolved", True)
if self.block_size is None:
object.__setattr__(self, "block_size", self.DEFAULT_BLOCK_SIZE)
else:
object.__setattr__(self, "user_specified_block_size", True)
return self
@field_validator("cache_dtype", mode="after")
@classmethod
def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType:

View File

@@ -1026,32 +1026,6 @@ class VllmConfig:
)
current_platform.check_and_update_config(self)
# If DCP, ensure the block size is right.
if self.parallel_config.decode_context_parallel_size > 1:
if self.parallel_config.dcp_kv_cache_interleave_size > 1 and (
self.parallel_config.cp_kv_cache_interleave_size
!= self.parallel_config.dcp_kv_cache_interleave_size
):
self.parallel_config.cp_kv_cache_interleave_size = (
self.parallel_config.dcp_kv_cache_interleave_size
)
logger.warning_once(
"cp_kv_cache_interleave_size is overridden by dcp_kv_cache"
"_interleave_size. And dcp-kv-cache-interleave-size will be "
"deprecated when PCP is fully supported."
)
assert (
self.parallel_config.cp_kv_cache_interleave_size
<= self.cache_config.block_size
and self.cache_config.block_size
% self.parallel_config.cp_kv_cache_interleave_size
== 0
), (
f"Block_size({self.cache_config.block_size}) should be greater "
"than or equal to and divisible by cp_kv_cache_interleave_size "
f"({self.parallel_config.cp_kv_cache_interleave_size})."
)
# Do this after all the updates to compilation_config.mode
effective_dp_size = (
self.parallel_config.data_parallel_size
@@ -1219,26 +1193,6 @@ class VllmConfig:
# Default to enable HMA if not explicitly disabled by user or logic above.
self.scheduler_config.disable_hybrid_kv_cache_manager = False
if self.cache_config.mamba_cache_mode == "align":
assert (
self.cache_config.block_size
<= self.scheduler_config.max_num_batched_tokens
), (
"In Mamba cache align mode, block_size "
f"({self.cache_config.block_size}) must be <= "
"max_num_batched_tokens "
f"({self.scheduler_config.max_num_batched_tokens})."
)
if self.scheduler_config.long_prefill_token_threshold > 0:
assert (
self.scheduler_config.long_prefill_token_threshold
>= self.cache_config.block_size
)
assert not self.scheduler_config.disable_chunked_mm_input, (
"Chunked MM input is required because we need the flexibility to "
"schedule a multiple of block_size tokens even if they are in the "
"middle of a mm input"
)
if self.compilation_config.debug_dump_path:
self.compilation_config.debug_dump_path = (
self.compilation_config.debug_dump_path.absolute().expanduser()
@@ -1673,6 +1627,53 @@ class VllmConfig:
f"compilation_config={self.compilation_config!r}"
)
def validate_block_size(self) -> None:
"""Validate block_size against DCP and mamba constraints.
Called after Platform.update_block_size_for_backend() has
finalised block_size.
"""
block_size = self.cache_config.block_size
# DCP interleave-size compatibility
if self.parallel_config.decode_context_parallel_size > 1:
if self.parallel_config.dcp_kv_cache_interleave_size > 1 and (
self.parallel_config.cp_kv_cache_interleave_size
!= self.parallel_config.dcp_kv_cache_interleave_size
):
self.parallel_config.cp_kv_cache_interleave_size = (
self.parallel_config.dcp_kv_cache_interleave_size
)
logger.warning_once(
"cp_kv_cache_interleave_size is overridden by dcp_kv_cache"
"_interleave_size. And dcp-kv-cache-interleave-size will be "
"deprecated when PCP is fully supported."
)
assert (
self.parallel_config.cp_kv_cache_interleave_size <= block_size
and block_size % self.parallel_config.cp_kv_cache_interleave_size == 0
), (
f"Block_size({block_size}) should be greater "
"than or equal to and divisible by cp_kv_cache_interleave_size "
f"({self.parallel_config.cp_kv_cache_interleave_size})."
)
# Mamba cache align-mode constraints
if self.cache_config.mamba_cache_mode == "align":
assert block_size <= self.scheduler_config.max_num_batched_tokens, (
"In Mamba cache align mode, block_size "
f"({block_size}) must be <= "
"max_num_batched_tokens "
f"({self.scheduler_config.max_num_batched_tokens})."
)
if self.scheduler_config.long_prefill_token_threshold > 0:
assert self.scheduler_config.long_prefill_token_threshold >= block_size
assert not self.scheduler_config.disable_chunked_mm_input, (
"Chunked MM input is required because we need the flexibility "
"to schedule a multiple of block_size tokens even if they are "
"in the middle of a mm input"
)
@model_validator(mode="after")
def validate_mamba_block_size(self) -> "VllmConfig":
if self.model_config is None: