[KVCache] Make KVCacheSpec hashable (#21791)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, fields
|
||||
from math import prod
|
||||
from typing import Optional
|
||||
|
||||
@@ -16,7 +16,7 @@ from vllm.utils import cdiv, get_dtype_size
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class KVCacheSpec:
|
||||
"""
|
||||
A base class for specifying the KV cache format of one layer.
|
||||
@@ -25,20 +25,6 @@ class KVCacheSpec:
|
||||
# number of tokens in a block
|
||||
block_size: int
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
"""
|
||||
The type identifier of this KV cache.
|
||||
Return different strings for layers with different KV cache type (e.g.,
|
||||
different number of tokens like full attention vs sliding window
|
||||
attention, different KV cache size per token like layers with different
|
||||
number of heads)
|
||||
|
||||
Returns:
|
||||
The type identifier of this KV cache.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
"""
|
||||
@@ -63,13 +49,12 @@ class KVCacheSpec:
|
||||
"""
|
||||
Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
|
||||
"""
|
||||
assert all(spec.type_id == specs[0].type_id for spec in specs[1:]), (
|
||||
"All layers in the same KV cache group must share the same "
|
||||
"type_id.")
|
||||
assert all(spec == specs[0] for spec in specs[1:]), (
|
||||
"All layers in the same KV cache group must be the same.")
|
||||
return copy.deepcopy(specs[0])
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class AttentionSpec(KVCacheSpec):
|
||||
num_kv_heads: int
|
||||
head_size: int
|
||||
@@ -84,7 +69,7 @@ class AttentionSpec(KVCacheSpec):
|
||||
* get_dtype_size(self.dtype)
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class FullAttentionSpec(AttentionSpec):
|
||||
sliding_window: Optional[int] = None
|
||||
attention_chunk_size: Optional[int] = None
|
||||
@@ -98,10 +83,6 @@ class FullAttentionSpec(AttentionSpec):
|
||||
Default to None for not using sliding window attention.
|
||||
"""
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
return f"full_attention_{self.block_size}_{self.page_size_bytes}"
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
||||
@@ -123,15 +104,28 @@ class FullAttentionSpec(AttentionSpec):
|
||||
Merge a list of FullAttentionSpec objects into a single
|
||||
FullAttentionSpec object.
|
||||
"""
|
||||
merged_spec = super().merge(specs)
|
||||
assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
|
||||
"All attention layers in the same KV cache group must be "
|
||||
"FullAttentionSpec.")
|
||||
|
||||
sliding_window = set(spec.sliding_window for spec in specs
|
||||
if spec.sliding_window is not None)
|
||||
attention_chunk_size = set(spec.attention_chunk_size for spec in specs
|
||||
if spec.attention_chunk_size is not None)
|
||||
|
||||
merged_spec.sliding_window = cls.merge_window_sizes(sliding_window)
|
||||
merged_spec.attention_chunk_size = (
|
||||
cls.merge_window_sizes(attention_chunk_size))
|
||||
merged_spec = cls(
|
||||
block_size=specs[0].block_size,
|
||||
num_kv_heads=specs[0].num_kv_heads,
|
||||
head_size=specs[0].head_size,
|
||||
dtype=specs[0].dtype,
|
||||
use_mla=specs[0].use_mla,
|
||||
sliding_window=cls.merge_window_sizes(sliding_window),
|
||||
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
|
||||
)
|
||||
for spec in specs:
|
||||
for f in fields(AttentionSpec):
|
||||
assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
|
||||
"All attention layers in the same KV cache group must have "
|
||||
"the same attention spec.")
|
||||
assert (
|
||||
(merged_spec.sliding_window is not None) +
|
||||
(merged_spec.attention_chunk_size is not None) <= 1
|
||||
@@ -140,16 +134,10 @@ class FullAttentionSpec(AttentionSpec):
|
||||
return merged_spec
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class ChunkedLocalAttentionSpec(AttentionSpec):
|
||||
attention_chunk_size: int
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
return (
|
||||
f"local_attention_{self.attention_chunk_size}_{self.block_size}_{self.page_size_bytes}"
|
||||
) # noqa
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
max_num_batched_tokens = (
|
||||
@@ -165,17 +153,13 @@ class ChunkedLocalAttentionSpec(AttentionSpec):
|
||||
return cdiv(num_tokens, self.block_size) * self.page_size_bytes
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class SlidingWindowSpec(AttentionSpec):
|
||||
sliding_window: int
|
||||
|
||||
def __post_init__(self):
|
||||
assert not self.use_mla, "MLA is not supported for sliding window"
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
return f"sliding_window_{self.sliding_window}_{self.block_size}_{self.page_size_bytes}" # noqa
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
max_num_batched_tokens = (
|
||||
@@ -195,23 +179,17 @@ class SlidingWindowSpec(AttentionSpec):
|
||||
return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class MambaSpec(KVCacheSpec):
|
||||
shapes: tuple[tuple[int, ...], ...]
|
||||
dtype: torch.dtype
|
||||
page_size_padded: Optional[int] = None
|
||||
mamba_type: str = "mamba2"
|
||||
|
||||
def __post_init__(self):
|
||||
self.num_elements = sum(prod(shape) for shape in self.shapes)
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
return f"mamba_{self.shapes}_{self.dtype}_{self.mamba_type}"
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
page_size = self.num_elements * get_dtype_size(self.dtype)
|
||||
num_elements = sum(prod(shape) for shape in self.shapes)
|
||||
page_size = num_elements * get_dtype_size(self.dtype)
|
||||
if self.page_size_padded is not None:
|
||||
assert self.page_size_padded >= page_size
|
||||
return self.page_size_padded
|
||||
|
||||
Reference in New Issue
Block a user