[Feat][Core] Support multiple KV cache groups in Hybrid KV Coordinator (#31707)
Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu>
This commit is contained in:
@@ -35,6 +35,7 @@ from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheGroupSpec,
|
||||
MambaSpec,
|
||||
SlidingWindowSpec,
|
||||
)
|
||||
|
||||
@@ -106,8 +107,66 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
|
||||
|
||||
|
||||
def make_kv_cache_config_hybrid_model(
|
||||
block_size: int, num_blocks: int
|
||||
block_size: int, num_blocks: int, second_spec_type: str = "sliding_window"
|
||||
) -> KVCacheConfig:
|
||||
if second_spec_type == "sliding_window":
|
||||
second_spec = SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=2 * block_size,
|
||||
)
|
||||
elif second_spec_type == "mamba":
|
||||
second_spec = MambaSpec(
|
||||
block_size=block_size,
|
||||
shapes=(1, 1),
|
||||
dtypes=(torch.float32,),
|
||||
)
|
||||
|
||||
return KVCacheConfig(
|
||||
num_blocks=num_blocks,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer1"],
|
||||
FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["layer2"],
|
||||
second_spec,
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["layer3"],
|
||||
second_spec,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def make_kv_cache_config_three_types(
|
||||
block_size: int, num_blocks: int, third_spec_type: str = "mamba"
|
||||
) -> KVCacheConfig:
|
||||
if third_spec_type == "mamba":
|
||||
third_spec = MambaSpec(
|
||||
block_size=block_size,
|
||||
shapes=(1, 1),
|
||||
dtypes=(torch.float32,),
|
||||
)
|
||||
elif third_spec_type == "sliding_window":
|
||||
third_spec = SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=4 * block_size,
|
||||
)
|
||||
|
||||
return KVCacheConfig(
|
||||
num_blocks=num_blocks,
|
||||
kv_cache_tensors=[],
|
||||
@@ -133,13 +192,7 @@ def make_kv_cache_config_hybrid_model(
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["layer3"],
|
||||
SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=2 * block_size,
|
||||
),
|
||||
third_spec,
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -424,6 +477,184 @@ def test_prefill_hybrid_model():
|
||||
)
|
||||
|
||||
|
||||
def _make_hybrid_kv_cache_config(
|
||||
block_size: int, num_blocks: int, spec_types: list[str]
|
||||
) -> KVCacheConfig:
|
||||
"""
|
||||
Create a KVCacheConfig with the specified spec types.
|
||||
|
||||
Args:
|
||||
block_size: The block size for KV cache.
|
||||
num_blocks: The number of blocks in the KV cache.
|
||||
spec_types: List of spec type strings. Supported types:
|
||||
- "full": FullAttentionSpec
|
||||
- "sliding_window": SlidingWindowSpec with window=2*block_size
|
||||
- "sliding_window_large": SlidingWindowSpec with window=4*block_size
|
||||
- "mamba": MambaSpec
|
||||
"""
|
||||
spec_map = {
|
||||
"full": lambda: FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
"sliding_window": lambda: SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=2 * block_size,
|
||||
),
|
||||
"sliding_window_large": lambda: SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=4 * block_size,
|
||||
),
|
||||
"mamba": lambda: MambaSpec(
|
||||
block_size=block_size,
|
||||
shapes=(1, 1),
|
||||
dtypes=(torch.float32,),
|
||||
),
|
||||
}
|
||||
|
||||
kv_cache_groups = [
|
||||
KVCacheGroupSpec([f"layer{i}"], spec_map[spec_type]())
|
||||
for i, spec_type in enumerate(spec_types)
|
||||
]
|
||||
|
||||
return KVCacheConfig(
|
||||
num_blocks=num_blocks,
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=kv_cache_groups,
|
||||
)
|
||||
|
||||
|
||||
# Test cases covering various combinations of KV cache spec types:
|
||||
# - Varying number of groups (2, 3, or 4)
|
||||
# - 0, 1, or 2 full attention groups
|
||||
# - Sliding window with different window sizes
|
||||
# - Interleaved group IDs (full attn and other types mixed)
|
||||
# - Mamba spec combinations
|
||||
_HYBRID_MODEL_TEST_CASES = [
|
||||
# 2 groups: 1 full + 1 other
|
||||
pytest.param(["full", "sliding_window"], id="2g-full+sw"),
|
||||
pytest.param(["full", "mamba"], id="2g-full+mamba"),
|
||||
# 2 groups: 0 full (all other types)
|
||||
pytest.param(["sliding_window", "mamba"], id="2g-sw+mamba"),
|
||||
pytest.param(["sliding_window", "sliding_window_large"], id="2g-sw+sw_large"),
|
||||
# 3 groups: 1 full + 2 others (same type)
|
||||
pytest.param(["full", "sliding_window", "sliding_window"], id="3g-full+2sw"),
|
||||
pytest.param(["full", "mamba", "mamba"], id="3g-full+2mamba"),
|
||||
# 3 groups: 1 full + 2 others (different types)
|
||||
pytest.param(["full", "sliding_window", "mamba"], id="3g-full+sw+mamba"),
|
||||
pytest.param(
|
||||
["full", "sliding_window", "sliding_window_large"],
|
||||
id="3g-full+sw+sw_large",
|
||||
),
|
||||
# 3 groups: 2 full + 1 other
|
||||
pytest.param(["full", "full", "sliding_window"], id="3g-2full+sw"),
|
||||
pytest.param(["full", "full", "mamba"], id="3g-2full+mamba"),
|
||||
# 4 groups: interleaved (full, other, full, other)
|
||||
pytest.param(
|
||||
["full", "sliding_window", "full", "sliding_window_large"],
|
||||
id="4g-interleaved-full+sw+sw_large",
|
||||
),
|
||||
pytest.param(
|
||||
["full", "mamba", "full", "mamba"],
|
||||
id="4g-interleaved-full+mamba",
|
||||
),
|
||||
# 4 groups: interleaved with different sliding windows
|
||||
pytest.param(
|
||||
["full", "sliding_window", "full", "sliding_window_large"],
|
||||
id="4g-interleaved-full+sw_mixed",
|
||||
),
|
||||
# 4 groups: 0 full (all other types)
|
||||
pytest.param(
|
||||
["sliding_window", "mamba", "sliding_window_large", "mamba"],
|
||||
id="4g-sw+mamba+sw_large+mamba",
|
||||
),
|
||||
# 4 groups: 2 full + 2 others (grouped)
|
||||
pytest.param(
|
||||
["full", "full", "sliding_window", "mamba"],
|
||||
id="4g-2full+sw+mamba",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec_types", _HYBRID_MODEL_TEST_CASES)
|
||||
def test_prefill_hybrid_model_combinations(spec_types: list[str]):
|
||||
"""
|
||||
Test prefix caching with hybrid models containing various combinations of
|
||||
KV cache spec types.
|
||||
|
||||
This unified test covers:
|
||||
- Various combinations (full attn + other attn types)
|
||||
- Varying number of groups (2, 3, or 4)
|
||||
- 0, 1, or 2 full attention groups in the combination
|
||||
- Two sliding_window attn groups with different window sizes
|
||||
- Interleaved group IDs (full attn and other types alternating)
|
||||
- Mamba spec with other attention types
|
||||
"""
|
||||
block_size = 16
|
||||
num_groups = len(spec_types)
|
||||
# Allocate enough blocks for all groups
|
||||
num_blocks = 10 * num_groups
|
||||
|
||||
kv_cache_config = _make_hybrid_kv_cache_config(block_size, num_blocks, spec_types)
|
||||
manager = KVCacheManager(
|
||||
kv_cache_config,
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
)
|
||||
|
||||
hash_fn = sha256
|
||||
|
||||
# Complete 3 blocks (48 tokens)
|
||||
common_token_ids = [i for i in range(3) for _ in range(block_size)]
|
||||
unique_token_ids = [3] * 7
|
||||
all_token_ids = common_token_ids + unique_token_ids
|
||||
|
||||
# First request: no cache hit initially
|
||||
req0 = make_request("0", all_token_ids, block_size, hash_fn)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
|
||||
assert len(req0.block_hashes) == 3
|
||||
assert not computed_blocks.blocks[0] # No cache hit initially
|
||||
assert num_computed_tokens == 0
|
||||
|
||||
blocks = manager.allocate_slots(
|
||||
req0, 55, len(computed_blocks.blocks[0]) * block_size, computed_blocks
|
||||
)
|
||||
assert blocks is not None
|
||||
# Should have blocks for all groups
|
||||
assert len(blocks.get_block_ids()) == num_groups
|
||||
|
||||
# Second request: should hit cached blocks for common prefix
|
||||
req1 = make_request("1", common_token_ids + [4] * 5, block_size, hash_fn)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
|
||||
# Should hit cached blocks for all groups
|
||||
assert num_computed_tokens == 3 * block_size
|
||||
assert len(computed_blocks.blocks) == num_groups
|
||||
|
||||
# Allocate and verify blocks for second request
|
||||
blocks = manager.allocate_slots(
|
||||
req1,
|
||||
len(common_token_ids) + 5 - num_computed_tokens,
|
||||
num_computed_tokens,
|
||||
computed_blocks,
|
||||
)
|
||||
assert blocks is not None
|
||||
assert len(blocks.get_block_ids()) == num_groups
|
||||
|
||||
manager.free(req0)
|
||||
manager.free(req1)
|
||||
|
||||
|
||||
def test_prefill_plp():
|
||||
"""Test prefill with APC and some prompt logprobs (plp) requests.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user