[v1][KVCacheManager] pass num_new_computed_tokens to kv cache manager (#18001)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -81,7 +81,9 @@ def test_prefill(hash_algo):
|
||||
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, 55, computed_blocks)
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [1, 2, 3, 4]
|
||||
|
||||
# Check full block metadata
|
||||
@@ -108,7 +110,9 @@ def test_prefill(hash_algo):
|
||||
assert computed_blocks.get_block_ids() == [1, 2, 3]
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [5]
|
||||
for block in computed_blocks.blocks:
|
||||
assert block.ref_cnt == 2
|
||||
@@ -140,7 +144,9 @@ def test_prefill(hash_algo):
|
||||
assert computed_blocks.get_block_ids() == [1, 2, 3]
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
|
||||
blocks = manager.allocate_slots(req2, num_new_tokens,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [6]
|
||||
|
||||
# Although we only have 6 free blocks, we have 8 blocks in
|
||||
@@ -161,7 +167,9 @@ def test_prefill(hash_algo):
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req3, 16 * 10, computed_blocks)
|
||||
blocks = manager.allocate_slots(req3, 16 * 10,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
# This block ID order also checks the eviction order.
|
||||
assert blocks.get_block_ids() == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1]
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 0
|
||||
@@ -197,7 +205,9 @@ def test_prefill_plp():
|
||||
assert len(manager.req_to_block_hashes[req0.request_id]) == 0
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, 55, computed_blocks)
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [1, 2, 3, 4]
|
||||
req0_block_hashes = [b.block_hash for b in blocks.blocks]
|
||||
|
||||
@@ -226,7 +236,9 @@ def test_prefill_plp():
|
||||
assert computed_blocks.get_block_ids() == [1, 2, 3]
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [5]
|
||||
for block in computed_blocks.blocks:
|
||||
assert block.ref_cnt == 2
|
||||
@@ -259,7 +271,9 @@ def test_prefill_plp():
|
||||
assert len(manager.req_to_block_hashes[req2.request_id]) == 0
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req2, 55, computed_blocks)
|
||||
blocks = manager.allocate_slots(req2, 55,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
block_ids = blocks.get_block_ids()
|
||||
# Duplicate cached blocks have different ids but same hashes vs request #0
|
||||
assert [b.block_hash for b in blocks.blocks] == req0_block_hashes
|
||||
@@ -290,14 +304,18 @@ def test_decode():
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, 55, computed_blocks)
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [1, 2, 3, 4]
|
||||
|
||||
# Append slots without allocating a new block.
|
||||
req0.num_computed_tokens = 55
|
||||
for _ in range(4):
|
||||
req0.append_output_token_ids(8)
|
||||
new_blocks = manager.allocate_slots(req0, 4)
|
||||
new_blocks = manager.allocate_slots(req0, 4,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert new_blocks is not None and len(new_blocks.blocks) == 0
|
||||
assert manager.single_type_manager.req_to_blocks[
|
||||
req0.request_id][-1].block_hash is None
|
||||
@@ -308,7 +326,9 @@ def test_decode():
|
||||
# the preallocated block.
|
||||
for _ in range(9 + 10):
|
||||
req0.append_output_token_ids(7)
|
||||
new_blocks = manager.allocate_slots(req0, 19)
|
||||
new_blocks = manager.allocate_slots(req0, 19,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert new_blocks is not None and len(new_blocks.blocks) == 1
|
||||
assert manager.single_type_manager.req_to_blocks[
|
||||
req0.request_id][-2].block_hash is not None
|
||||
@@ -328,7 +348,9 @@ def test_evict():
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks)
|
||||
blocks = manager.allocate_slots(req0, 5 * 16 + 7,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks) == 6 # 5 full + 1 partial
|
||||
|
||||
# 3 blocks.
|
||||
@@ -337,7 +359,9 @@ def test_evict():
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks)
|
||||
blocks = manager.allocate_slots(req1, 3 * 16,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks) == 3 # 3 full blocks
|
||||
last_token_id += 3 * 16
|
||||
|
||||
@@ -357,7 +381,9 @@ def test_evict():
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert computed_blocks.get_block_ids() == [1, 2]
|
||||
assert num_computed_tokens == 2 * 16
|
||||
blocks = manager.allocate_slots(req2, 3, computed_blocks)
|
||||
blocks = manager.allocate_slots(req2, 3,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [10]
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 7
|
||||
|
||||
@@ -380,7 +406,9 @@ def test_hash_block_correct_reuse():
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req, num_tokens, computed_blocks)
|
||||
blocks = manager.allocate_slots(req, num_tokens,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks) == 1
|
||||
|
||||
# Deallocate the block.
|
||||
@@ -392,7 +420,9 @@ def test_hash_block_correct_reuse():
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks)
|
||||
blocks = manager.allocate_slots(req, num_tokens - 1,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks) == 1
|
||||
|
||||
assert manager.block_pool.blocks[
|
||||
@@ -417,7 +447,9 @@ def test_computed_blocks_not_evicted():
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, num_tokens, computed_blocks)
|
||||
blocks = manager.allocate_slots(req0, num_tokens,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks) == 1
|
||||
assert blocks.blocks[0].block_id == 1
|
||||
|
||||
@@ -426,7 +458,9 @@ def test_computed_blocks_not_evicted():
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req1, num_tokens, computed_blocks)
|
||||
blocks = manager.allocate_slots(req1, num_tokens,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks) == 1
|
||||
assert blocks.blocks[0].block_id == 2
|
||||
|
||||
@@ -443,6 +477,7 @@ def test_computed_blocks_not_evicted():
|
||||
assert num_computed_tokens == block_size
|
||||
|
||||
blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks) == 1
|
||||
assert blocks.blocks[0].block_id == 2
|
||||
@@ -464,7 +499,9 @@ def test_basic_prefix_caching_disabled():
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req1, 10, computed_blocks)
|
||||
blocks = manager.allocate_slots(req1, 10,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks) == 3
|
||||
|
||||
# Free the blocks.
|
||||
@@ -475,7 +512,9 @@ def test_basic_prefix_caching_disabled():
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req2, 16, computed_blocks)
|
||||
blocks = manager.allocate_slots(req2, 16,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert len(blocks.blocks) == 4
|
||||
|
||||
# New requests should not have any blocks.
|
||||
@@ -483,7 +522,9 @@ def test_basic_prefix_caching_disabled():
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req3, 4, computed_blocks)
|
||||
blocks = manager.allocate_slots(req3, 4,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert not blocks
|
||||
|
||||
|
||||
@@ -581,14 +622,18 @@ def test_mm_prefix_caching():
|
||||
assert block_hashes[1].extra_keys == ("aaa", "bbb")
|
||||
assert block_hashes[2].extra_keys == ("bbb", )
|
||||
|
||||
blocks = manager.allocate_slots(req0, 59, computed_blocks)
|
||||
blocks = manager.allocate_slots(req0, 59,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [1, 2, 3, 4]
|
||||
req0.num_computed_tokens = 59
|
||||
|
||||
# Append slots without allocating a new block.
|
||||
for _ in range(5):
|
||||
req0.append_output_token_ids(8)
|
||||
new_blocks = manager.allocate_slots(req0, 5)
|
||||
new_blocks = manager.allocate_slots(req0, 5,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert new_blocks is not None and len(new_blocks.blocks) == 0
|
||||
|
||||
# The just completed block should have hashes with extra keys.
|
||||
@@ -638,14 +683,18 @@ def test_cache_key_salting():
|
||||
assert block_hashes[1].extra_keys is None
|
||||
assert block_hashes[2].extra_keys is None
|
||||
|
||||
blocks = manager.allocate_slots(req0, 59, computed_blocks)
|
||||
blocks = manager.allocate_slots(req0, 59,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [1, 2, 3, 4]
|
||||
req0.num_computed_tokens = 59
|
||||
|
||||
# Append slots without allocating a new block.
|
||||
for _ in range(5):
|
||||
req0.append_output_token_ids(8)
|
||||
new_blocks = manager.allocate_slots(req0, 5)
|
||||
new_blocks = manager.allocate_slots(req0, 5,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert new_blocks is not None and len(new_blocks.blocks) == 0
|
||||
|
||||
# Now one more block that should not have extra keys.
|
||||
@@ -691,7 +740,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
manager.allocate_slots(req0, 48, computed_blocks)
|
||||
manager.allocate_slots(req0, 48,
|
||||
len(computed_blocks.blocks) * 16, computed_blocks)
|
||||
block_part0 = manager.single_type_manager.req_to_blocks[req0.request_id]
|
||||
|
||||
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
|
||||
@@ -699,7 +749,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert computed_blocks.blocks == block_part0
|
||||
assert num_computed_tokens == 3 * 16
|
||||
manager.allocate_slots(req1, 48, computed_blocks)
|
||||
manager.allocate_slots(req1, 48,
|
||||
len(computed_blocks.blocks) * 16, computed_blocks)
|
||||
block_part1 = manager.single_type_manager.req_to_blocks[req1.request_id]
|
||||
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
|
||||
# | Req1-5(F)| ... |
|
||||
@@ -713,7 +764,8 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
manager.allocate_slots(req2, block_size * 2, computed_blocks)
|
||||
manager.allocate_slots(req2, block_size * 2,
|
||||
len(computed_blocks.blocks) * 16, computed_blocks)
|
||||
|
||||
# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
|
||||
# but it cannot be allocated due to insufficient free blocks (2).
|
||||
@@ -724,7 +776,9 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
assert computed_blocks.blocks == block_part1
|
||||
assert num_computed_tokens == 6 * 16
|
||||
# Req3 cannot be allocated.
|
||||
assert manager.allocate_slots(req3, 48, computed_blocks) is None
|
||||
assert manager.allocate_slots(req3, 48,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks) is None
|
||||
# Block 0-2 are used by Req 1.
|
||||
assert {block.ref_cnt for block in block_part1[:3]} == {1}
|
||||
# Block 3-5 are free.
|
||||
@@ -751,7 +805,9 @@ def test_reset_prefix_cache():
|
||||
computed_blocks, _ = manager.get_computed_blocks(req1)
|
||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||
assert len(computed_blocks.blocks) == 3
|
||||
blocks = manager.allocate_slots(req1, 7, computed_blocks)
|
||||
blocks = manager.allocate_slots(req1, 7,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [5]
|
||||
|
||||
# Failed to reset prefix cache because some blocks are not freed yet.
|
||||
@@ -782,7 +838,8 @@ def test_prefix_cache_stats_disabled():
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert not computed_blocks.blocks
|
||||
assert num_computed_tokens == 0
|
||||
manager.allocate_slots(req, 16, computed_blocks)
|
||||
manager.allocate_slots(req, 16,
|
||||
len(computed_blocks.blocks) * 16, computed_blocks)
|
||||
manager.reset_prefix_cache()
|
||||
|
||||
# Ensure prefix_cache_stats remains None
|
||||
@@ -860,7 +917,8 @@ def test_eagle_enabled_removes_last_block():
|
||||
|
||||
# Prime the cache
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
manager.allocate_slots(req, len(token_ids), computed_blocks)
|
||||
manager.allocate_slots(req, len(token_ids),
|
||||
len(computed_blocks.blocks) * 16, computed_blocks)
|
||||
manager.free(req)
|
||||
|
||||
# New request with same tokens + Eagle enabled
|
||||
@@ -889,7 +947,8 @@ def test_eagle_with_partial_blocks():
|
||||
|
||||
# Prime the cache
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
manager.allocate_slots(req, len(token_ids), computed_blocks)
|
||||
manager.allocate_slots(req, len(token_ids),
|
||||
len(computed_blocks.blocks) * 16, computed_blocks)
|
||||
manager.free(req)
|
||||
|
||||
# New request with Eagle enabled
|
||||
@@ -928,7 +987,8 @@ def test_eagle_with_sliding_window():
|
||||
|
||||
# Prime the cache
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
manager.allocate_slots(req, len(token_ids), computed_blocks)
|
||||
manager.allocate_slots(req, len(token_ids),
|
||||
len(computed_blocks.blocks) * 16, computed_blocks)
|
||||
# record the block hash of the first block in the request for later use
|
||||
block_hash_first_block = manager.req_to_block_hashes[req.request_id][0]
|
||||
assert block_hash_first_block is not None
|
||||
|
||||
Reference in New Issue
Block a user