[KVCache] Make KVCacheSpec hashable (#21791)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -7,7 +7,8 @@ from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||
from vllm.v1.core.single_type_kv_cache_manager import (
|
||||
FullAttentionManager, get_manager_for_kv_cache_spec)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
@@ -258,44 +259,40 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
one of them is full attention. Then, split the kv cache groups into full
|
||||
attention groups and other groups.
|
||||
"""
|
||||
full_attention_type_id: Optional[str] = None
|
||||
other_type_id: Optional[str] = None
|
||||
full_attention_spec: Optional[FullAttentionSpec] = None
|
||||
other_spec: Optional[KVCacheSpec] = None
|
||||
self.full_attention_group_ids: list[int] = []
|
||||
self.other_group_ids: list[int] = []
|
||||
for i, g in enumerate(self.kv_cache_config.kv_cache_groups):
|
||||
if isinstance(g.kv_cache_spec, FullAttentionSpec):
|
||||
if full_attention_type_id is None:
|
||||
full_attention_type_id = g.kv_cache_spec.type_id
|
||||
if full_attention_spec is None:
|
||||
full_attention_spec = g.kv_cache_spec
|
||||
else:
|
||||
assert full_attention_type_id == g.kv_cache_spec.type_id, (
|
||||
assert full_attention_spec == g.kv_cache_spec, (
|
||||
"HybridKVCacheCoordinator assumes exactly one type of "
|
||||
"full attention groups now.")
|
||||
self.full_attention_group_ids.append(i)
|
||||
else:
|
||||
if other_type_id is None:
|
||||
other_type_id = g.kv_cache_spec.type_id
|
||||
if other_spec is None:
|
||||
other_spec = g.kv_cache_spec
|
||||
else:
|
||||
assert other_type_id == g.kv_cache_spec.type_id, (
|
||||
assert other_spec == g.kv_cache_spec, (
|
||||
"HybridKVCacheCoordinator assumes "
|
||||
"exactly one other type of groups now.")
|
||||
self.other_group_ids.append(i)
|
||||
|
||||
assert full_attention_type_id is not None, (
|
||||
assert full_attention_spec is not None, (
|
||||
"HybridKVCacheCoordinator assumes exactly one type of full "
|
||||
"attention groups now.")
|
||||
assert other_type_id is not None, (
|
||||
assert other_spec is not None, (
|
||||
"HybridKVCacheCoordinator assumes exactly one type of other "
|
||||
"groups now.")
|
||||
|
||||
self.full_attention_manager_cls = FullAttentionManager
|
||||
self.other_attention_cls = self.single_type_managers[
|
||||
self.other_group_ids[0]].__class__
|
||||
|
||||
self.full_attention_spec = self.kv_cache_config.kv_cache_groups[
|
||||
self.full_attention_group_ids[0]].kv_cache_spec
|
||||
self.other_spec = self.kv_cache_config.kv_cache_groups[
|
||||
self.other_group_ids[0]].kv_cache_spec
|
||||
|
||||
self.full_attention_spec = full_attention_spec
|
||||
self.other_spec = other_spec
|
||||
self.full_attention_block_size = self.full_attention_spec.block_size
|
||||
self.other_block_size = self.other_spec.block_size
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
import os
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import astuple, dataclass
|
||||
from typing import Any, Callable, NamedTuple, Optional
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
@@ -727,7 +727,9 @@ def create_kv_cache_group_specs(
|
||||
|
||||
def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
|
||||
"""
|
||||
Whether all layers in the given KVCacheSpec have the same type of KV cache.
|
||||
Whether all layers in the given KVCacheSpec have the same KV cache spec.
|
||||
Note that we regard FullAttentionSpec with and without sliding window as
|
||||
the same type.
|
||||
|
||||
Args:
|
||||
kv_cache_spec: The kv cache spec of each attention layer in the model
|
||||
@@ -736,8 +738,12 @@ def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
|
||||
True if all layers have the same type, False otherwise.
|
||||
"""
|
||||
|
||||
layer_keys = set(layer.type_id for layer in kv_cache_spec.values())
|
||||
return len(layer_keys) == 1
|
||||
try:
|
||||
kv_cache_spec_values = list(kv_cache_spec.values())
|
||||
_ = kv_cache_spec_values[0].merge(kv_cache_spec_values)
|
||||
except AssertionError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_max_concurrency_for_kv_cache_config(
|
||||
@@ -928,12 +934,12 @@ def _get_kv_cache_config_uniform_page_size(
|
||||
Returns:
|
||||
The generated KVCacheConfig
|
||||
"""
|
||||
# Group all layers by type_id.
|
||||
# Group all layers by kv_cache_spec.
|
||||
# E.g., 2 full attention layers and 3 sliding window attention layers,
|
||||
# -> (full.0, full.1), (sw.0, sw.1, sw.2).
|
||||
same_type_layers: dict[str, list[str]] = defaultdict(list)
|
||||
same_type_layers: dict[KVCacheSpec, list[str]] = defaultdict(list)
|
||||
for layer_name, layer_spec in kv_cache_spec.items():
|
||||
same_type_layers[layer_spec.type_id].append(layer_name)
|
||||
same_type_layers[layer_spec].append(layer_name)
|
||||
|
||||
# Split each group into smaller groups, to make the number of layers in each
|
||||
# group identical. Add padding to the last group of each type if necessary.
|
||||
@@ -1017,12 +1023,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
|
||||
kv_cache_spec: The kv cache spec of each attention layer in the model
|
||||
"""
|
||||
|
||||
def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
|
||||
type_ids = set(layer_spec.type_id
|
||||
for layer_spec in kv_cache_spec.values())
|
||||
return len(type_ids) > 1
|
||||
|
||||
if not is_hybrid(kv_cache_spec):
|
||||
if is_kv_cache_type_uniform(kv_cache_spec):
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
@@ -1060,7 +1061,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
|
||||
attention_chunk_size=spec.attention_chunk_size,
|
||||
)
|
||||
|
||||
if is_hybrid(kv_cache_spec):
|
||||
if not is_kv_cache_type_uniform(kv_cache_spec):
|
||||
raise ValueError("Hybrid KV cache manager is disabled but failed to "
|
||||
"convert the KV cache specs to one unified type.")
|
||||
|
||||
@@ -1119,11 +1120,11 @@ def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]):
|
||||
in-place modified to make them consistent.
|
||||
"""
|
||||
|
||||
# Sort the kv cache groups by the type_id of their KV cache spec.
|
||||
# Sort the kv cache groups by their KV cache spec.
|
||||
# This can avoid the inconsistency caused by the order of groups.
|
||||
for kv_cache_config in kv_cache_configs:
|
||||
kv_cache_config.kv_cache_groups.sort(
|
||||
key=lambda x: x.kv_cache_spec.type_id)
|
||||
kv_cache_config.kv_cache_groups.sort(key=lambda x: (type(
|
||||
x.kv_cache_spec).__name__, astuple(x.kv_cache_spec)))
|
||||
|
||||
# Verify that the groups of each rank are the same.
|
||||
for kv_cache_config in kv_cache_configs[1:]:
|
||||
|
||||
Reference in New Issue
Block a user