[Feat][Core] Support multiple KV cache groups in Hybrid KV Coordinator (#31707)

Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu>
This commit is contained in:
Yifan Qiao
2026-01-09 10:53:20 -08:00
committed by GitHub
parent d5ec6c056f
commit cd4a95e3aa
2 changed files with 325 additions and 128 deletions

View File

@@ -35,6 +35,7 @@ from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
MambaSpec,
SlidingWindowSpec,
)
@@ -106,8 +107,66 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
def make_kv_cache_config_hybrid_model(
block_size: int, num_blocks: int
block_size: int, num_blocks: int, second_spec_type: str = "sliding_window"
) -> KVCacheConfig:
if second_spec_type == "sliding_window":
second_spec = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=2 * block_size,
)
elif second_spec_type == "mamba":
second_spec = MambaSpec(
block_size=block_size,
shapes=(1, 1),
dtypes=(torch.float32,),
)
return KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer1"],
FullAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
),
),
KVCacheGroupSpec(
["layer2"],
second_spec,
),
KVCacheGroupSpec(
["layer3"],
second_spec,
),
],
)
def make_kv_cache_config_three_types(
block_size: int, num_blocks: int, third_spec_type: str = "mamba"
) -> KVCacheConfig:
if third_spec_type == "mamba":
third_spec = MambaSpec(
block_size=block_size,
shapes=(1, 1),
dtypes=(torch.float32,),
)
elif third_spec_type == "sliding_window":
third_spec = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=4 * block_size,
)
return KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[],
@@ -133,13 +192,7 @@ def make_kv_cache_config_hybrid_model(
),
KVCacheGroupSpec(
["layer3"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=2 * block_size,
),
third_spec,
),
],
)
@@ -424,6 +477,184 @@ def test_prefill_hybrid_model():
)
def _make_hybrid_kv_cache_config(
block_size: int, num_blocks: int, spec_types: list[str]
) -> KVCacheConfig:
"""
Create a KVCacheConfig with the specified spec types.
Args:
block_size: The block size for KV cache.
num_blocks: The number of blocks in the KV cache.
spec_types: List of spec type strings. Supported types:
- "full": FullAttentionSpec
- "sliding_window": SlidingWindowSpec with window=2*block_size
- "sliding_window_large": SlidingWindowSpec with window=4*block_size
- "mamba": MambaSpec
"""
spec_map = {
"full": lambda: FullAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
),
"sliding_window": lambda: SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=2 * block_size,
),
"sliding_window_large": lambda: SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=4 * block_size,
),
"mamba": lambda: MambaSpec(
block_size=block_size,
shapes=(1, 1),
dtypes=(torch.float32,),
),
}
kv_cache_groups = [
KVCacheGroupSpec([f"layer{i}"], spec_map[spec_type]())
for i, spec_type in enumerate(spec_types)
]
return KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[],
kv_cache_groups=kv_cache_groups,
)
# Test cases covering various combinations of KV cache spec types:
# - Varying number of groups (2, 3, or 4)
# - 0, 1, or 2 full attention groups
# - Sliding window with different window sizes
# - Interleaved group IDs (full attn and other types mixed)
# - Mamba spec combinations
_HYBRID_MODEL_TEST_CASES = [
# 2 groups: 1 full + 1 other
pytest.param(["full", "sliding_window"], id="2g-full+sw"),
pytest.param(["full", "mamba"], id="2g-full+mamba"),
# 2 groups: 0 full (all other types)
pytest.param(["sliding_window", "mamba"], id="2g-sw+mamba"),
pytest.param(["sliding_window", "sliding_window_large"], id="2g-sw+sw_large"),
# 3 groups: 1 full + 2 others (same type)
pytest.param(["full", "sliding_window", "sliding_window"], id="3g-full+2sw"),
pytest.param(["full", "mamba", "mamba"], id="3g-full+2mamba"),
# 3 groups: 1 full + 2 others (different types)
pytest.param(["full", "sliding_window", "mamba"], id="3g-full+sw+mamba"),
pytest.param(
["full", "sliding_window", "sliding_window_large"],
id="3g-full+sw+sw_large",
),
# 3 groups: 2 full + 1 other
pytest.param(["full", "full", "sliding_window"], id="3g-2full+sw"),
pytest.param(["full", "full", "mamba"], id="3g-2full+mamba"),
# 4 groups: interleaved (full, other, full, other)
pytest.param(
["full", "sliding_window", "full", "sliding_window_large"],
id="4g-interleaved-full+sw+sw_large",
),
pytest.param(
["full", "mamba", "full", "mamba"],
id="4g-interleaved-full+mamba",
),
# 4 groups: interleaved with different sliding windows
pytest.param(
["full", "sliding_window", "full", "sliding_window_large"],
id="4g-interleaved-full+sw_mixed",
),
# 4 groups: 0 full (all other types)
pytest.param(
["sliding_window", "mamba", "sliding_window_large", "mamba"],
id="4g-sw+mamba+sw_large+mamba",
),
# 4 groups: 2 full + 2 others (grouped)
pytest.param(
["full", "full", "sliding_window", "mamba"],
id="4g-2full+sw+mamba",
),
]
@pytest.mark.parametrize("spec_types", _HYBRID_MODEL_TEST_CASES)
def test_prefill_hybrid_model_combinations(spec_types: list[str]):
"""
Test prefix caching with hybrid models containing various combinations of
KV cache spec types.
This unified test covers:
- Various combinations (full attn + other attn types)
- Varying number of groups (2, 3, or 4)
- 0, 1, or 2 full attention groups in the combination
- Two sliding_window attn groups with different window sizes
- Interleaved group IDs (full attn and other types alternating)
- Mamba spec with other attention types
"""
block_size = 16
num_groups = len(spec_types)
# Allocate enough blocks for all groups
num_blocks = 10 * num_groups
kv_cache_config = _make_hybrid_kv_cache_config(block_size, num_blocks, spec_types)
manager = KVCacheManager(
kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
hash_fn = sha256
# Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(block_size)]
unique_token_ids = [3] * 7
all_token_ids = common_token_ids + unique_token_ids
# First request: no cache hit initially
req0 = make_request("0", all_token_ids, block_size, hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert len(req0.block_hashes) == 3
assert not computed_blocks.blocks[0] # No cache hit initially
assert num_computed_tokens == 0
blocks = manager.allocate_slots(
req0, 55, len(computed_blocks.blocks[0]) * block_size, computed_blocks
)
assert blocks is not None
# Should have blocks for all groups
assert len(blocks.get_block_ids()) == num_groups
# Second request: should hit cached blocks for common prefix
req1 = make_request("1", common_token_ids + [4] * 5, block_size, hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
# Should hit cached blocks for all groups
assert num_computed_tokens == 3 * block_size
assert len(computed_blocks.blocks) == num_groups
# Allocate and verify blocks for second request
blocks = manager.allocate_slots(
req1,
len(common_token_ids) + 5 - num_computed_tokens,
num_computed_tokens,
computed_blocks,
)
assert blocks is not None
assert len(blocks.get_block_ids()) == num_groups
manager.free(req0)
manager.free(req1)
def test_prefill_plp():
"""Test prefill with APC and some prompt logprobs (plp) requests.

View File

@@ -14,7 +14,7 @@ from vllm.v1.core.kv_cache_utils import (
)
from vllm.v1.core.single_type_kv_cache_manager import (
CrossAttentionManager,
FullAttentionManager,
SingleTypeKVCacheManager,
get_manager_for_kv_cache_spec,
)
from vllm.v1.kv_cache_interface import (
@@ -354,9 +354,6 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
"""
KV cache coordinator for hybrid models with multiple KV cache types, and
thus multiple kv cache groups.
To simplify `find_longest_cache_hit`, it only supports the combination of
two types of KV cache groups, and one of them must be full attention.
May extend to more general cases in the future.
"""
def __init__(
@@ -397,70 +394,46 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
def verify_and_split_kv_cache_groups(self) -> None:
"""
Verifies that the model has exactly two types of KV cache groups, and
one of them is full attention. Then, split the kv cache groups into full
attention groups and other groups.
Groups KV cache groups by their spec type for efficient batch processing
during cache hit lookup.
"""
full_attention_spec: FullAttentionSpec | None = None
other_spec: KVCacheSpec | None = None
self.full_attention_group_ids: list[int] = []
self.other_group_ids: list[int] = []
attention_groups: list[
tuple[KVCacheSpec, list[int], type[SingleTypeKVCacheManager]]
] = []
for i, g in enumerate(self.kv_cache_config.kv_cache_groups):
if isinstance(g.kv_cache_spec, FullAttentionSpec):
if full_attention_spec is None:
full_attention_spec = g.kv_cache_spec
else:
assert full_attention_spec == g.kv_cache_spec, (
"HybridKVCacheCoordinator assumes exactly one type of "
"full attention groups now."
manager_cls = self.single_type_managers[i].__class__
spec = g.kv_cache_spec
# Try to find an existing group with the same spec
for existing_spec, group_ids, existing_cls in attention_groups:
if existing_spec == spec:
assert manager_cls is existing_cls, (
"Expected same manager class for identical KV cache specs."
)
self.full_attention_group_ids.append(i)
group_ids.append(i)
break
else:
if other_spec is None:
other_spec = g.kv_cache_spec
else:
assert other_spec == g.kv_cache_spec, (
"HybridKVCacheCoordinator assumes "
"exactly one other type of groups now."
)
self.other_group_ids.append(i)
attention_groups.append((spec, [i], manager_cls))
assert full_attention_spec is not None, (
"HybridKVCacheCoordinator assumes exactly one type of full "
"attention groups now."
)
assert other_spec is not None, (
"HybridKVCacheCoordinator assumes exactly one type of other groups now."
assert len(attention_groups) > 1, (
"HybridKVCacheCoordinator requires at least two attention groups."
)
self.full_attention_manager_cls = FullAttentionManager
self.other_attention_cls = self.single_type_managers[
self.other_group_ids[0]
].__class__
self.full_attention_spec = full_attention_spec
self.other_spec = other_spec
self.full_attention_block_size = self.full_attention_spec.block_size
self.other_block_size = self.other_spec.block_size
# The LCM of the block sizes of full attention and other attention.
# Put full attention first: its efficient left-to-right scan provides
# a tighter initial bound, reducing work for subsequent groups.
self.attention_groups = sorted(
attention_groups,
key=lambda x: not isinstance(x[0], FullAttentionSpec),
)
# The LCM of the block sizes of all attention types.
# The cache hit length must be a multiple of the LCM of the block sizes
# to make sure the cache hit length is a multiple of the block size of
# each attention type. Requiring this because we don't support partial
# block cache hit yet.
self.lcm_block_size = lcm(self.full_attention_block_size, self.other_block_size)
if max(self.full_attention_group_ids) < min(self.other_group_ids):
self.full_attn_first = True
elif max(self.other_group_ids) < min(self.full_attention_group_ids):
self.full_attn_first = False
else:
raise ValueError(
"HybridKVCacheCoordinator assumes the full "
"attention group ids and other attention group ids "
"do not interleave, either full attention group ids "
"are before other attention group ids or vice versa."
"This is for simplifying merging hit_blocks_full_attn and "
"hit_blocks_other_attn to hit_blocks."
)
block_sizes = [spec.block_size for spec, _, _ in attention_groups]
self.lcm_block_size = lcm(*block_sizes)
def find_longest_cache_hit(
self,
@@ -468,7 +441,12 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
max_cache_hit_length: int,
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
"""
Find the longest cache hit for the request.
Find the longest cache hit using an iterative fixed-point algorithm.
Each attention type either accepts the current candidate length or
reduces it. If any type reduces the length, restart checks over all
types. This converges because length monotonically decreases and is
bounded below by 0.
Args:
block_hashes: The block hashes of the request.
@@ -476,75 +454,63 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
Returns:
A tuple containing:
- A list of the cache hit blocks for each single type manager.
- A tuple of the cache hit blocks for each single type manager.
- The number of tokens of the longest cache hit.
"""
# First, find the longest cache hit for full attention.
if self.full_attention_spec.block_size == self.hash_block_size:
# Common case.
full_attention_block_hashes: BlockHashList = block_hashes
else:
# block_size is a multiple of hash_block_size. This happens when different
# KV cache groups have different block sizes. In this case, we need to
# recalculate block_hashes at the granularity of block_size, using the
# original block_hashes (at the granularity of hash_block_size).
full_attention_block_hashes = BlockHashListWithBlockSize(
block_hashes, self.hash_block_size, self.full_attention_spec.block_size
def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList:
if kv_cache_spec.block_size == self.hash_block_size:
return block_hashes
return BlockHashListWithBlockSize(
block_hashes, self.hash_block_size, kv_cache_spec.block_size
)
hit_blocks_full_attn = self.full_attention_manager_cls.find_longest_cache_hit(
block_hashes=full_attention_block_hashes,
max_length=max_cache_hit_length,
kv_cache_group_ids=self.full_attention_group_ids,
block_pool=self.block_pool,
kv_cache_spec=self.full_attention_spec,
use_eagle=self.use_eagle,
alignment_tokens=self.lcm_block_size,
)
hit_length = len(hit_blocks_full_attn[0]) * self.full_attention_block_size
# Next, find the cache hit for the other attention WITHIN
# the cache hit of full attention.
if self.other_spec.block_size == self.hash_block_size:
# Common case.
other_block_hashes: BlockHashList = block_hashes
else:
# Similar to the full attention case, here we need to recalculate
# block_hashes at the granularity of block_size, using the original
# block_hashes (at the granularity of hash_block_size).
other_block_hashes = BlockHashListWithBlockSize(
block_hashes, self.hash_block_size, self.other_spec.block_size
)
hit_blocks_other_attn = self.other_attention_cls.find_longest_cache_hit(
block_hashes=other_block_hashes,
max_length=hit_length,
kv_cache_group_ids=self.other_group_ids,
block_pool=self.block_pool,
kv_cache_spec=self.other_spec,
use_eagle=self.use_eagle,
alignment_tokens=self.lcm_block_size,
)
hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size
num_groups = len(self.kv_cache_config.kv_cache_groups)
hit_length = max_cache_hit_length
hit_blocks_by_group: list[list[KVCacheBlock] | None] = [None] * num_groups
# NOTE: the prefix cache hit length must be a multiple of block_size as
# we don't support partial block cache hit yet. The cache hit length
# of other attention is ensured to be a multiple of the block size of
# full attention layers in current implementation, because hit_length is
# a multiple of other attention's block size, and other attention's
# block size is a multiple of full attention's block size (verified in
# `verify_and_split_kv_cache_groups`).
assert hit_length % self.full_attention_block_size == 0
while True:
curr_hit_length = hit_length
# Truncate the full attention cache hit to the length of the
# cache hit of the other attention.
for group_hit_blocks in hit_blocks_full_attn:
del group_hit_blocks[hit_length // self.full_attention_block_size :]
for spec, group_ids, manager_cls in self.attention_groups:
is_full_attn = isinstance(spec, FullAttentionSpec)
# Merge the hit blocks of full attention and other attention.
if self.full_attn_first:
hit_blocks = hit_blocks_full_attn + hit_blocks_other_attn
else:
hit_blocks = hit_blocks_other_attn + hit_blocks_full_attn
return hit_blocks, hit_length
# Full attention: reuse cached blocks (downward-closed property)
cached_blocks = hit_blocks_by_group[group_ids[0]]
if is_full_attn and cached_blocks is not None:
# For full attention, we only need to compute the cache hit
# length once. Starting from the second iteration, if the
# curr_hit_length is reduced by other groups, we can simply
# keep the first (curr_hit_length // block_size) blocks from
# the last iteration.
num_blocks = curr_hit_length // spec.block_size
curr_hit_length = num_blocks * spec.block_size
for group_id in group_ids:
blocks = hit_blocks_by_group[group_id]
assert blocks is not None
del blocks[num_blocks:]
else:
hit_blocks = manager_cls.find_longest_cache_hit(
block_hashes=_get_block_hashes(spec),
max_length=curr_hit_length,
kv_cache_group_ids=group_ids,
block_pool=self.block_pool,
kv_cache_spec=spec,
use_eagle=self.use_eagle,
alignment_tokens=self.lcm_block_size,
)
curr_hit_length = len(hit_blocks[0]) * spec.block_size
for group_id, blocks in zip(group_ids, hit_blocks):
hit_blocks_by_group[group_id] = blocks
if curr_hit_length < hit_length:
hit_length = curr_hit_length
else:
break
return tuple(
blocks if blocks is not None else [] for blocks in hit_blocks_by_group
), hit_length
def get_kv_cache_coordinator(