[Fix] prefix cache hit rate == 0 bug with gpt-oss style models (#33524)

Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu>
This commit is contained in:
Yifan Qiao
2026-02-01 17:59:58 -08:00
committed by GitHub
parent 7320ca3942
commit a01ef3fa51
2 changed files with 366 additions and 47 deletions

View File

@@ -107,7 +107,10 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
def make_kv_cache_config_hybrid_model(
block_size: int, num_blocks: int, second_spec_type: str = "sliding_window"
block_size: int,
num_blocks: int,
sliding_window_blocks: int,
second_spec_type: str = "sliding_window",
) -> KVCacheConfig:
if second_spec_type == "sliding_window":
second_spec = SlidingWindowSpec(
@@ -115,7 +118,7 @@ def make_kv_cache_config_hybrid_model(
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=2 * block_size,
sliding_window=sliding_window_blocks * block_size,
)
elif second_spec_type == "mamba":
second_spec = MambaSpec(
@@ -325,7 +328,7 @@ def test_prefill(hash_fn):
def test_prefill_hybrid_model():
block_size = 16
manager = KVCacheManager(
make_kv_cache_config_hybrid_model(block_size, 21),
make_kv_cache_config_hybrid_model(block_size, 21, 2),
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
@@ -334,7 +337,8 @@ def test_prefill_hybrid_model():
hash_fn = sha256
# Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(block_size)]
num_full_blocks = 3
common_token_ids = [i for i in range(num_full_blocks) for _ in range(block_size)]
# Fully cache miss
# Incomplete 1 block (7 tokens)
@@ -375,6 +379,7 @@ def test_prefill_hybrid_model():
# Cache hit in the common prefix
# Incomplete 1 block (5 tokens)
unique_token_ids = [3] * 5
all_token_ids = common_token_ids + unique_token_ids
req1 = make_request("1", common_token_ids + unique_token_ids, block_size, hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(req1.block_hashes) == 3
@@ -394,34 +399,13 @@ def test_prefill_hybrid_model():
manager.free(req0)
manager.free(req1)
cached_block_hash_to_block_bak = copy.copy(
manager.block_pool.cached_block_hash_to_block._cache
)
def test_partial_request_hit(
request_id: str,
hash_to_evict: list[BlockHashWithGroupId],
expect_hit_length: int,
):
req = make_request(
request_id, common_token_ids + unique_token_ids, block_size, sha256
)
for hash_with_group_id in hash_to_evict:
manager.block_pool.cached_block_hash_to_block._cache.pop(hash_with_group_id)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert len(req.block_hashes) == 3
assert num_computed_tokens == expect_hit_length * block_size
for block_per_group in computed_blocks.blocks:
assert len(block_per_group) == num_computed_tokens // block_size
for hash_with_group_id in hash_to_evict:
manager.block_pool.cached_block_hash_to_block._cache[hash_with_group_id] = (
cached_block_hash_to_block_bak[hash_with_group_id]
)
manager.free(req)
# Evict the blocks outside sliding window, does not affect the hit length.
test_partial_request_hit(
_test_partial_request_hit(
manager,
block_size,
num_full_blocks,
"2",
all_token_ids,
[
make_block_hash_with_group_id(block_hashes[0], 1),
make_block_hash_with_group_id(block_hashes[0], 2),
@@ -430,13 +414,23 @@ def test_prefill_hybrid_model():
)
# Evict the first block of full attention, makes total cache miss.
test_partial_request_hit(
"3", [make_block_hash_with_group_id(block_hashes[0], 0)], 0
_test_partial_request_hit(
manager,
block_size,
num_full_blocks,
"3",
all_token_ids,
[make_block_hash_with_group_id(block_hashes[0], 0)],
0,
)
# Evict the last block of all layers, reduces the hit length to 2.
test_partial_request_hit(
_test_partial_request_hit(
manager,
block_size,
num_full_blocks,
"4",
all_token_ids,
[
make_block_hash_with_group_id(block_hashes[2], 0),
make_block_hash_with_group_id(block_hashes[2], 1),
@@ -446,18 +440,36 @@ def test_prefill_hybrid_model():
)
# Evict the last block of full attention, reduces the hit length to 2.
test_partial_request_hit(
"5", [make_block_hash_with_group_id(block_hashes[2], 0)], 2
_test_partial_request_hit(
manager,
block_size,
num_full_blocks,
"5",
all_token_ids,
[make_block_hash_with_group_id(block_hashes[2], 0)],
2,
)
# Evict the last block of sliding window, reduces the hit length to 2.
test_partial_request_hit(
"6", [make_block_hash_with_group_id(block_hashes[2], 1)], 2
_test_partial_request_hit(
manager,
block_size,
num_full_blocks,
"6",
all_token_ids,
[make_block_hash_with_group_id(block_hashes[2], 1)],
2,
)
# Evict the last block of sliding window, reduces the hit length to 2.
test_partial_request_hit(
"7", [make_block_hash_with_group_id(block_hashes[2], 2)], 2
_test_partial_request_hit(
manager,
block_size,
num_full_blocks,
"7",
all_token_ids,
[make_block_hash_with_group_id(block_hashes[2], 2)],
2,
)
# Evict different set of blocks for full attention and sliding window makes
@@ -466,8 +478,12 @@ def test_prefill_hybrid_model():
# The cache hit length of sliding window is 2 * block_size.
# Then it is cache miss as the two type of layers
# have different hit length.
test_partial_request_hit(
_test_partial_request_hit(
manager,
block_size,
num_full_blocks,
"8",
all_token_ids,
[
make_block_hash_with_group_id(block_hashes[2], 0),
make_block_hash_with_group_id(block_hashes[0], 1),
@@ -477,6 +493,214 @@ def test_prefill_hybrid_model():
)
def test_prefill_hybrid_model_eagle():
block_size = 16
kv_cache_config = make_kv_cache_config_hybrid_model(block_size, 31, 3)
manager = KVCacheManager(
kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
use_eagle=True,
)
hash_fn = sha256
# Complete 6 blocks (96 tokens)
num_full_blocks = 6
common_token_ids = [i for i in range(num_full_blocks) for _ in range(block_size)]
# Fully cache miss
# Incomplete 1 block (7 tokens)
unique_token_ids = [6] * 7
all_token_ids = common_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids, block_size, hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert len(req0.block_hashes) == len(all_token_ids) // block_size
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
blocks = manager.allocate_slots(
req0, len(all_token_ids), num_computed_tokens, computed_blocks
)
block_ids = (
[1, 2, 3, 4, 5, 6, 7],
[8, 9, 10, 11, 12, 13, 14],
[15, 16, 17, 18, 19, 20, 21],
)
assert blocks is not None and blocks.get_block_ids() == block_ids
# Check full block metadata
parent_block_hash = None
for i, full_block_ids in enumerate(zip(*(row[:-1] for row in block_ids))):
block_tokens = tuple(all_token_ids[i * block_size : (i + 1) * block_size])
block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens)
for group_id, block_id in enumerate(full_block_ids):
blk_hash = manager.block_pool.blocks[block_id].block_hash
assert blk_hash is not None
assert get_block_hash(blk_hash) == block_hash
assert get_group_id(blk_hash) == group_id
assert manager.block_pool.blocks[block_id].ref_cnt == 1
parent_block_hash = block_hash
# Check partial block metadata
for partial_block_id in (row[-1] for row in block_ids):
assert manager.block_pool.blocks[partial_block_id].block_hash is None
assert manager.block_pool.blocks[partial_block_id].ref_cnt == 1
# Cache hit in the common prefix
# Incomplete 1 block (5 tokens)
unique_token_ids = [6] * 5
all_token_ids = common_token_ids + unique_token_ids
req1 = make_request("1", all_token_ids, block_size, hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(req1.block_hashes) == num_full_blocks
assert computed_blocks.get_block_ids() == (
[1, 2, 3, 4],
[0, 9, 10, 11],
[0, 16, 17, 18],
)
assert num_computed_tokens == 4 * block_size
num_new_tokens = len(all_token_ids) - num_computed_tokens
blocks = manager.allocate_slots(
req1, num_new_tokens, num_computed_tokens, computed_blocks
)
assert blocks is not None and blocks.get_block_ids() == (
[22, 23, 24],
[25, 26, 27],
[28, 29, 30],
)
for block_per_group in computed_blocks.blocks:
for block in block_per_group:
if block != manager.block_pool.null_block:
assert block.ref_cnt == 2
block_hashes = req1.block_hashes
manager.free(req0)
manager.free(req1)
# Evict the blocks outside sliding window, does not affect the hit length.
_test_partial_request_hit(
manager,
block_size,
num_full_blocks,
"2",
all_token_ids,
[
make_block_hash_with_group_id(block_hashes[0], 1),
make_block_hash_with_group_id(block_hashes[0], 2),
],
4,
)
# Evict the first block of full attention, makes total cache miss.
_test_partial_request_hit(
manager,
block_size,
num_full_blocks,
"3",
all_token_ids,
[make_block_hash_with_group_id(block_hashes[0], 0)],
0,
)
# Evict the last block of all layers, reduces the hit length to 3.
_test_partial_request_hit(
manager,
block_size,
num_full_blocks,
"4",
all_token_ids,
[
make_block_hash_with_group_id(block_hashes[-1], 0),
make_block_hash_with_group_id(block_hashes[-1], 1),
make_block_hash_with_group_id(block_hashes[-1], 2),
],
3,
)
# Evict the last block of full attention, reduces the hit length to 3.
_test_partial_request_hit(
manager,
block_size,
num_full_blocks,
"5",
all_token_ids,
[make_block_hash_with_group_id(block_hashes[-1], 0)],
3,
)
# Since the last block of full attention is dropped for eagle, evict
# the second last block of sliding window, reduces the hit length to 3.
_test_partial_request_hit(
manager,
block_size,
num_full_blocks,
"6",
all_token_ids,
[make_block_hash_with_group_id(block_hashes[-2], 1)],
3,
)
# Since the last block of full attention is dropped for eagle, evict
# the second last block of sliding window, reduces the hit length to 3.
_test_partial_request_hit(
manager,
block_size,
num_full_blocks,
"7",
all_token_ids,
[make_block_hash_with_group_id(block_hashes[-2], 2)],
3,
)
# Evict different set of blocks for full attention and sliding window makes
# total cache miss.
# The cache hit length of full attention is 4 * block_size.
# The cache hit length of sliding window is 3 * block_size.
# Then it is cache miss as the two type of layers
# have different hit length.
_test_partial_request_hit(
manager,
block_size,
num_full_blocks,
"8",
all_token_ids,
[
make_block_hash_with_group_id(block_hashes[-1], 0),
make_block_hash_with_group_id(block_hashes[0], 1),
make_block_hash_with_group_id(block_hashes[0], 2),
],
0,
)
def _test_partial_request_hit(
manager: KVCacheManager,
block_size: int,
num_full_blocks,
request_id: str,
prompt_token_ids: list[int],
hash_to_evict: list[BlockHashWithGroupId],
expect_hit_length: int,
):
cached_block_hash_to_block_bak = copy.copy(
manager.block_pool.cached_block_hash_to_block._cache
)
req = make_request(request_id, prompt_token_ids, block_size, sha256)
for hash_with_group_id in hash_to_evict:
manager.block_pool.cached_block_hash_to_block._cache.pop(hash_with_group_id)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert len(req.block_hashes) == num_full_blocks
assert num_computed_tokens == expect_hit_length * block_size
for block_per_group in computed_blocks.blocks:
assert len(block_per_group) == num_computed_tokens // block_size
for hash_with_group_id in hash_to_evict:
manager.block_pool.cached_block_hash_to_block._cache[hash_with_group_id] = (
cached_block_hash_to_block_bak[hash_with_group_id]
)
manager.free(req)
def _make_hybrid_kv_cache_config(
block_size: int, num_blocks: int, spec_types: list[str]
) -> KVCacheConfig:
@@ -655,6 +879,85 @@ def test_prefill_hybrid_model_combinations(spec_types: list[str]):
manager.free(req1)
# Test cases with eagle enabled: Only test a single simple case for now.
# - 2 groups: 1 full + 1 other
_EAGLE_HYBRID_MODEL_TEST_CASES = [
# 2 groups: 1 full + 1 other
pytest.param(["full", "sliding_window"], 2, id="2g-full+sw"),
]
@pytest.mark.parametrize("spec_types,expect_hit_length", _EAGLE_HYBRID_MODEL_TEST_CASES)
def test_prefill_hybrid_model_combinations_eagle(
spec_types: list[str], expect_hit_length: int
):
"""
Test prefix caching with hybrid models (1 full attn + 1 other) with EAGLE.
More complex hybrid models with EAGLE are not yet supported (see issue #32802).
"""
block_size = 16
num_groups = len(spec_types)
# Allocate enough blocks for all groups
num_blocks = 10 * num_groups
kv_cache_config = _make_hybrid_kv_cache_config(block_size, num_blocks, spec_types)
manager = KVCacheManager(
kv_cache_config,
max_model_len=8192,
enable_caching=True,
hash_block_size=block_size,
use_eagle=True,
)
hash_fn = sha256
# Complete 3 blocks (48 tokens)
num_full_blocks = 4
common_token_ids = [i for i in range(num_full_blocks) for _ in range(block_size)]
unique_token_ids = [4] * 7
all_token_ids = common_token_ids + unique_token_ids
# First request: no cache hit initially
req0 = make_request("0", all_token_ids, block_size, hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert len(req0.block_hashes) == num_full_blocks
assert not computed_blocks.blocks[0] # No cache hit initially
assert num_computed_tokens == 0
blocks = manager.allocate_slots(
req0, len(all_token_ids), num_computed_tokens, computed_blocks
)
assert blocks is not None
# Should have blocks for all groups
assert len(blocks.get_block_ids()) == num_groups
# Second request: should hit cached blocks for common prefix
all_token_ids = common_token_ids + [6] * 5
req1 = make_request("1", all_token_ids, block_size, hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
# Should hit cached blocks for all groups
assert num_computed_tokens == expect_hit_length * block_size
assert len(computed_blocks.blocks) == num_groups
# Verify each group has the correct number of computed blocks
for block_per_group in computed_blocks.blocks:
assert len(block_per_group) == expect_hit_length
# Allocate and verify blocks for second request
blocks = manager.allocate_slots(
req1,
len(all_token_ids) - num_computed_tokens,
num_computed_tokens,
computed_blocks,
)
assert blocks is not None
assert len(blocks.get_block_ids()) == num_groups
manager.free(req0)
manager.free(req1)
def test_prefill_plp():
"""Test prefill with APC and some prompt logprobs (plp) requests.

View File

@@ -479,6 +479,16 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
hit_length = max_cache_hit_length
hit_blocks_by_group: list[list[KVCacheBlock] | None] = [None] * num_groups
# Simple hybrid (1 full attn + 1 other): one iteration suffices.
# Full attn is always first if it exists. This avoids EAGLE drops
# being applied multiple times to non-full-attn groups.
# FIXME (yifan): However, for complex hybrid models with multiple attn
# groups, we still have the EAGLE spiral block dropping problem. See
# discussion in issue https://github.com/vllm-project/vllm/issues/32802.
is_simple_hybrid = len(self.attention_groups) == 2 and isinstance(
self.attention_groups[0][0], FullAttentionSpec
)
while True:
curr_hit_length = hit_length
@@ -495,10 +505,6 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
# the last iteration.
num_blocks = curr_hit_length // spec.block_size
curr_hit_length = num_blocks * spec.block_size
for group_id in group_ids:
blocks = hit_blocks_by_group[group_id]
assert blocks is not None
del blocks[num_blocks:]
else:
hit_blocks = manager_cls.find_longest_cache_hit(
block_hashes=_get_block_hashes(spec),
@@ -513,10 +519,20 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
for group_id, blocks in zip(group_ids, hit_blocks):
hit_blocks_by_group[group_id] = blocks
if curr_hit_length < hit_length:
hit_length = curr_hit_length
else:
if curr_hit_length >= hit_length:
break
hit_length = curr_hit_length
# Simple hybrid: exit after one iteration
if is_simple_hybrid:
break
# Truncate full attention blocks to final hit_length (if present)
spec, group_ids, _ = self.attention_groups[0]
if isinstance(spec, FullAttentionSpec):
num_blocks = hit_length // spec.block_size
for group_id in group_ids:
if (blks := hit_blocks_by_group[group_id]) is not None:
del blks[num_blocks:]
return tuple(
blocks if blocks is not None else [] for blocks in hit_blocks_by_group