Improve configs - CacheConfig (#16835)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
122
vllm/config.py
122
vllm/config.py
@@ -1245,22 +1245,70 @@ class ModelConfig:
|
||||
or getattr(self.hf_config, "is_matryoshka", False))
|
||||
|
||||
|
||||
class CacheConfig:
|
||||
"""Configuration for the KV cache.
|
||||
BlockSize = Literal[8, 16, 32, 64, 128]
|
||||
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"]
|
||||
PrefixCachingHashAlgo = Literal["builtin", "sha256"]
|
||||
|
||||
Args:
|
||||
block_size: Size of a cache block in number of tokens.
|
||||
gpu_memory_utilization: Fraction of GPU memory to use for the
|
||||
vLLM execution.
|
||||
swap_space: Size of the CPU swap space per GPU (in GiB).
|
||||
cache_dtype: Data type for kv cache storage.
|
||||
is_attention_free: Whether the model is attention-free.
|
||||
num_gpu_blocks_override: Number of GPU blocks to use. This overrides the
|
||||
profiled num_gpu_blocks if specified. Does nothing if None.
|
||||
sliding_window: Sliding window size for the KV cache.
|
||||
enable_prefix_caching: Whether to enable prefix caching.
|
||||
cpu_offload_gb: Size of the CPU offload buffer in GiB.
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class CacheConfig:
|
||||
"""Configuration for the KV cache."""
|
||||
|
||||
block_size: Optional[BlockSize] = None
|
||||
"""Size of a contiguous cache block in number of tokens. This is ignored on
|
||||
neuron devices and set to `--max-model-len`. On CUDA devices, only block
|
||||
sizes up to 32 are supported. On HPU devices, block size defaults to 128.
|
||||
"""
|
||||
gpu_memory_utilization: float = 0.9
|
||||
"""The fraction of GPU memory to be used for the model executor, which can
|
||||
range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
|
||||
utilization. If unspecified, will use the default value of 0.9. This is a
|
||||
per-instance limit, and only applies to the current vLLM instance. It does
|
||||
not matter if you have another vLLM instance running on the same GPU. For
|
||||
example, if you have two vLLM instances running on the same GPU, you can
|
||||
set the GPU memory utilization to 0.5 for each instance."""
|
||||
swap_space: float = 4
|
||||
"""Size of the CPU swap space per GPU (in GiB)."""
|
||||
cache_dtype: CacheDType = "auto"
|
||||
"""Data type for kv cache storage. If "auto", will use model data type.
|
||||
CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports
|
||||
fp8 (=fp8_e4m3)."""
|
||||
is_attention_free: bool = False
|
||||
"""Whether the model is attention-free. This is primarily set in
|
||||
`ModelConfig` and that value should be manually duplicated here."""
|
||||
num_gpu_blocks_override: Optional[int] = None
|
||||
"""Number of GPU blocks to use. This overrides the profiled `num_gpu_blocks`
|
||||
if specified. Does nothing if `None`. Used for testing preemption."""
|
||||
sliding_window: Optional[int] = None
|
||||
"""Sliding window size for the KV cache. This is primarily set in
|
||||
`ModelConfig` and that value should be manually duplicated here."""
|
||||
enable_prefix_caching: Optional[bool] = None
|
||||
"""Whether to enable prefix caching. Disabled by default for V0. Enabled by
|
||||
default for V1."""
|
||||
prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin"
|
||||
"""Set the hash algorithm for prefix caching:\n
|
||||
- "builtin" is Python's built-in hash.\n
|
||||
- "sha256" is collision resistant but with certain overheads."""
|
||||
cpu_offload_gb: float = 0
|
||||
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
|
||||
no offloading. Intuitively, this argument can be seen as a virtual way to
|
||||
increase the GPU memory size. For example, if you have one 24 GB GPU and
|
||||
set this to 10, virtually you can think of it as a 34 GB GPU. Then you can
|
||||
load a 13B model with BF16 weight, which requires at least 26GB GPU memory.
|
||||
Note that this requires fast CPU-GPU interconnect, as part of the model is
|
||||
loaded from CPU memory to GPU memory on the fly in each model forward pass.
|
||||
"""
|
||||
calculate_kv_scales: bool = False
|
||||
"""This enables dynamic calculation of `k_scale` and `v_scale` when
|
||||
kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model
|
||||
checkpoint if available. Otherwise, the scales will default to 1.0."""
|
||||
|
||||
# Will be set after profiling.
|
||||
num_gpu_blocks: Optional[int] = field(default=None, init=False)
|
||||
"""The number of blocks to allocate for GPU memory."""
|
||||
num_cpu_blocks: Optional[int] = field(default=None, init=False)
|
||||
"""The number of blocks to allocate for CPU memory."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
@@ -1281,43 +1329,13 @@ class CacheConfig:
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
gpu_memory_utilization: float,
|
||||
swap_space: float,
|
||||
cache_dtype: str,
|
||||
is_attention_free: bool = False,
|
||||
num_gpu_blocks_override: Optional[int] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
enable_prefix_caching: bool = False,
|
||||
prefix_caching_hash_algo: str = "builtin",
|
||||
cpu_offload_gb: float = 0,
|
||||
calculate_kv_scales: Optional[bool] = None,
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
self.gpu_memory_utilization = gpu_memory_utilization
|
||||
self.swap_space_bytes = swap_space * GiB_bytes
|
||||
self.num_gpu_blocks_override = num_gpu_blocks_override
|
||||
self.cache_dtype = cache_dtype
|
||||
self.is_attention_free = is_attention_free
|
||||
self.sliding_window = sliding_window
|
||||
self.enable_prefix_caching = enable_prefix_caching
|
||||
self.prefix_caching_hash_algo = prefix_caching_hash_algo
|
||||
self.cpu_offload_gb = cpu_offload_gb
|
||||
self.calculate_kv_scales = calculate_kv_scales
|
||||
def __post_init__(self) -> None:
|
||||
self.swap_space_bytes = self.swap_space * GiB_bytes
|
||||
|
||||
self._verify_args()
|
||||
self._verify_cache_dtype()
|
||||
self._verify_prefix_caching()
|
||||
|
||||
# Will be set after profiling.
|
||||
self.num_gpu_blocks: Optional[int] = None
|
||||
self.num_cpu_blocks: Optional[int] = None
|
||||
|
||||
# Set calculate_kv_scales to False if the value is unset.
|
||||
if self.calculate_kv_scales is None:
|
||||
self.calculate_kv_scales = False
|
||||
|
||||
def metrics_info(self):
|
||||
# convert cache_config to dict(key: str, value: str) for prometheus
|
||||
# metrics info
|
||||
@@ -1336,7 +1354,7 @@ class CacheConfig:
|
||||
def _verify_cache_dtype(self) -> None:
|
||||
if self.cache_dtype == "auto":
|
||||
pass
|
||||
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
|
||||
elif self.cache_dtype in get_args(CacheDType):
|
||||
logger.info(
|
||||
"Using fp8 data type to store kv cache. It reduces the GPU "
|
||||
"memory footprint and boosts the performance. "
|
||||
@@ -1354,12 +1372,12 @@ class CacheConfig:
|
||||
"Prefix caching is not supported with sliding window. "
|
||||
"Run with --disable-sliding-window to use prefix caching.")
|
||||
|
||||
if self.enable_prefix_caching and self.prefix_caching_hash_algo not in (
|
||||
"builtin", "sha256"):
|
||||
if (self.enable_prefix_caching and self.prefix_caching_hash_algo
|
||||
not in get_args(PrefixCachingHashAlgo)):
|
||||
raise ValueError(
|
||||
"Unknown prefix caching hash algorithm: "
|
||||
f"{self.prefix_caching_hash_algo}. Must be either "
|
||||
"'builtin' or 'sha256'.")
|
||||
f"{self.prefix_caching_hash_algo}. Must be one of "
|
||||
f"{get_args(PrefixCachingHashAlgo)}.")
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user