[Fix] prefix cache hit rate == 0 bug with gpt-oss style models (#33524)
Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu>
(cherry picked from commit a01ef3fa51)
This commit is contained in:
@@ -107,7 +107,10 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
|
|||||||
|
|
||||||
|
|
||||||
def make_kv_cache_config_hybrid_model(
|
def make_kv_cache_config_hybrid_model(
|
||||||
block_size: int, num_blocks: int, second_spec_type: str = "sliding_window"
|
block_size: int,
|
||||||
|
num_blocks: int,
|
||||||
|
sliding_window_blocks: int,
|
||||||
|
second_spec_type: str = "sliding_window",
|
||||||
) -> KVCacheConfig:
|
) -> KVCacheConfig:
|
||||||
if second_spec_type == "sliding_window":
|
if second_spec_type == "sliding_window":
|
||||||
second_spec = SlidingWindowSpec(
|
second_spec = SlidingWindowSpec(
|
||||||
@@ -115,7 +118,7 @@ def make_kv_cache_config_hybrid_model(
|
|||||||
num_kv_heads=1,
|
num_kv_heads=1,
|
||||||
head_size=1,
|
head_size=1,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
sliding_window=2 * block_size,
|
sliding_window=sliding_window_blocks * block_size,
|
||||||
)
|
)
|
||||||
elif second_spec_type == "mamba":
|
elif second_spec_type == "mamba":
|
||||||
second_spec = MambaSpec(
|
second_spec = MambaSpec(
|
||||||
@@ -325,7 +328,7 @@ def test_prefill(hash_fn):
|
|||||||
def test_prefill_hybrid_model():
|
def test_prefill_hybrid_model():
|
||||||
block_size = 16
|
block_size = 16
|
||||||
manager = KVCacheManager(
|
manager = KVCacheManager(
|
||||||
make_kv_cache_config_hybrid_model(block_size, 21),
|
make_kv_cache_config_hybrid_model(block_size, 21, 2),
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
enable_caching=True,
|
enable_caching=True,
|
||||||
hash_block_size=block_size,
|
hash_block_size=block_size,
|
||||||
@@ -334,7 +337,8 @@ def test_prefill_hybrid_model():
|
|||||||
hash_fn = sha256
|
hash_fn = sha256
|
||||||
|
|
||||||
# Complete 3 blocks (48 tokens)
|
# Complete 3 blocks (48 tokens)
|
||||||
common_token_ids = [i for i in range(3) for _ in range(block_size)]
|
num_full_blocks = 3
|
||||||
|
common_token_ids = [i for i in range(num_full_blocks) for _ in range(block_size)]
|
||||||
|
|
||||||
# Fully cache miss
|
# Fully cache miss
|
||||||
# Incomplete 1 block (7 tokens)
|
# Incomplete 1 block (7 tokens)
|
||||||
@@ -375,6 +379,7 @@ def test_prefill_hybrid_model():
|
|||||||
# Cache hit in the common prefix
|
# Cache hit in the common prefix
|
||||||
# Incomplete 1 block (5 tokens)
|
# Incomplete 1 block (5 tokens)
|
||||||
unique_token_ids = [3] * 5
|
unique_token_ids = [3] * 5
|
||||||
|
all_token_ids = common_token_ids + unique_token_ids
|
||||||
req1 = make_request("1", common_token_ids + unique_token_ids, block_size, hash_fn)
|
req1 = make_request("1", common_token_ids + unique_token_ids, block_size, hash_fn)
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||||
assert len(req1.block_hashes) == 3
|
assert len(req1.block_hashes) == 3
|
||||||
@@ -394,34 +399,13 @@ def test_prefill_hybrid_model():
|
|||||||
manager.free(req0)
|
manager.free(req0)
|
||||||
manager.free(req1)
|
manager.free(req1)
|
||||||
|
|
||||||
cached_block_hash_to_block_bak = copy.copy(
|
|
||||||
manager.block_pool.cached_block_hash_to_block._cache
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_partial_request_hit(
|
|
||||||
request_id: str,
|
|
||||||
hash_to_evict: list[BlockHashWithGroupId],
|
|
||||||
expect_hit_length: int,
|
|
||||||
):
|
|
||||||
req = make_request(
|
|
||||||
request_id, common_token_ids + unique_token_ids, block_size, sha256
|
|
||||||
)
|
|
||||||
for hash_with_group_id in hash_to_evict:
|
|
||||||
manager.block_pool.cached_block_hash_to_block._cache.pop(hash_with_group_id)
|
|
||||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
|
||||||
assert len(req.block_hashes) == 3
|
|
||||||
assert num_computed_tokens == expect_hit_length * block_size
|
|
||||||
for block_per_group in computed_blocks.blocks:
|
|
||||||
assert len(block_per_group) == num_computed_tokens // block_size
|
|
||||||
for hash_with_group_id in hash_to_evict:
|
|
||||||
manager.block_pool.cached_block_hash_to_block._cache[hash_with_group_id] = (
|
|
||||||
cached_block_hash_to_block_bak[hash_with_group_id]
|
|
||||||
)
|
|
||||||
manager.free(req)
|
|
||||||
|
|
||||||
# Evict the blocks outside sliding window, does not affect the hit length.
|
# Evict the blocks outside sliding window, does not affect the hit length.
|
||||||
test_partial_request_hit(
|
_test_partial_request_hit(
|
||||||
|
manager,
|
||||||
|
block_size,
|
||||||
|
num_full_blocks,
|
||||||
"2",
|
"2",
|
||||||
|
all_token_ids,
|
||||||
[
|
[
|
||||||
make_block_hash_with_group_id(block_hashes[0], 1),
|
make_block_hash_with_group_id(block_hashes[0], 1),
|
||||||
make_block_hash_with_group_id(block_hashes[0], 2),
|
make_block_hash_with_group_id(block_hashes[0], 2),
|
||||||
@@ -430,13 +414,23 @@ def test_prefill_hybrid_model():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Evict the first block of full attention, makes total cache miss.
|
# Evict the first block of full attention, makes total cache miss.
|
||||||
test_partial_request_hit(
|
_test_partial_request_hit(
|
||||||
"3", [make_block_hash_with_group_id(block_hashes[0], 0)], 0
|
manager,
|
||||||
|
block_size,
|
||||||
|
num_full_blocks,
|
||||||
|
"3",
|
||||||
|
all_token_ids,
|
||||||
|
[make_block_hash_with_group_id(block_hashes[0], 0)],
|
||||||
|
0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Evict the last block of all layers, reduces the hit length to 2.
|
# Evict the last block of all layers, reduces the hit length to 2.
|
||||||
test_partial_request_hit(
|
_test_partial_request_hit(
|
||||||
|
manager,
|
||||||
|
block_size,
|
||||||
|
num_full_blocks,
|
||||||
"4",
|
"4",
|
||||||
|
all_token_ids,
|
||||||
[
|
[
|
||||||
make_block_hash_with_group_id(block_hashes[2], 0),
|
make_block_hash_with_group_id(block_hashes[2], 0),
|
||||||
make_block_hash_with_group_id(block_hashes[2], 1),
|
make_block_hash_with_group_id(block_hashes[2], 1),
|
||||||
@@ -446,18 +440,36 @@ def test_prefill_hybrid_model():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Evict the last block of full attention, reduces the hit length to 2.
|
# Evict the last block of full attention, reduces the hit length to 2.
|
||||||
test_partial_request_hit(
|
_test_partial_request_hit(
|
||||||
"5", [make_block_hash_with_group_id(block_hashes[2], 0)], 2
|
manager,
|
||||||
|
block_size,
|
||||||
|
num_full_blocks,
|
||||||
|
"5",
|
||||||
|
all_token_ids,
|
||||||
|
[make_block_hash_with_group_id(block_hashes[2], 0)],
|
||||||
|
2,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Evict the last block of sliding window, reduces the hit length to 2.
|
# Evict the last block of sliding window, reduces the hit length to 2.
|
||||||
test_partial_request_hit(
|
_test_partial_request_hit(
|
||||||
"6", [make_block_hash_with_group_id(block_hashes[2], 1)], 2
|
manager,
|
||||||
|
block_size,
|
||||||
|
num_full_blocks,
|
||||||
|
"6",
|
||||||
|
all_token_ids,
|
||||||
|
[make_block_hash_with_group_id(block_hashes[2], 1)],
|
||||||
|
2,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Evict the last block of sliding window, reduces the hit length to 2.
|
# Evict the last block of sliding window, reduces the hit length to 2.
|
||||||
test_partial_request_hit(
|
_test_partial_request_hit(
|
||||||
"7", [make_block_hash_with_group_id(block_hashes[2], 2)], 2
|
manager,
|
||||||
|
block_size,
|
||||||
|
num_full_blocks,
|
||||||
|
"7",
|
||||||
|
all_token_ids,
|
||||||
|
[make_block_hash_with_group_id(block_hashes[2], 2)],
|
||||||
|
2,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Evict different set of blocks for full attention and sliding window makes
|
# Evict different set of blocks for full attention and sliding window makes
|
||||||
@@ -466,8 +478,12 @@ def test_prefill_hybrid_model():
|
|||||||
# The cache hit length of sliding window is 2 * block_size.
|
# The cache hit length of sliding window is 2 * block_size.
|
||||||
# Then it is cache miss as the two type of layers
|
# Then it is cache miss as the two type of layers
|
||||||
# have different hit length.
|
# have different hit length.
|
||||||
test_partial_request_hit(
|
_test_partial_request_hit(
|
||||||
|
manager,
|
||||||
|
block_size,
|
||||||
|
num_full_blocks,
|
||||||
"8",
|
"8",
|
||||||
|
all_token_ids,
|
||||||
[
|
[
|
||||||
make_block_hash_with_group_id(block_hashes[2], 0),
|
make_block_hash_with_group_id(block_hashes[2], 0),
|
||||||
make_block_hash_with_group_id(block_hashes[0], 1),
|
make_block_hash_with_group_id(block_hashes[0], 1),
|
||||||
@@ -477,6 +493,214 @@ def test_prefill_hybrid_model():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_prefill_hybrid_model_eagle():
|
||||||
|
block_size = 16
|
||||||
|
kv_cache_config = make_kv_cache_config_hybrid_model(block_size, 31, 3)
|
||||||
|
manager = KVCacheManager(
|
||||||
|
kv_cache_config,
|
||||||
|
max_model_len=8192,
|
||||||
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
|
use_eagle=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
hash_fn = sha256
|
||||||
|
|
||||||
|
# Complete 6 blocks (96 tokens)
|
||||||
|
num_full_blocks = 6
|
||||||
|
common_token_ids = [i for i in range(num_full_blocks) for _ in range(block_size)]
|
||||||
|
|
||||||
|
# Fully cache miss
|
||||||
|
# Incomplete 1 block (7 tokens)
|
||||||
|
unique_token_ids = [6] * 7
|
||||||
|
all_token_ids = common_token_ids + unique_token_ids
|
||||||
|
req0 = make_request("0", all_token_ids, block_size, hash_fn)
|
||||||
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||||
|
assert len(req0.block_hashes) == len(all_token_ids) // block_size
|
||||||
|
assert not computed_blocks.blocks[0]
|
||||||
|
assert num_computed_tokens == 0
|
||||||
|
blocks = manager.allocate_slots(
|
||||||
|
req0, len(all_token_ids), num_computed_tokens, computed_blocks
|
||||||
|
)
|
||||||
|
block_ids = (
|
||||||
|
[1, 2, 3, 4, 5, 6, 7],
|
||||||
|
[8, 9, 10, 11, 12, 13, 14],
|
||||||
|
[15, 16, 17, 18, 19, 20, 21],
|
||||||
|
)
|
||||||
|
assert blocks is not None and blocks.get_block_ids() == block_ids
|
||||||
|
|
||||||
|
# Check full block metadata
|
||||||
|
parent_block_hash = None
|
||||||
|
for i, full_block_ids in enumerate(zip(*(row[:-1] for row in block_ids))):
|
||||||
|
block_tokens = tuple(all_token_ids[i * block_size : (i + 1) * block_size])
|
||||||
|
block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_tokens)
|
||||||
|
for group_id, block_id in enumerate(full_block_ids):
|
||||||
|
blk_hash = manager.block_pool.blocks[block_id].block_hash
|
||||||
|
assert blk_hash is not None
|
||||||
|
assert get_block_hash(blk_hash) == block_hash
|
||||||
|
assert get_group_id(blk_hash) == group_id
|
||||||
|
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||||
|
parent_block_hash = block_hash
|
||||||
|
|
||||||
|
# Check partial block metadata
|
||||||
|
for partial_block_id in (row[-1] for row in block_ids):
|
||||||
|
assert manager.block_pool.blocks[partial_block_id].block_hash is None
|
||||||
|
assert manager.block_pool.blocks[partial_block_id].ref_cnt == 1
|
||||||
|
|
||||||
|
# Cache hit in the common prefix
|
||||||
|
# Incomplete 1 block (5 tokens)
|
||||||
|
unique_token_ids = [6] * 5
|
||||||
|
all_token_ids = common_token_ids + unique_token_ids
|
||||||
|
req1 = make_request("1", all_token_ids, block_size, hash_fn)
|
||||||
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||||
|
assert len(req1.block_hashes) == num_full_blocks
|
||||||
|
assert computed_blocks.get_block_ids() == (
|
||||||
|
[1, 2, 3, 4],
|
||||||
|
[0, 9, 10, 11],
|
||||||
|
[0, 16, 17, 18],
|
||||||
|
)
|
||||||
|
assert num_computed_tokens == 4 * block_size
|
||||||
|
num_new_tokens = len(all_token_ids) - num_computed_tokens
|
||||||
|
blocks = manager.allocate_slots(
|
||||||
|
req1, num_new_tokens, num_computed_tokens, computed_blocks
|
||||||
|
)
|
||||||
|
assert blocks is not None and blocks.get_block_ids() == (
|
||||||
|
[22, 23, 24],
|
||||||
|
[25, 26, 27],
|
||||||
|
[28, 29, 30],
|
||||||
|
)
|
||||||
|
for block_per_group in computed_blocks.blocks:
|
||||||
|
for block in block_per_group:
|
||||||
|
if block != manager.block_pool.null_block:
|
||||||
|
assert block.ref_cnt == 2
|
||||||
|
|
||||||
|
block_hashes = req1.block_hashes
|
||||||
|
manager.free(req0)
|
||||||
|
manager.free(req1)
|
||||||
|
|
||||||
|
# Evict the blocks outside sliding window, does not affect the hit length.
|
||||||
|
_test_partial_request_hit(
|
||||||
|
manager,
|
||||||
|
block_size,
|
||||||
|
num_full_blocks,
|
||||||
|
"2",
|
||||||
|
all_token_ids,
|
||||||
|
[
|
||||||
|
make_block_hash_with_group_id(block_hashes[0], 1),
|
||||||
|
make_block_hash_with_group_id(block_hashes[0], 2),
|
||||||
|
],
|
||||||
|
4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Evict the first block of full attention, makes total cache miss.
|
||||||
|
_test_partial_request_hit(
|
||||||
|
manager,
|
||||||
|
block_size,
|
||||||
|
num_full_blocks,
|
||||||
|
"3",
|
||||||
|
all_token_ids,
|
||||||
|
[make_block_hash_with_group_id(block_hashes[0], 0)],
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Evict the last block of all layers, reduces the hit length to 3.
|
||||||
|
_test_partial_request_hit(
|
||||||
|
manager,
|
||||||
|
block_size,
|
||||||
|
num_full_blocks,
|
||||||
|
"4",
|
||||||
|
all_token_ids,
|
||||||
|
[
|
||||||
|
make_block_hash_with_group_id(block_hashes[-1], 0),
|
||||||
|
make_block_hash_with_group_id(block_hashes[-1], 1),
|
||||||
|
make_block_hash_with_group_id(block_hashes[-1], 2),
|
||||||
|
],
|
||||||
|
3,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Evict the last block of full attention, reduces the hit length to 3.
|
||||||
|
_test_partial_request_hit(
|
||||||
|
manager,
|
||||||
|
block_size,
|
||||||
|
num_full_blocks,
|
||||||
|
"5",
|
||||||
|
all_token_ids,
|
||||||
|
[make_block_hash_with_group_id(block_hashes[-1], 0)],
|
||||||
|
3,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Since the last block of full attention is dropped for eagle, evict
|
||||||
|
# the second last block of sliding window, reduces the hit length to 3.
|
||||||
|
_test_partial_request_hit(
|
||||||
|
manager,
|
||||||
|
block_size,
|
||||||
|
num_full_blocks,
|
||||||
|
"6",
|
||||||
|
all_token_ids,
|
||||||
|
[make_block_hash_with_group_id(block_hashes[-2], 1)],
|
||||||
|
3,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Since the last block of full attention is dropped for eagle, evict
|
||||||
|
# the second last block of sliding window, reduces the hit length to 3.
|
||||||
|
_test_partial_request_hit(
|
||||||
|
manager,
|
||||||
|
block_size,
|
||||||
|
num_full_blocks,
|
||||||
|
"7",
|
||||||
|
all_token_ids,
|
||||||
|
[make_block_hash_with_group_id(block_hashes[-2], 2)],
|
||||||
|
3,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Evict different set of blocks for full attention and sliding window makes
|
||||||
|
# total cache miss.
|
||||||
|
# The cache hit length of full attention is 4 * block_size.
|
||||||
|
# The cache hit length of sliding window is 3 * block_size.
|
||||||
|
# Then it is cache miss as the two type of layers
|
||||||
|
# have different hit length.
|
||||||
|
_test_partial_request_hit(
|
||||||
|
manager,
|
||||||
|
block_size,
|
||||||
|
num_full_blocks,
|
||||||
|
"8",
|
||||||
|
all_token_ids,
|
||||||
|
[
|
||||||
|
make_block_hash_with_group_id(block_hashes[-1], 0),
|
||||||
|
make_block_hash_with_group_id(block_hashes[0], 1),
|
||||||
|
make_block_hash_with_group_id(block_hashes[0], 2),
|
||||||
|
],
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_partial_request_hit(
|
||||||
|
manager: KVCacheManager,
|
||||||
|
block_size: int,
|
||||||
|
num_full_blocks,
|
||||||
|
request_id: str,
|
||||||
|
prompt_token_ids: list[int],
|
||||||
|
hash_to_evict: list[BlockHashWithGroupId],
|
||||||
|
expect_hit_length: int,
|
||||||
|
):
|
||||||
|
cached_block_hash_to_block_bak = copy.copy(
|
||||||
|
manager.block_pool.cached_block_hash_to_block._cache
|
||||||
|
)
|
||||||
|
req = make_request(request_id, prompt_token_ids, block_size, sha256)
|
||||||
|
for hash_with_group_id in hash_to_evict:
|
||||||
|
manager.block_pool.cached_block_hash_to_block._cache.pop(hash_with_group_id)
|
||||||
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||||
|
assert len(req.block_hashes) == num_full_blocks
|
||||||
|
assert num_computed_tokens == expect_hit_length * block_size
|
||||||
|
for block_per_group in computed_blocks.blocks:
|
||||||
|
assert len(block_per_group) == num_computed_tokens // block_size
|
||||||
|
for hash_with_group_id in hash_to_evict:
|
||||||
|
manager.block_pool.cached_block_hash_to_block._cache[hash_with_group_id] = (
|
||||||
|
cached_block_hash_to_block_bak[hash_with_group_id]
|
||||||
|
)
|
||||||
|
manager.free(req)
|
||||||
|
|
||||||
|
|
||||||
def _make_hybrid_kv_cache_config(
|
def _make_hybrid_kv_cache_config(
|
||||||
block_size: int, num_blocks: int, spec_types: list[str]
|
block_size: int, num_blocks: int, spec_types: list[str]
|
||||||
) -> KVCacheConfig:
|
) -> KVCacheConfig:
|
||||||
@@ -655,6 +879,85 @@ def test_prefill_hybrid_model_combinations(spec_types: list[str]):
|
|||||||
manager.free(req1)
|
manager.free(req1)
|
||||||
|
|
||||||
|
|
||||||
|
# Test cases with eagle enabled: Only test a single simple case for now.
|
||||||
|
# - 2 groups: 1 full + 1 other
|
||||||
|
_EAGLE_HYBRID_MODEL_TEST_CASES = [
|
||||||
|
# 2 groups: 1 full + 1 other
|
||||||
|
pytest.param(["full", "sliding_window"], 2, id="2g-full+sw"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("spec_types,expect_hit_length", _EAGLE_HYBRID_MODEL_TEST_CASES)
|
||||||
|
def test_prefill_hybrid_model_combinations_eagle(
|
||||||
|
spec_types: list[str], expect_hit_length: int
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test prefix caching with hybrid models (1 full attn + 1 other) with EAGLE.
|
||||||
|
More complex hybrid models with EAGLE are not yet supported (see issue #32802).
|
||||||
|
"""
|
||||||
|
block_size = 16
|
||||||
|
num_groups = len(spec_types)
|
||||||
|
# Allocate enough blocks for all groups
|
||||||
|
num_blocks = 10 * num_groups
|
||||||
|
|
||||||
|
kv_cache_config = _make_hybrid_kv_cache_config(block_size, num_blocks, spec_types)
|
||||||
|
manager = KVCacheManager(
|
||||||
|
kv_cache_config,
|
||||||
|
max_model_len=8192,
|
||||||
|
enable_caching=True,
|
||||||
|
hash_block_size=block_size,
|
||||||
|
use_eagle=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
hash_fn = sha256
|
||||||
|
|
||||||
|
# Complete 3 blocks (48 tokens)
|
||||||
|
num_full_blocks = 4
|
||||||
|
common_token_ids = [i for i in range(num_full_blocks) for _ in range(block_size)]
|
||||||
|
unique_token_ids = [4] * 7
|
||||||
|
all_token_ids = common_token_ids + unique_token_ids
|
||||||
|
|
||||||
|
# First request: no cache hit initially
|
||||||
|
req0 = make_request("0", all_token_ids, block_size, hash_fn)
|
||||||
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||||
|
|
||||||
|
assert len(req0.block_hashes) == num_full_blocks
|
||||||
|
assert not computed_blocks.blocks[0] # No cache hit initially
|
||||||
|
assert num_computed_tokens == 0
|
||||||
|
|
||||||
|
blocks = manager.allocate_slots(
|
||||||
|
req0, len(all_token_ids), num_computed_tokens, computed_blocks
|
||||||
|
)
|
||||||
|
assert blocks is not None
|
||||||
|
# Should have blocks for all groups
|
||||||
|
assert len(blocks.get_block_ids()) == num_groups
|
||||||
|
|
||||||
|
# Second request: should hit cached blocks for common prefix
|
||||||
|
all_token_ids = common_token_ids + [6] * 5
|
||||||
|
req1 = make_request("1", all_token_ids, block_size, hash_fn)
|
||||||
|
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||||
|
|
||||||
|
# Should hit cached blocks for all groups
|
||||||
|
assert num_computed_tokens == expect_hit_length * block_size
|
||||||
|
assert len(computed_blocks.blocks) == num_groups
|
||||||
|
# Verify each group has the correct number of computed blocks
|
||||||
|
for block_per_group in computed_blocks.blocks:
|
||||||
|
assert len(block_per_group) == expect_hit_length
|
||||||
|
|
||||||
|
# Allocate and verify blocks for second request
|
||||||
|
blocks = manager.allocate_slots(
|
||||||
|
req1,
|
||||||
|
len(all_token_ids) - num_computed_tokens,
|
||||||
|
num_computed_tokens,
|
||||||
|
computed_blocks,
|
||||||
|
)
|
||||||
|
assert blocks is not None
|
||||||
|
assert len(blocks.get_block_ids()) == num_groups
|
||||||
|
|
||||||
|
manager.free(req0)
|
||||||
|
manager.free(req1)
|
||||||
|
|
||||||
|
|
||||||
def test_prefill_plp():
|
def test_prefill_plp():
|
||||||
"""Test prefill with APC and some prompt logprobs (plp) requests.
|
"""Test prefill with APC and some prompt logprobs (plp) requests.
|
||||||
|
|
||||||
|
|||||||
@@ -479,6 +479,16 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
|||||||
hit_length = max_cache_hit_length
|
hit_length = max_cache_hit_length
|
||||||
hit_blocks_by_group: list[list[KVCacheBlock] | None] = [None] * num_groups
|
hit_blocks_by_group: list[list[KVCacheBlock] | None] = [None] * num_groups
|
||||||
|
|
||||||
|
# Simple hybrid (1 full attn + 1 other): one iteration suffices.
|
||||||
|
# Full attn is always first if it exists. This avoids EAGLE drops
|
||||||
|
# being applied multiple times to non-full-attn groups.
|
||||||
|
# FIXME (yifan): However, for complex hybrid models with multiple attn
|
||||||
|
# groups, we still have the EAGLE spiral block dropping problem. See
|
||||||
|
# discussion in issue https://github.com/vllm-project/vllm/issues/32802.
|
||||||
|
is_simple_hybrid = len(self.attention_groups) == 2 and isinstance(
|
||||||
|
self.attention_groups[0][0], FullAttentionSpec
|
||||||
|
)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
curr_hit_length = hit_length
|
curr_hit_length = hit_length
|
||||||
|
|
||||||
@@ -495,10 +505,6 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
|||||||
# the last iteration.
|
# the last iteration.
|
||||||
num_blocks = curr_hit_length // spec.block_size
|
num_blocks = curr_hit_length // spec.block_size
|
||||||
curr_hit_length = num_blocks * spec.block_size
|
curr_hit_length = num_blocks * spec.block_size
|
||||||
for group_id in group_ids:
|
|
||||||
blocks = hit_blocks_by_group[group_id]
|
|
||||||
assert blocks is not None
|
|
||||||
del blocks[num_blocks:]
|
|
||||||
else:
|
else:
|
||||||
hit_blocks = manager_cls.find_longest_cache_hit(
|
hit_blocks = manager_cls.find_longest_cache_hit(
|
||||||
block_hashes=_get_block_hashes(spec),
|
block_hashes=_get_block_hashes(spec),
|
||||||
@@ -513,10 +519,20 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
|||||||
for group_id, blocks in zip(group_ids, hit_blocks):
|
for group_id, blocks in zip(group_ids, hit_blocks):
|
||||||
hit_blocks_by_group[group_id] = blocks
|
hit_blocks_by_group[group_id] = blocks
|
||||||
|
|
||||||
if curr_hit_length < hit_length:
|
if curr_hit_length >= hit_length:
|
||||||
hit_length = curr_hit_length
|
|
||||||
else:
|
|
||||||
break
|
break
|
||||||
|
hit_length = curr_hit_length
|
||||||
|
# Simple hybrid: exit after one iteration
|
||||||
|
if is_simple_hybrid:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Truncate full attention blocks to final hit_length (if present)
|
||||||
|
spec, group_ids, _ = self.attention_groups[0]
|
||||||
|
if isinstance(spec, FullAttentionSpec):
|
||||||
|
num_blocks = hit_length // spec.block_size
|
||||||
|
for group_id in group_ids:
|
||||||
|
if (blks := hit_blocks_by_group[group_id]) is not None:
|
||||||
|
del blks[num_blocks:]
|
||||||
|
|
||||||
return tuple(
|
return tuple(
|
||||||
blocks if blocks is not None else [] for blocks in hit_blocks_by_group
|
blocks if blocks is not None else [] for blocks in hit_blocks_by_group
|
||||||
|
|||||||
Reference in New Issue
Block a user