[Feat][Core] Support multiple KV cache groups in Hybrid KV Coordinator (#31707)

Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu>
This commit is contained in:
Yifan Qiao
2026-01-09 10:53:20 -08:00
committed by GitHub
parent d5ec6c056f
commit cd4a95e3aa
2 changed files with 325 additions and 128 deletions

View File

@@ -35,6 +35,7 @@ from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
MambaSpec,
SlidingWindowSpec,
)
@@ -106,8 +107,66 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
def make_kv_cache_config_hybrid_model(
block_size: int, num_blocks: int
block_size: int, num_blocks: int, second_spec_type: str = "sliding_window"
) -> KVCacheConfig:
if second_spec_type == "sliding_window":
second_spec = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=2 * block_size,
)
elif second_spec_type == "mamba":
second_spec = MambaSpec(
block_size=block_size,
shapes=(1, 1),
dtypes=(torch.float32,),
)
return KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer1"],
FullAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
),
),
KVCacheGroupSpec(
["layer2"],
second_spec,
),
KVCacheGroupSpec(
["layer3"],
second_spec,
),
],
)
def make_kv_cache_config_three_types(
block_size: int, num_blocks: int, third_spec_type: str = "mamba"
) -> KVCacheConfig:
if third_spec_type == "mamba":
third_spec = MambaSpec(
block_size=block_size,
shapes=(1, 1),
dtypes=(torch.float32,),
)
elif third_spec_type == "sliding_window":
third_spec = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=4 * block_size,
)
return KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[],
@@ -133,13 +192,7 @@ def make_kv_cache_config_hybrid_model(
),
KVCacheGroupSpec(
["layer3"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=2 * block_size,
),
third_spec,
),
],
)
@@ -424,6 +477,184 @@ def test_prefill_hybrid_model():
)
def _make_hybrid_kv_cache_config(
block_size: int, num_blocks: int, spec_types: list[str]
) -> KVCacheConfig:
"""
Create a KVCacheConfig with the specified spec types.
Args:
block_size: The block size for KV cache.
num_blocks: The number of blocks in the KV cache.
spec_types: List of spec type strings. Supported types:
- "full": FullAttentionSpec
- "sliding_window": SlidingWindowSpec with window=2*block_size
- "sliding_window_large": SlidingWindowSpec with window=4*block_size
- "mamba": MambaSpec
"""
spec_map = {
"full": lambda: FullAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
),
"sliding_window": lambda: SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=2 * block_size,
),
"sliding_window_large": lambda: SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=4 * block_size,
),
"mamba": lambda: MambaSpec(
block_size=block_size,
shapes=(1, 1),
dtypes=(torch.float32,),
),
}
kv_cache_groups = [
KVCacheGroupSpec([f"layer{i}"], spec_map[spec_type]())
for i, spec_type in enumerate(spec_types)
]
return KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[],
kv_cache_groups=kv_cache_groups,
)
# Test cases covering various combinations of KV cache spec types:
# - Varying number of groups (2, 3, or 4)
# - 0, 1, or 2 full attention groups
# - Sliding window with different window sizes
# - Interleaved group IDs (full attn and other types mixed)
# - Mamba spec combinations
_HYBRID_MODEL_TEST_CASES = [
# 2 groups: 1 full + 1 other
pytest.param(["full", "sliding_window"], id="2g-full+sw"),
pytest.param(["full", "mamba"], id="2g-full+mamba"),
# 2 groups: 0 full (all other types)
pytest.param(["sliding_window", "mamba"], id="2g-sw+mamba"),
pytest.param(["sliding_window", "sliding_window_large"], id="2g-sw+sw_large"),
# 3 groups: 1 full + 2 others (same type)
pytest.param(["full", "sliding_window", "sliding_window"], id="3g-full+2sw"),
pytest.param(["full", "mamba", "mamba"], id="3g-full+2mamba"),
# 3 groups: 1 full + 2 others (different types)
pytest.param(["full", "sliding_window", "mamba"], id="3g-full+sw+mamba"),
pytest.param(
["full", "sliding_window", "sliding_window_large"],
id="3g-full+sw+sw_large",
),
# 3 groups: 2 full + 1 other
pytest.param(["full", "full", "sliding_window"], id="3g-2full+sw"),
pytest.param(["full", "full", "mamba"], id="3g-2full+mamba"),
# 4 groups: interleaved (full, other, full, other)
pytest.param(
["full", "sliding_window", "full", "sliding_window_large"],
id="4g-interleaved-full+sw+sw_large",
),
pytest.param(
["full", "mamba", "full", "mamba"],
id="4g-interleaved-full+mamba",
),
# 4 groups: interleaved with different sliding windows
pytest.param(
["full", "sliding_window", "full", "sliding_window_large"],
id="4g-interleaved-full+sw_mixed",
),
# 4 groups: 0 full (all other types)
pytest.param(
["sliding_window", "mamba", "sliding_window_large", "mamba"],
id="4g-sw+mamba+sw_large+mamba",
),
# 4 groups: 2 full + 2 others (grouped)
pytest.param(
["full", "full", "sliding_window", "mamba"],
id="4g-2full+sw+mamba",
),
]
@pytest.mark.parametrize("spec_types", _HYBRID_MODEL_TEST_CASES)
def test_prefill_hybrid_model_combinations(spec_types: list[str]):
"""
Test prefix caching with hybrid models containing various combinations of
KV cache spec types.
This unified test covers:
- Various combinations (full attn + other attn types)
- Varying number of groups (2, 3, or 4)
- 0, 1, or 2 full attention groups in the combination
- Two sliding_window attn groups with different window sizes
- Interleaved group IDs (full attn and other types alternating)
- Mamba spec with other attention types
"""
block_size = 16
num_groups = len(spec_types)
# Allocate enough blocks for all groups
num_blocks = 10 * num_groups
kv_cache_config = _make_hybrid_kv_cache_config(block_size, num_blocks, spec_types)
manager = KVCacheManager(
kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
)
hash_fn = sha256
# Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(block_size)]
unique_token_ids = [3] * 7
all_token_ids = common_token_ids + unique_token_ids
# First request: no cache hit initially
req0 = make_request("0", all_token_ids, block_size, hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert len(req0.block_hashes) == 3
assert not computed_blocks.blocks[0] # No cache hit initially
assert num_computed_tokens == 0
blocks = manager.allocate_slots(
req0, 55, len(computed_blocks.blocks[0]) * block_size, computed_blocks
)
assert blocks is not None
# Should have blocks for all groups
assert len(blocks.get_block_ids()) == num_groups
# Second request: should hit cached blocks for common prefix
req1 = make_request("1", common_token_ids + [4] * 5, block_size, hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
# Should hit cached blocks for all groups
assert num_computed_tokens == 3 * block_size
assert len(computed_blocks.blocks) == num_groups
# Allocate and verify blocks for second request
blocks = manager.allocate_slots(
req1,
len(common_token_ids) + 5 - num_computed_tokens,
num_computed_tokens,
computed_blocks,
)
assert blocks is not None
assert len(blocks.get_block_ids()) == num_groups
manager.free(req0)
manager.free(req1)
def test_prefill_plp():
"""Test prefill with APC and some prompt logprobs (plp) requests.