Reapply [Attention] Refactor check_and_update_config (#35122)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -2,16 +2,15 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import field
|
||||
from typing import Literal
|
||||
from typing import ClassVar, Literal
|
||||
|
||||
from pydantic import Field, SkipValidation, field_validator
|
||||
from pydantic import Field, SkipValidation, field_validator, model_validator
|
||||
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
BlockSize = Literal[1, 8, 16, 32, 64, 128, 256]
|
||||
CacheDType = Literal[
|
||||
"auto",
|
||||
"bfloat16",
|
||||
@@ -31,12 +30,13 @@ KVOffloadingBackend = Literal["native", "lmcache"]
|
||||
class CacheConfig:
|
||||
"""Configuration for the KV cache."""
|
||||
|
||||
block_size: SkipValidation[BlockSize] = None # type: ignore[assignment]
|
||||
"""Size of a contiguous cache block in number of tokens.
|
||||
DEFAULT_BLOCK_SIZE: ClassVar[int] = 16
|
||||
|
||||
This config has no static default. If left unspecified by the user, it will
|
||||
be set in `Platform.check_and_update_config()` based on the current
|
||||
platform."""
|
||||
block_size: SkipValidation[int] = None # type: ignore[assignment]
|
||||
"""Size of a contiguous cache block in number of tokens.
|
||||
Accepts None (meaning "use default"). After construction, always int."""
|
||||
user_specified_block_size: bool = field(default=False, init=False)
|
||||
"""Whether block_size was explicitly provided. Derived automatically."""
|
||||
gpu_memory_utilization: float = Field(default=0.9, gt=0, le=1)
|
||||
"""The fraction of GPU memory to be used for the model executor, which can
|
||||
range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
|
||||
@@ -169,6 +169,8 @@ class CacheConfig:
|
||||
"prefix_caching_hash_algo",
|
||||
"cpu_kvcache_space_bytes",
|
||||
"mamba_page_size_padded",
|
||||
"user_specified_block_size",
|
||||
"_block_size_resolved",
|
||||
# Post-init/derived counters
|
||||
"num_gpu_blocks",
|
||||
"num_cpu_blocks",
|
||||
@@ -186,6 +188,22 @@ class CacheConfig:
|
||||
# metrics info
|
||||
return {key: str(value) for key, value in self.__dict__.items()}
|
||||
|
||||
_block_size_resolved: bool = field(default=False, init=False)
|
||||
"""Guard against pydantic re-running _apply_block_size_default."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _apply_block_size_default(self) -> "CacheConfig":
|
||||
# Pydantic re-runs validators when CacheConfig is nested inside
|
||||
# another pydantic model (e.g. VllmConfig). Guard against that.
|
||||
if self._block_size_resolved:
|
||||
return self
|
||||
object.__setattr__(self, "_block_size_resolved", True)
|
||||
if self.block_size is None:
|
||||
object.__setattr__(self, "block_size", self.DEFAULT_BLOCK_SIZE)
|
||||
else:
|
||||
object.__setattr__(self, "user_specified_block_size", True)
|
||||
return self
|
||||
|
||||
@field_validator("cache_dtype", mode="after")
|
||||
@classmethod
|
||||
def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType:
|
||||
|
||||
@@ -1026,32 +1026,6 @@ class VllmConfig:
|
||||
)
|
||||
current_platform.check_and_update_config(self)
|
||||
|
||||
# If DCP, ensure the block size is right.
|
||||
if self.parallel_config.decode_context_parallel_size > 1:
|
||||
if self.parallel_config.dcp_kv_cache_interleave_size > 1 and (
|
||||
self.parallel_config.cp_kv_cache_interleave_size
|
||||
!= self.parallel_config.dcp_kv_cache_interleave_size
|
||||
):
|
||||
self.parallel_config.cp_kv_cache_interleave_size = (
|
||||
self.parallel_config.dcp_kv_cache_interleave_size
|
||||
)
|
||||
logger.warning_once(
|
||||
"cp_kv_cache_interleave_size is overridden by dcp_kv_cache"
|
||||
"_interleave_size. And dcp-kv-cache-interleave-size will be "
|
||||
"deprecated when PCP is fully supported."
|
||||
)
|
||||
assert (
|
||||
self.parallel_config.cp_kv_cache_interleave_size
|
||||
<= self.cache_config.block_size
|
||||
and self.cache_config.block_size
|
||||
% self.parallel_config.cp_kv_cache_interleave_size
|
||||
== 0
|
||||
), (
|
||||
f"Block_size({self.cache_config.block_size}) should be greater "
|
||||
"than or equal to and divisible by cp_kv_cache_interleave_size "
|
||||
f"({self.parallel_config.cp_kv_cache_interleave_size})."
|
||||
)
|
||||
|
||||
# Do this after all the updates to compilation_config.mode
|
||||
effective_dp_size = (
|
||||
self.parallel_config.data_parallel_size
|
||||
@@ -1219,26 +1193,6 @@ class VllmConfig:
|
||||
# Default to enable HMA if not explicitly disabled by user or logic above.
|
||||
self.scheduler_config.disable_hybrid_kv_cache_manager = False
|
||||
|
||||
if self.cache_config.mamba_cache_mode == "align":
|
||||
assert (
|
||||
self.cache_config.block_size
|
||||
<= self.scheduler_config.max_num_batched_tokens
|
||||
), (
|
||||
"In Mamba cache align mode, block_size "
|
||||
f"({self.cache_config.block_size}) must be <= "
|
||||
"max_num_batched_tokens "
|
||||
f"({self.scheduler_config.max_num_batched_tokens})."
|
||||
)
|
||||
if self.scheduler_config.long_prefill_token_threshold > 0:
|
||||
assert (
|
||||
self.scheduler_config.long_prefill_token_threshold
|
||||
>= self.cache_config.block_size
|
||||
)
|
||||
assert not self.scheduler_config.disable_chunked_mm_input, (
|
||||
"Chunked MM input is required because we need the flexibility to "
|
||||
"schedule a multiple of block_size tokens even if they are in the "
|
||||
"middle of a mm input"
|
||||
)
|
||||
if self.compilation_config.debug_dump_path:
|
||||
self.compilation_config.debug_dump_path = (
|
||||
self.compilation_config.debug_dump_path.absolute().expanduser()
|
||||
@@ -1673,6 +1627,53 @@ class VllmConfig:
|
||||
f"compilation_config={self.compilation_config!r}"
|
||||
)
|
||||
|
||||
def validate_block_size(self) -> None:
|
||||
"""Validate block_size against DCP and mamba constraints.
|
||||
|
||||
Called after Platform.update_block_size_for_backend() has
|
||||
finalised block_size.
|
||||
"""
|
||||
block_size = self.cache_config.block_size
|
||||
|
||||
# DCP interleave-size compatibility
|
||||
if self.parallel_config.decode_context_parallel_size > 1:
|
||||
if self.parallel_config.dcp_kv_cache_interleave_size > 1 and (
|
||||
self.parallel_config.cp_kv_cache_interleave_size
|
||||
!= self.parallel_config.dcp_kv_cache_interleave_size
|
||||
):
|
||||
self.parallel_config.cp_kv_cache_interleave_size = (
|
||||
self.parallel_config.dcp_kv_cache_interleave_size
|
||||
)
|
||||
logger.warning_once(
|
||||
"cp_kv_cache_interleave_size is overridden by dcp_kv_cache"
|
||||
"_interleave_size. And dcp-kv-cache-interleave-size will be "
|
||||
"deprecated when PCP is fully supported."
|
||||
)
|
||||
assert (
|
||||
self.parallel_config.cp_kv_cache_interleave_size <= block_size
|
||||
and block_size % self.parallel_config.cp_kv_cache_interleave_size == 0
|
||||
), (
|
||||
f"Block_size({block_size}) should be greater "
|
||||
"than or equal to and divisible by cp_kv_cache_interleave_size "
|
||||
f"({self.parallel_config.cp_kv_cache_interleave_size})."
|
||||
)
|
||||
|
||||
# Mamba cache align-mode constraints
|
||||
if self.cache_config.mamba_cache_mode == "align":
|
||||
assert block_size <= self.scheduler_config.max_num_batched_tokens, (
|
||||
"In Mamba cache align mode, block_size "
|
||||
f"({block_size}) must be <= "
|
||||
"max_num_batched_tokens "
|
||||
f"({self.scheduler_config.max_num_batched_tokens})."
|
||||
)
|
||||
if self.scheduler_config.long_prefill_token_threshold > 0:
|
||||
assert self.scheduler_config.long_prefill_token_threshold >= block_size
|
||||
assert not self.scheduler_config.disable_chunked_mm_input, (
|
||||
"Chunked MM input is required because we need the flexibility "
|
||||
"to schedule a multiple of block_size tokens even if they are "
|
||||
"in the middle of a mm input"
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_mamba_block_size(self) -> "VllmConfig":
|
||||
if self.model_config is None:
|
||||
|
||||
Reference in New Issue
Block a user