[v1] Refactor KVCacheConfig (#14079)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-03-21 19:56:27 +08:00
committed by GitHub
parent 61e8c18350
commit 93a00d7dde
10 changed files with 318 additions and 110 deletions

View File

@@ -11,7 +11,7 @@ logger = init_logger(__name__)
@dataclass
class KVCacheSpecBase:
class KVCacheSpec:
"""
A base class for specifying the KV cache format of one layer.
"""
@@ -55,7 +55,7 @@ class KVCacheSpecBase:
@dataclass
class FullAttentionSpec(KVCacheSpecBase):
class FullAttentionSpec(KVCacheSpec):
num_kv_heads: int
head_size: int
dtype: torch.dtype
@@ -76,9 +76,6 @@ class FullAttentionSpec(KVCacheSpecBase):
return cdiv(num_tokens, self.block_size) * self.page_size_bytes
KVCacheSpec = dict[str, KVCacheSpecBase]
@dataclass
class KVCacheTensor:
"""
@@ -89,6 +86,18 @@ class KVCacheTensor:
size: int # The size of KV cache Tensor in bytes
@dataclass
class KVCacheGroupSpec:
"""
Represents a group of model layers that share the same KV cache block table.
These layers are regarded as one layer in the KV cache manager.
"""
# The names of model layers in this group
layer_names: list[str]
# The KV cache spec of this manager layer
kv_cache_spec: KVCacheSpec
@dataclass
class KVCacheConfig:
"""
@@ -99,17 +108,24 @@ class KVCacheConfig:
"""layer_name -> how to initialize KV cache for that layer"""
tensors: dict[str, KVCacheTensor]
"""
A list of kv-cache groups. Each group includes a set of layers with
the same kv-cache spec, and the total page_size of layers inside a group
is same across all groups (as the KVCacheManager only supports allocating
pages of the same size). For example:
1. A model only uses full attention: one group with all layers in the model.
2. (not implemented yet) A model with the same number of full attention
layers and sliding window attention layers: two groups, one for full
attention layers and one for sliding window attention layers.
3. (not implemented yet) A model with 2 full attention layers and 4 sliding
window attention layers: three groups, (full * 2), (sw * 2), (sw * 2).
The kv cache groups of the model.
The layers in the models are repeated with some patterns, e.g., a model
with 10 full attention layers and 20 sliding window attention layers can be
regarded as repeating the pattern (1 * full, 2 * sw) 10 times.
The KVCacheManager allocates different block tables for each of the 3 layers
in the pattern, and repeats each of them 10 times to generate the
block_table for the 30 layers in the model.
Therefore, we can group the layers in the model into 3 groups, each of which
contains 10 layers in the model.
The KVCacheManager allocates the block_table for each group based on its
kv_cache spec, and the model runner applies the block table to each layer
in the group.
For example:
1. A model only uses full attention. The pattern is
(num_hidden_layers * full), so there is only one group and the block table
is shared by all layers.
2. (WIP) A model with 10 full attention layers and 20 sliding window
attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so
there are 3 groups, each of which represents 10 layers in the model.
"""
groups: list[list[str]]
"""the KVCacheSpec of the model"""
kv_cache_spec: KVCacheSpec
kv_cache_groups: list[KVCacheGroupSpec]