[Fix] prefix cache hit rate == 0 bug with gpt-oss style models (#33524)
Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu>
This commit is contained in:
@@ -107,7 +107,10 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
|
||||
|
||||
|
||||
def make_kv_cache_config_hybrid_model(
|
||||
block_size: int, num_blocks: int, second_spec_type: str = "sliding_window"
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
sliding_window_blocks: int,
|
||||
second_spec_type: str = "sliding_window",
|
||||
) -> KVCacheConfig:
|
||||
if second_spec_type == "sliding_window":
|
||||
second_spec = SlidingWindowSpec(
|
||||
@@ -115,7 +118,7 @@ def make_kv_cache_config_hybrid_model(
|
||||
num_kv_heads=1,
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=2 * block_size,
|
||||
sliding_window=sliding_window_blocks * block_size,
|
||||
)
|
||||
elif second_spec_type == "mamba":
|
||||
second_spec = MambaSpec(
|
||||
@@ -325,7 +328,7 @@ def test_prefill(hash_fn):
|
||||
def test_prefill_hybrid_model():
|
||||
block_size = 16
|
||||
manager = KVCacheManager(
|
||||
make_kv_cache_config_hybrid_model(block_size, 21),
|
||||
make_kv_cache_config_hybrid_model(block_size, 21, 2),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
@@ -334,7 +337,8 @@ def test_prefill_hybrid_model():
|
||||
hash_fn = sha256
|
||||
|
||||
# Complete 3 blocks (48 tokens)
|
||||
common_token_ids = [i for i in range(3) for _ in range(block_size)]
|
||||
num_full_blocks = 3
|
||||
common_token_ids = [i for i in range(num_full_blocks) for _ in range(block_size)]
|
||||
|
||||
# Fully cache miss
|
||||
# Incomplete 1 block (7 tokens)
|
||||
@@ -375,6 +379,7 @@ def test_prefill_hybrid_model():
|
||||
# Cache hit in the common prefix
|
||||
# Incomplete 1 block (5 tokens)
|
||||
unique_token_ids = [3] * 5
|
||||
all_token_ids = common_token_ids + unique_token_ids
|
||||
req1 = make_request("1", common_token_ids + unique_token_ids, block_size, hash_fn)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(req1.block_hashes) == 3
|
||||
@@ -394,34 +399,13 @@ def test_prefill_hybrid_model():
|
||||
manager.free(req0)
|
||||
manager.free(req1)
|
||||
|
||||
cached_block_hash_to_block_bak = copy.copy(
|
||||
manager.block_pool.cached_block_hash_to_block._cache
|
||||
)
|
||||
|
||||
def test_partial_request_hit(
|
||||
request_id: str,
|
||||
hash_to_evict: list[BlockHashWithGroupId],
|
||||
expect_hit_length: int,
|
||||
):
|
||||
req = make_request(
|
||||
request_id, common_token_ids + unique_token_ids, block_size, sha256
|
||||
)
|
||||
for hash_with_group_id in hash_to_evict:
|
||||
manager.block_pool.cached_block_hash_to_block._cache.pop(hash_with_group_id)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert len(req.block_hashes) == 3
|
||||
assert num_computed_tokens == expect_hit_length * block_size
|
||||
for block_per_group in computed_blocks.blocks:
|
||||
assert len(block_per_group) == num_computed_tokens // block_size
|
||||
for hash_with_group_id in hash_to_evict:
|
||||
manager.block_pool.cached_block_hash_to_block._cache[hash_with_group_id] = (
|
||||
cached_block_hash_to_block_bak[hash_with_group_id]
|
||||
)
|
||||
manager.free(req)
|
||||
|
||||
# Evict the blocks outside sliding window, does not affect the hit length.
|
||||
test_partial_request_hit(
|
||||
_test_partial_request_hit(
|
||||
manager,
|
||||
block_size,
|
||||
num_full_blocks,
|
||||
"2",
|
||||
all_token_ids,
|
||||
[
|
||||
make_block_hash_with_group_id(block_hashes[0], 1),
|
||||
make_block_hash_with_group_id(block_hashes[0], 2),
|
||||
@@ -430,13 +414,23 @@ def test_prefill_hybrid_model():
|
||||
)
|
||||
|
||||
# Evict the first block of full attention, makes total cache miss.
|
||||
test_partial_request_hit(
|
||||
"3", [make_block_hash_with_group_id(block_hashes[0], 0)], 0
|
||||
_test_partial_request_hit(
|
||||
manager,
|
||||
block_size,
|
||||
num_full_blocks,
|
||||
"3",
|
||||
all_token_ids,
|
||||
[make_block_hash_with_group_id(block_hashes[0], 0)],
|
||||
0,
|
||||
)
|
||||
|
||||
# Evict the last block of all layers, reduces the hit length to 2.
|
||||
test_partial_request_hit(
|
||||
_test_partial_request_hit(
|
||||
manager,
|
||||
block_size,
|
||||
num_full_blocks,
|
||||
"4",
|
||||
all_token_ids,
|
||||
[
|
||||
make_block_hash_with_group_id(block_hashes[2], 0),
|
||||
make_block_hash_with_group_id(block_hashes[2], 1),
|
||||
@@ -446,18 +440,36 @@ def test_prefill_hybrid_model():
|
||||
)
|
||||
|
||||
# Evict the last block of full attention, reduces the hit length to 2.
|
||||
test_partial_request_hit(
|
||||
"5", [make_block_hash_with_group_id(block_hashes[2], 0)], 2
|
||||
_test_partial_request_hit(
|
||||
manager,
|
||||
block_size,
|
||||
num_full_blocks,
|
||||
"5",
|
||||
all_token_ids,
|
||||
[make_block_hash_with_group_id(block_hashes[2], 0)],
|
||||
2,
|
||||
)
|
||||
|
||||
# Evict the last block of sliding window, reduces the hit length to 2.
|
||||
test_partial_request_hit(
|
||||
"6", [make_block_hash_with_group_id(block_hashes[2], 1)], 2
|
||||
_test_partial_request_hit(
|
||||
manager,
|
||||
block_size,
|
||||
num_full_blocks,
|
||||
"6",
|
||||
all_token_ids,
|
||||
[make_block_hash_with_group_id(block_hashes[2], 1)],
|
||||
2,
|
||||
)
|
||||
|
||||
# Evict the last block of sliding window, reduces the hit length to 2.
|
||||
test_partial_request_hit(
|
||||
"7", [make_block_hash_with_group_id(block_hashes[2], 2)], 2
|
||||
_test_partial_request_hit(
|
||||
manager,
|
||||
block_size,
|
||||
num_full_blocks,
|
||||
"7",
|
||||
all_token_ids,
|
||||
[make_block_hash_with_group_id(block_hashes[2], 2)],
|
||||
2,
|
||||
)
|
||||
|
||||
# Evict different set of blocks for full attention and sliding window makes
|
||||
@@ -466,8 +478,12 @@ def test_prefill_hybrid_model():
|
||||
# The cache hit length of sliding window is 2 * block_size.
|
||||
# Then it is cache miss as the two type of layers
|
||||
# have different hit length.
|
||||
test_partial_request_hit(
|
||||
_test_partial_request_hit(
|
||||
manager,
|
||||
block_size,
|
||||
num_full_blocks,
|
||||
"8",
|
||||
all_token_ids,
|
||||
[
|
||||
make_block_hash_with_group_id(block_hashes[2], 0),
|
||||
make_block_hash_with_group_id(block_hashes[0], 1),
|
||||
@@ -477,6 +493,214 @@ def test_prefill_hybrid_model():
|
||||
)
|
||||
|
||||
|
||||
def test_prefill_hybrid_model_eagle():
|
||||
block_size = 16
|
||||
kv_cache_config = make_kv_cache_config_hybrid_model(block_size, 31, 3)
|
||||
manager = KVCacheManager(
|
||||
kv_cache_config,
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
use_eagle=True,
|
||||
)
|
||||
|
||||
hash_fn = sha256
|
||||
|
||||
# Complete 6 blocks (96 tokens)
|
||||
num_full_blocks = 6
|
||||
common_token_ids = [i for i in range(num_full_blocks) for _ in range(block_size)]
|
||||
|
||||
# Fully cache miss
|
||||
# Incomplete 1 block (7 tokens)
|
||||
unique_token_ids = [6] * 7
|
||||
all_token_ids = common_token_ids + unique_token_ids
|
||||
req0 = make_request("0", all_token_ids, block_size, hash_fn)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert len(req0.block_hashes) == len(all_token_ids) // block_size
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(
|
||||
req0, len(all_token_ids), num_computed_tokens, computed_blocks
|
||||
)
|
||||
block_ids = (
|
||||
[1, 2, 3, 4, 5, 6, 7],
|
||||
[8, 9, 10, 11, 12, 13, 14],
|
||||
[15, 16, 17, 18, 19, 20, 21],
|
||||
)
|
||||
assert blocks is not None and blocks.get_block_ids() == block_ids
|
||||
|
||||
# Check full block metadata
|
||||
parent_block_hash = None
|
||||
for i, full_block_ids in enumerate(zip(*(row[:-1] for row in block_ids))):
|
||||
block_tokens = tuple(all_token_ids[i * block_size : (i + 1) * block_size])
|
||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens)
|
||||
for group_id, block_id in enumerate(full_block_ids):
|
||||
blk_hash = manager.block_pool.blocks[block_id].block_hash
|
||||
assert blk_hash is not None
|
||||
assert get_block_hash(blk_hash) == block_hash
|
||||
assert get_group_id(blk_hash) == group_id
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
parent_block_hash = block_hash
|
||||
|
||||
# Check partial block metadata
|
||||
for partial_block_id in (row[-1] for row in block_ids):
|
||||
assert manager.block_pool.blocks[partial_block_id].block_hash is None
|
||||
assert manager.block_pool.blocks[partial_block_id].ref_cnt == 1
|
||||
|
||||
# Cache hit in the common prefix
|
||||
# Incomplete 1 block (5 tokens)
|
||||
unique_token_ids = [6] * 5
|
||||
all_token_ids = common_token_ids + unique_token_ids
|
||||
req1 = make_request("1", all_token_ids, block_size, hash_fn)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(req1.block_hashes) == num_full_blocks
|
||||
assert computed_blocks.get_block_ids() == (
|
||||
[1, 2, 3, 4],
|
||||
[0, 9, 10, 11],
|
||||
[0, 16, 17, 18],
|
||||
)
|
||||
assert num_computed_tokens == 4 * block_size
|
||||
num_new_tokens = len(all_token_ids) - num_computed_tokens
|
||||
blocks = manager.allocate_slots(
|
||||
req1, num_new_tokens, num_computed_tokens, computed_blocks
|
||||
)
|
||||
assert blocks is not None and blocks.get_block_ids() == (
|
||||
[22, 23, 24],
|
||||
[25, 26, 27],
|
||||
[28, 29, 30],
|
||||
)
|
||||
for block_per_group in computed_blocks.blocks:
|
||||
for block in block_per_group:
|
||||
if block != manager.block_pool.null_block:
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
block_hashes = req1.block_hashes
|
||||
manager.free(req0)
|
||||
manager.free(req1)
|
||||
|
||||
# Evict the blocks outside sliding window, does not affect the hit length.
|
||||
_test_partial_request_hit(
|
||||
manager,
|
||||
block_size,
|
||||
num_full_blocks,
|
||||
"2",
|
||||
all_token_ids,
|
||||
[
|
||||
make_block_hash_with_group_id(block_hashes[0], 1),
|
||||
make_block_hash_with_group_id(block_hashes[0], 2),
|
||||
],
|
||||
4,
|
||||
)
|
||||
|
||||
# Evict the first block of full attention, makes total cache miss.
|
||||
_test_partial_request_hit(
|
||||
manager,
|
||||
block_size,
|
||||
num_full_blocks,
|
||||
"3",
|
||||
all_token_ids,
|
||||
[make_block_hash_with_group_id(block_hashes[0], 0)],
|
||||
0,
|
||||
)
|
||||
|
||||
# Evict the last block of all layers, reduces the hit length to 3.
|
||||
_test_partial_request_hit(
|
||||
manager,
|
||||
block_size,
|
||||
num_full_blocks,
|
||||
"4",
|
||||
all_token_ids,
|
||||
[
|
||||
make_block_hash_with_group_id(block_hashes[-1], 0),
|
||||
make_block_hash_with_group_id(block_hashes[-1], 1),
|
||||
make_block_hash_with_group_id(block_hashes[-1], 2),
|
||||
],
|
||||
3,
|
||||
)
|
||||
|
||||
# Evict the last block of full attention, reduces the hit length to 3.
|
||||
_test_partial_request_hit(
|
||||
manager,
|
||||
block_size,
|
||||
num_full_blocks,
|
||||
"5",
|
||||
all_token_ids,
|
||||
[make_block_hash_with_group_id(block_hashes[-1], 0)],
|
||||
3,
|
||||
)
|
||||
|
||||
# Since the last block of full attention is dropped for eagle, evict
|
||||
# the second last block of sliding window, reduces the hit length to 3.
|
||||
_test_partial_request_hit(
|
||||
manager,
|
||||
block_size,
|
||||
num_full_blocks,
|
||||
"6",
|
||||
all_token_ids,
|
||||
[make_block_hash_with_group_id(block_hashes[-2], 1)],
|
||||
3,
|
||||
)
|
||||
|
||||
# Since the last block of full attention is dropped for eagle, evict
|
||||
# the second last block of sliding window, reduces the hit length to 3.
|
||||
_test_partial_request_hit(
|
||||
manager,
|
||||
block_size,
|
||||
num_full_blocks,
|
||||
"7",
|
||||
all_token_ids,
|
||||
[make_block_hash_with_group_id(block_hashes[-2], 2)],
|
||||
3,
|
||||
)
|
||||
|
||||
# Evict different set of blocks for full attention and sliding window makes
|
||||
# total cache miss.
|
||||
# The cache hit length of full attention is 4 * block_size.
|
||||
# The cache hit length of sliding window is 3 * block_size.
|
||||
# Then it is cache miss as the two type of layers
|
||||
# have different hit length.
|
||||
_test_partial_request_hit(
|
||||
manager,
|
||||
block_size,
|
||||
num_full_blocks,
|
||||
"8",
|
||||
all_token_ids,
|
||||
[
|
||||
make_block_hash_with_group_id(block_hashes[-1], 0),
|
||||
make_block_hash_with_group_id(block_hashes[0], 1),
|
||||
make_block_hash_with_group_id(block_hashes[0], 2),
|
||||
],
|
||||
0,
|
||||
)
|
||||
|
||||
|
||||
def _test_partial_request_hit(
|
||||
manager: KVCacheManager,
|
||||
block_size: int,
|
||||
num_full_blocks,
|
||||
request_id: str,
|
||||
prompt_token_ids: list[int],
|
||||
hash_to_evict: list[BlockHashWithGroupId],
|
||||
expect_hit_length: int,
|
||||
):
|
||||
cached_block_hash_to_block_bak = copy.copy(
|
||||
manager.block_pool.cached_block_hash_to_block._cache
|
||||
)
|
||||
req = make_request(request_id, prompt_token_ids, block_size, sha256)
|
||||
for hash_with_group_id in hash_to_evict:
|
||||
manager.block_pool.cached_block_hash_to_block._cache.pop(hash_with_group_id)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert len(req.block_hashes) == num_full_blocks
|
||||
assert num_computed_tokens == expect_hit_length * block_size
|
||||
for block_per_group in computed_blocks.blocks:
|
||||
assert len(block_per_group) == num_computed_tokens // block_size
|
||||
for hash_with_group_id in hash_to_evict:
|
||||
manager.block_pool.cached_block_hash_to_block._cache[hash_with_group_id] = (
|
||||
cached_block_hash_to_block_bak[hash_with_group_id]
|
||||
)
|
||||
manager.free(req)
|
||||
|
||||
|
||||
def _make_hybrid_kv_cache_config(
|
||||
block_size: int, num_blocks: int, spec_types: list[str]
|
||||
) -> KVCacheConfig:
|
||||
@@ -655,6 +879,85 @@ def test_prefill_hybrid_model_combinations(spec_types: list[str]):
|
||||
manager.free(req1)
|
||||
|
||||
|
||||
# Test cases with eagle enabled: Only test a single simple case for now.
|
||||
# - 2 groups: 1 full + 1 other
|
||||
_EAGLE_HYBRID_MODEL_TEST_CASES = [
|
||||
# 2 groups: 1 full + 1 other
|
||||
pytest.param(["full", "sliding_window"], 2, id="2g-full+sw"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec_types,expect_hit_length", _EAGLE_HYBRID_MODEL_TEST_CASES)
|
||||
def test_prefill_hybrid_model_combinations_eagle(
|
||||
spec_types: list[str], expect_hit_length: int
|
||||
):
|
||||
"""
|
||||
Test prefix caching with hybrid models (1 full attn + 1 other) with EAGLE.
|
||||
More complex hybrid models with EAGLE are not yet supported (see issue #32802).
|
||||
"""
|
||||
block_size = 16
|
||||
num_groups = len(spec_types)
|
||||
# Allocate enough blocks for all groups
|
||||
num_blocks = 10 * num_groups
|
||||
|
||||
kv_cache_config = _make_hybrid_kv_cache_config(block_size, num_blocks, spec_types)
|
||||
manager = KVCacheManager(
|
||||
kv_cache_config,
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
hash_block_size=block_size,
|
||||
use_eagle=True,
|
||||
)
|
||||
|
||||
hash_fn = sha256
|
||||
|
||||
# Complete 3 blocks (48 tokens)
|
||||
num_full_blocks = 4
|
||||
common_token_ids = [i for i in range(num_full_blocks) for _ in range(block_size)]
|
||||
unique_token_ids = [4] * 7
|
||||
all_token_ids = common_token_ids + unique_token_ids
|
||||
|
||||
# First request: no cache hit initially
|
||||
req0 = make_request("0", all_token_ids, block_size, hash_fn)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
|
||||
assert len(req0.block_hashes) == num_full_blocks
|
||||
assert not computed_blocks.blocks[0] # No cache hit initially
|
||||
assert num_computed_tokens == 0
|
||||
|
||||
blocks = manager.allocate_slots(
|
||||
req0, len(all_token_ids), num_computed_tokens, computed_blocks
|
||||
)
|
||||
assert blocks is not None
|
||||
# Should have blocks for all groups
|
||||
assert len(blocks.get_block_ids()) == num_groups
|
||||
|
||||
# Second request: should hit cached blocks for common prefix
|
||||
all_token_ids = common_token_ids + [6] * 5
|
||||
req1 = make_request("1", all_token_ids, block_size, hash_fn)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
|
||||
# Should hit cached blocks for all groups
|
||||
assert num_computed_tokens == expect_hit_length * block_size
|
||||
assert len(computed_blocks.blocks) == num_groups
|
||||
# Verify each group has the correct number of computed blocks
|
||||
for block_per_group in computed_blocks.blocks:
|
||||
assert len(block_per_group) == expect_hit_length
|
||||
|
||||
# Allocate and verify blocks for second request
|
||||
blocks = manager.allocate_slots(
|
||||
req1,
|
||||
len(all_token_ids) - num_computed_tokens,
|
||||
num_computed_tokens,
|
||||
computed_blocks,
|
||||
)
|
||||
assert blocks is not None
|
||||
assert len(blocks.get_block_ids()) == num_groups
|
||||
|
||||
manager.free(req0)
|
||||
manager.free(req1)
|
||||
|
||||
|
||||
def test_prefill_plp():
|
||||
"""Test prefill with APC and some prompt logprobs (plp) requests.
|
||||
|
||||
|
||||
@@ -479,6 +479,16 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
hit_length = max_cache_hit_length
|
||||
hit_blocks_by_group: list[list[KVCacheBlock] | None] = [None] * num_groups
|
||||
|
||||
# Simple hybrid (1 full attn + 1 other): one iteration suffices.
|
||||
# Full attn is always first if it exists. This avoids EAGLE drops
|
||||
# being applied multiple times to non-full-attn groups.
|
||||
# FIXME (yifan): However, for complex hybrid models with multiple attn
|
||||
# groups, we still have the EAGLE spiral block dropping problem. See
|
||||
# discussion in issue https://github.com/vllm-project/vllm/issues/32802.
|
||||
is_simple_hybrid = len(self.attention_groups) == 2 and isinstance(
|
||||
self.attention_groups[0][0], FullAttentionSpec
|
||||
)
|
||||
|
||||
while True:
|
||||
curr_hit_length = hit_length
|
||||
|
||||
@@ -495,10 +505,6 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
# the last iteration.
|
||||
num_blocks = curr_hit_length // spec.block_size
|
||||
curr_hit_length = num_blocks * spec.block_size
|
||||
for group_id in group_ids:
|
||||
blocks = hit_blocks_by_group[group_id]
|
||||
assert blocks is not None
|
||||
del blocks[num_blocks:]
|
||||
else:
|
||||
hit_blocks = manager_cls.find_longest_cache_hit(
|
||||
block_hashes=_get_block_hashes(spec),
|
||||
@@ -513,10 +519,20 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
for group_id, blocks in zip(group_ids, hit_blocks):
|
||||
hit_blocks_by_group[group_id] = blocks
|
||||
|
||||
if curr_hit_length < hit_length:
|
||||
hit_length = curr_hit_length
|
||||
else:
|
||||
if curr_hit_length >= hit_length:
|
||||
break
|
||||
hit_length = curr_hit_length
|
||||
# Simple hybrid: exit after one iteration
|
||||
if is_simple_hybrid:
|
||||
break
|
||||
|
||||
# Truncate full attention blocks to final hit_length (if present)
|
||||
spec, group_ids, _ = self.attention_groups[0]
|
||||
if isinstance(spec, FullAttentionSpec):
|
||||
num_blocks = hit_length // spec.block_size
|
||||
for group_id in group_ids:
|
||||
if (blks := hit_blocks_by_group[group_id]) is not None:
|
||||
del blks[num_blocks:]
|
||||
|
||||
return tuple(
|
||||
blocks if blocks is not None else [] for blocks in hit_blocks_by_group
|
||||
|
||||
Reference in New Issue
Block a user