[KVCache] Make KVCacheSpec hashable (#21791)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-07-29 04:58:29 -07:00
committed by GitHub
parent 2470419119
commit 755fa8b657
5 changed files with 100 additions and 88 deletions

View File

@@ -7,7 +7,8 @@ from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.core.single_type_kv_cache_manager import (
FullAttentionManager, get_manager_for_kv_cache_spec)
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.request import Request
@@ -258,44 +259,40 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
one of them is full attention. Then, split the kv cache groups into full
attention groups and other groups.
"""
full_attention_type_id: Optional[str] = None
other_type_id: Optional[str] = None
full_attention_spec: Optional[FullAttentionSpec] = None
other_spec: Optional[KVCacheSpec] = None
self.full_attention_group_ids: list[int] = []
self.other_group_ids: list[int] = []
for i, g in enumerate(self.kv_cache_config.kv_cache_groups):
if isinstance(g.kv_cache_spec, FullAttentionSpec):
if full_attention_type_id is None:
full_attention_type_id = g.kv_cache_spec.type_id
if full_attention_spec is None:
full_attention_spec = g.kv_cache_spec
else:
assert full_attention_type_id == g.kv_cache_spec.type_id, (
assert full_attention_spec == g.kv_cache_spec, (
"HybridKVCacheCoordinator assumes exactly one type of "
"full attention groups now.")
self.full_attention_group_ids.append(i)
else:
if other_type_id is None:
other_type_id = g.kv_cache_spec.type_id
if other_spec is None:
other_spec = g.kv_cache_spec
else:
assert other_type_id == g.kv_cache_spec.type_id, (
assert other_spec == g.kv_cache_spec, (
"HybridKVCacheCoordinator assumes "
"exactly one other type of groups now.")
self.other_group_ids.append(i)
assert full_attention_type_id is not None, (
assert full_attention_spec is not None, (
"HybridKVCacheCoordinator assumes exactly one type of full "
"attention groups now.")
assert other_type_id is not None, (
assert other_spec is not None, (
"HybridKVCacheCoordinator assumes exactly one type of other "
"groups now.")
self.full_attention_manager_cls = FullAttentionManager
self.other_attention_cls = self.single_type_managers[
self.other_group_ids[0]].__class__
self.full_attention_spec = self.kv_cache_config.kv_cache_groups[
self.full_attention_group_ids[0]].kv_cache_spec
self.other_spec = self.kv_cache_config.kv_cache_groups[
self.other_group_ids[0]].kv_cache_spec
self.full_attention_spec = full_attention_spec
self.other_spec = other_spec
self.full_attention_block_size = self.full_attention_spec.block_size
self.other_block_size = self.other_spec.block_size

View File

@@ -5,7 +5,7 @@
import os
from collections import defaultdict, deque
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from dataclasses import astuple, dataclass
from typing import Any, Callable, NamedTuple, Optional
from vllm.config import VllmConfig
@@ -727,7 +727,9 @@ def create_kv_cache_group_specs(
def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
"""
Whether all layers in the given KVCacheSpec have the same type of KV cache.
Whether all layers in the given KVCacheSpec have the same KV cache spec.
Note that we regard FullAttentionSpec with and without sliding window as
the same type.
Args:
kv_cache_spec: The kv cache spec of each attention layer in the model
@@ -736,8 +738,12 @@ def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
True if all layers have the same type, False otherwise.
"""
layer_keys = set(layer.type_id for layer in kv_cache_spec.values())
return len(layer_keys) == 1
try:
kv_cache_spec_values = list(kv_cache_spec.values())
_ = kv_cache_spec_values[0].merge(kv_cache_spec_values)
except AssertionError:
return False
return True
def get_max_concurrency_for_kv_cache_config(
@@ -928,12 +934,12 @@ def _get_kv_cache_config_uniform_page_size(
Returns:
The generated KVCacheConfig
"""
# Group all layers by type_id.
# Group all layers by kv_cache_spec.
# E.g., 2 full attention layers and 3 sliding window attention layers,
# -> (full.0, full.1), (sw.0, sw.1, sw.2).
same_type_layers: dict[str, list[str]] = defaultdict(list)
same_type_layers: dict[KVCacheSpec, list[str]] = defaultdict(list)
for layer_name, layer_spec in kv_cache_spec.items():
same_type_layers[layer_spec.type_id].append(layer_name)
same_type_layers[layer_spec].append(layer_name)
# Split each group into smaller groups, to make the number of layers in each
# group identical. Add padding to the last group of each type if necessary.
@@ -1017,12 +1023,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
kv_cache_spec: The kv cache spec of each attention layer in the model
"""
def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
type_ids = set(layer_spec.type_id
for layer_spec in kv_cache_spec.values())
return len(type_ids) > 1
if not is_hybrid(kv_cache_spec):
if is_kv_cache_type_uniform(kv_cache_spec):
return
logger.warning(
@@ -1060,7 +1061,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
attention_chunk_size=spec.attention_chunk_size,
)
if is_hybrid(kv_cache_spec):
if not is_kv_cache_type_uniform(kv_cache_spec):
raise ValueError("Hybrid KV cache manager is disabled but failed to "
"convert the KV cache specs to one unified type.")
@@ -1119,11 +1120,11 @@ def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]):
in-place modified to make them consistent.
"""
# Sort the kv cache groups by the type_id of their KV cache spec.
# Sort the kv cache groups by their KV cache spec.
# This can avoid the inconsistency caused by the order of groups.
for kv_cache_config in kv_cache_configs:
kv_cache_config.kv_cache_groups.sort(
key=lambda x: x.kv_cache_spec.type_id)
kv_cache_config.kv_cache_groups.sort(key=lambda x: (type(
x.kv_cache_spec).__name__, astuple(x.kv_cache_spec)))
# Verify that the groups of each rank are the same.
for kv_cache_config in kv_cache_configs[1:]: