[KVCache] Make KVCacheSpec hashable (#21791)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-07-29 04:58:29 -07:00
committed by GitHub
parent 2470419119
commit 755fa8b657
5 changed files with 100 additions and 88 deletions

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from dataclasses import dataclass
from dataclasses import dataclass, fields
from math import prod
from typing import Optional
@@ -16,7 +16,7 @@ from vllm.utils import cdiv, get_dtype_size
logger = init_logger(__name__)
@dataclass
@dataclass(frozen=True)
class KVCacheSpec:
"""
A base class for specifying the KV cache format of one layer.
@@ -25,20 +25,6 @@ class KVCacheSpec:
# number of tokens in a block
block_size: int
@property
def type_id(self) -> str:
"""
The type identifier of this KV cache.
Return different strings for layers with different KV cache type (e.g.,
different number of tokens like full attention vs sliding window
attention, different KV cache size per token like layers with different
number of heads)
Returns:
The type identifier of this KV cache.
"""
raise NotImplementedError
@property
def page_size_bytes(self) -> int:
"""
@@ -63,13 +49,12 @@ class KVCacheSpec:
"""
Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
"""
assert all(spec.type_id == specs[0].type_id for spec in specs[1:]), (
"All layers in the same KV cache group must share the same "
"type_id.")
assert all(spec == specs[0] for spec in specs[1:]), (
"All layers in the same KV cache group must be the same.")
return copy.deepcopy(specs[0])
@dataclass
@dataclass(frozen=True)
class AttentionSpec(KVCacheSpec):
num_kv_heads: int
head_size: int
@@ -84,7 +69,7 @@ class AttentionSpec(KVCacheSpec):
* get_dtype_size(self.dtype)
@dataclass
@dataclass(frozen=True)
class FullAttentionSpec(AttentionSpec):
sliding_window: Optional[int] = None
attention_chunk_size: Optional[int] = None
@@ -98,10 +83,6 @@ class FullAttentionSpec(AttentionSpec):
Default to None for not using sliding window attention.
"""
@property
def type_id(self) -> str:
return f"full_attention_{self.block_size}_{self.page_size_bytes}"
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@@ -123,15 +104,28 @@ class FullAttentionSpec(AttentionSpec):
Merge a list of FullAttentionSpec objects into a single
FullAttentionSpec object.
"""
merged_spec = super().merge(specs)
assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
"All attention layers in the same KV cache group must be "
"FullAttentionSpec.")
sliding_window = set(spec.sliding_window for spec in specs
if spec.sliding_window is not None)
attention_chunk_size = set(spec.attention_chunk_size for spec in specs
if spec.attention_chunk_size is not None)
merged_spec.sliding_window = cls.merge_window_sizes(sliding_window)
merged_spec.attention_chunk_size = (
cls.merge_window_sizes(attention_chunk_size))
merged_spec = cls(
block_size=specs[0].block_size,
num_kv_heads=specs[0].num_kv_heads,
head_size=specs[0].head_size,
dtype=specs[0].dtype,
use_mla=specs[0].use_mla,
sliding_window=cls.merge_window_sizes(sliding_window),
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
)
for spec in specs:
for f in fields(AttentionSpec):
assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
"All attention layers in the same KV cache group must have "
"the same attention spec.")
assert (
(merged_spec.sliding_window is not None) +
(merged_spec.attention_chunk_size is not None) <= 1
@@ -140,16 +134,10 @@ class FullAttentionSpec(AttentionSpec):
return merged_spec
@dataclass
@dataclass(frozen=True)
class ChunkedLocalAttentionSpec(AttentionSpec):
attention_chunk_size: int
@property
def type_id(self) -> str:
return (
f"local_attention_{self.attention_chunk_size}_{self.block_size}_{self.page_size_bytes}"
) # noqa
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
max_num_batched_tokens = (
@@ -165,17 +153,13 @@ class ChunkedLocalAttentionSpec(AttentionSpec):
return cdiv(num_tokens, self.block_size) * self.page_size_bytes
@dataclass
@dataclass(frozen=True)
class SlidingWindowSpec(AttentionSpec):
sliding_window: int
def __post_init__(self):
assert not self.use_mla, "MLA is not supported for sliding window"
@property
def type_id(self) -> str:
return f"sliding_window_{self.sliding_window}_{self.block_size}_{self.page_size_bytes}" # noqa
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
max_num_batched_tokens = (
@@ -195,23 +179,17 @@ class SlidingWindowSpec(AttentionSpec):
return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
@dataclass
@dataclass(frozen=True)
class MambaSpec(KVCacheSpec):
shapes: tuple[tuple[int, ...], ...]
dtype: torch.dtype
page_size_padded: Optional[int] = None
mamba_type: str = "mamba2"
def __post_init__(self):
self.num_elements = sum(prod(shape) for shape in self.shapes)
@property
def type_id(self) -> str:
return f"mamba_{self.shapes}_{self.dtype}_{self.mamba_type}"
@property
def page_size_bytes(self) -> int:
page_size = self.num_elements * get_dtype_size(self.dtype)
num_elements = sum(prod(shape) for shape in self.shapes)
page_size = num_elements * get_dtype_size(self.dtype)
if self.page_size_padded is not None:
assert self.page_size_padded >= page_size
return self.page_size_padded