[v1][KVCacheManager] pass num_new_computed_tokens to kv cache manager (#18001)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-05-14 10:09:39 +08:00
committed by GitHub
parent 40de1ef455
commit f2ae883b67
3 changed files with 119 additions and 53 deletions

View File

@@ -81,7 +81,9 @@ def test_prefill(hash_algo):
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]
# Check full block metadata
@@ -108,7 +110,9 @@ def test_prefill(hash_algo):
assert computed_blocks.get_block_ids() == [1, 2, 3]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [5]
for block in computed_blocks.blocks:
assert block.ref_cnt == 2
@@ -140,7 +144,9 @@ def test_prefill(hash_algo):
assert computed_blocks.get_block_ids() == [1, 2, 3]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
blocks = manager.allocate_slots(req2, num_new_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [6]
# Although we only have 6 free blocks, we have 8 blocks in
@@ -161,7 +167,9 @@ def test_prefill(hash_algo):
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req3, 16 * 10, computed_blocks)
blocks = manager.allocate_slots(req3, 16 * 10,
len(computed_blocks.blocks) * 16,
computed_blocks)
# This block ID order also checks the eviction order.
assert blocks.get_block_ids() == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1]
assert manager.block_pool.free_block_queue.num_free_blocks == 0
@@ -197,7 +205,9 @@ def test_prefill_plp():
assert len(manager.req_to_block_hashes[req0.request_id]) == 0
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]
req0_block_hashes = [b.block_hash for b in blocks.blocks]
@@ -226,7 +236,9 @@ def test_prefill_plp():
assert computed_blocks.get_block_ids() == [1, 2, 3]
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
blocks = manager.allocate_slots(req1, num_new_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [5]
for block in computed_blocks.blocks:
assert block.ref_cnt == 2
@@ -259,7 +271,9 @@ def test_prefill_plp():
assert len(manager.req_to_block_hashes[req2.request_id]) == 0
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req2, 55, computed_blocks)
blocks = manager.allocate_slots(req2, 55,
len(computed_blocks.blocks) * 16,
computed_blocks)
block_ids = blocks.get_block_ids()
# Duplicate cached blocks have different ids but same hashes vs request #0
assert [b.block_hash for b in blocks.blocks] == req0_block_hashes
@@ -290,14 +304,18 @@ def test_decode():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55, computed_blocks)
blocks = manager.allocate_slots(req0, 55,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]
# Append slots without allocating a new block.
req0.num_computed_tokens = 55
for _ in range(4):
req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 4)
new_blocks = manager.allocate_slots(req0, 4,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks) == 0
assert manager.single_type_manager.req_to_blocks[
req0.request_id][-1].block_hash is None
@@ -308,7 +326,9 @@ def test_decode():
# the preallocated block.
for _ in range(9 + 10):
req0.append_output_token_ids(7)
new_blocks = manager.allocate_slots(req0, 19)
new_blocks = manager.allocate_slots(req0, 19,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks) == 1
assert manager.single_type_manager.req_to_blocks[
req0.request_id][-2].block_hash is not None
@@ -328,7 +348,9 @@ def test_evict():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks)
blocks = manager.allocate_slots(req0, 5 * 16 + 7,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 6 # 5 full + 1 partial
# 3 blocks.
@@ -337,7 +359,9 @@ def test_evict():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks)
blocks = manager.allocate_slots(req1, 3 * 16,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 3 # 3 full blocks
last_token_id += 3 * 16
@@ -357,7 +381,9 @@ def test_evict():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert computed_blocks.get_block_ids() == [1, 2]
assert num_computed_tokens == 2 * 16
blocks = manager.allocate_slots(req2, 3, computed_blocks)
blocks = manager.allocate_slots(req2, 3,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [10]
assert manager.block_pool.free_block_queue.num_free_blocks == 7
@@ -380,7 +406,9 @@ def test_hash_block_correct_reuse():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req, num_tokens, computed_blocks)
blocks = manager.allocate_slots(req, num_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 1
# Deallocate the block.
@@ -392,7 +420,9 @@ def test_hash_block_correct_reuse():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks)
blocks = manager.allocate_slots(req, num_tokens - 1,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 1
assert manager.block_pool.blocks[
@@ -417,7 +447,9 @@ def test_computed_blocks_not_evicted():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, num_tokens, computed_blocks)
blocks = manager.allocate_slots(req0, num_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 1
assert blocks.blocks[0].block_id == 1
@@ -426,7 +458,9 @@ def test_computed_blocks_not_evicted():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, num_tokens, computed_blocks)
blocks = manager.allocate_slots(req1, num_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 1
assert blocks.blocks[0].block_id == 2
@@ -443,6 +477,7 @@ def test_computed_blocks_not_evicted():
assert num_computed_tokens == block_size
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 1
assert blocks.blocks[0].block_id == 2
@@ -464,7 +499,9 @@ def test_basic_prefix_caching_disabled():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req1, 10, computed_blocks)
blocks = manager.allocate_slots(req1, 10,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 3
# Free the blocks.
@@ -475,7 +512,9 @@ def test_basic_prefix_caching_disabled():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req2, 16, computed_blocks)
blocks = manager.allocate_slots(req2, 16,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert len(blocks.blocks) == 4
# New requests should not have any blocks.
@@ -483,7 +522,9 @@ def test_basic_prefix_caching_disabled():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req3, 4, computed_blocks)
blocks = manager.allocate_slots(req3, 4,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert not blocks
@@ -581,14 +622,18 @@ def test_mm_prefix_caching():
assert block_hashes[1].extra_keys == ("aaa", "bbb")
assert block_hashes[2].extra_keys == ("bbb", )
blocks = manager.allocate_slots(req0, 59, computed_blocks)
blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]
req0.num_computed_tokens = 59
# Append slots without allocating a new block.
for _ in range(5):
req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 5)
new_blocks = manager.allocate_slots(req0, 5,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks) == 0
# The just completed block should have hashes with extra keys.
@@ -638,14 +683,18 @@ def test_cache_key_salting():
assert block_hashes[1].extra_keys is None
assert block_hashes[2].extra_keys is None
blocks = manager.allocate_slots(req0, 59, computed_blocks)
blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [1, 2, 3, 4]
req0.num_computed_tokens = 59
# Append slots without allocating a new block.
for _ in range(5):
req0.append_output_token_ids(8)
new_blocks = manager.allocate_slots(req0, 5)
new_blocks = manager.allocate_slots(req0, 5,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks) == 0
# Now one more block that should not have extra keys.
@@ -691,7 +740,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
manager.allocate_slots(req0, 48, computed_blocks)
manager.allocate_slots(req0, 48,
len(computed_blocks.blocks) * 16, computed_blocks)
block_part0 = manager.single_type_manager.req_to_blocks[req0.request_id]
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
@@ -699,7 +749,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert computed_blocks.blocks == block_part0
assert num_computed_tokens == 3 * 16
manager.allocate_slots(req1, 48, computed_blocks)
manager.allocate_slots(req1, 48,
len(computed_blocks.blocks) * 16, computed_blocks)
block_part1 = manager.single_type_manager.req_to_blocks[req1.request_id]
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Req1-5(F)| ... |
@@ -713,7 +764,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
manager.allocate_slots(req2, block_size * 2, computed_blocks)
manager.allocate_slots(req2, block_size * 2,
len(computed_blocks.blocks) * 16, computed_blocks)
# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
# but it cannot be allocated due to insufficient free blocks (2).
@@ -724,7 +776,9 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
assert computed_blocks.blocks == block_part1
assert num_computed_tokens == 6 * 16
# Req3 cannot be allocated.
assert manager.allocate_slots(req3, 48, computed_blocks) is None
assert manager.allocate_slots(req3, 48,
len(computed_blocks.blocks) * 16,
computed_blocks) is None
# Block 0-2 are used by Req 1.
assert {block.ref_cnt for block in block_part1[:3]} == {1}
# Block 3-5 are free.
@@ -751,7 +805,9 @@ def test_reset_prefix_cache():
computed_blocks, _ = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert len(computed_blocks.blocks) == 3
blocks = manager.allocate_slots(req1, 7, computed_blocks)
blocks = manager.allocate_slots(req1, 7,
len(computed_blocks.blocks) * 16,
computed_blocks)
assert blocks.get_block_ids() == [5]
# Failed to reset prefix cache because some blocks are not freed yet.
@@ -782,7 +838,8 @@ def test_prefix_cache_stats_disabled():
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks.blocks
assert num_computed_tokens == 0
manager.allocate_slots(req, 16, computed_blocks)
manager.allocate_slots(req, 16,
len(computed_blocks.blocks) * 16, computed_blocks)
manager.reset_prefix_cache()
# Ensure prefix_cache_stats remains None
@@ -860,7 +917,8 @@ def test_eagle_enabled_removes_last_block():
# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
manager.allocate_slots(req, len(token_ids), computed_blocks)
manager.allocate_slots(req, len(token_ids),
len(computed_blocks.blocks) * 16, computed_blocks)
manager.free(req)
# New request with same tokens + Eagle enabled
@@ -889,7 +947,8 @@ def test_eagle_with_partial_blocks():
# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
manager.allocate_slots(req, len(token_ids), computed_blocks)
manager.allocate_slots(req, len(token_ids),
len(computed_blocks.blocks) * 16, computed_blocks)
manager.free(req)
# New request with Eagle enabled
@@ -928,7 +987,8 @@ def test_eagle_with_sliding_window():
# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
manager.allocate_slots(req, len(token_ids), computed_blocks)
manager.allocate_slots(req, len(token_ids),
len(computed_blocks.blocks) * 16, computed_blocks)
# record the block hash of the first block in the request for later use
block_hash_first_block = manager.req_to_block_hashes[req.request_id][0]
assert block_hash_first_block is not None