[V1] Prompt logprobs + APC compatibility; prompt logprobs reqs cannot fill APC (#13949)
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Compare the with and without prefix caching."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
@@ -15,7 +17,8 @@ from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
|
||||
def make_request(request_id,
|
||||
prompt_token_ids,
|
||||
mm_positions=None,
|
||||
mm_hashes=None):
|
||||
mm_hashes=None,
|
||||
prompt_logprobs: Optional[int] = None):
|
||||
if mm_positions is None:
|
||||
multi_modal_inputs = None
|
||||
else:
|
||||
@@ -28,7 +31,8 @@ def make_request(request_id,
|
||||
multi_modal_inputs=multi_modal_inputs,
|
||||
multi_modal_hashes=mm_hashes,
|
||||
multi_modal_placeholders=mm_positions,
|
||||
sampling_params=SamplingParams(max_tokens=17),
|
||||
sampling_params=SamplingParams(max_tokens=17,
|
||||
prompt_logprobs=prompt_logprobs),
|
||||
eos_token_id=100,
|
||||
arrival_time=0,
|
||||
lora_request=None,
|
||||
@@ -144,6 +148,110 @@ def test_prefill():
|
||||
assert manager.block_pool.free_block_queue.free_list_tail is None
|
||||
|
||||
|
||||
def test_prefill_plp():
|
||||
'''Test prefill with APC and some prompt logprobs (plp) requests.
|
||||
|
||||
1. Schedule plp request and validate APC block allocation
|
||||
2. Schedule non-plp request and validate blocks
|
||||
3. Schedule plp request; no hit should occur; validate blocks
|
||||
'''
|
||||
manager = KVCacheManager(
|
||||
block_size=16,
|
||||
num_gpu_blocks=10,
|
||||
max_model_len=8192,
|
||||
sliding_window=None,
|
||||
enable_caching=True,
|
||||
num_preallocate_tokens=16,
|
||||
)
|
||||
|
||||
# Complete 3 blocks (48 tokens)
|
||||
common_token_ids = [i for i in range(3) for _ in range(16)]
|
||||
|
||||
# Request #0 is a prompt logprobs request
|
||||
# Fully cache miss
|
||||
# Incomplete 1 block (7 tokens)
|
||||
unique_token_ids = [3] * 7
|
||||
all_token_ids = common_token_ids + unique_token_ids
|
||||
req0 = make_request("0", all_token_ids, prompt_logprobs=5)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
|
||||
assert not computed_blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, 55, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [0, 1, 2, 3, 4]
|
||||
req0_block_hashes = [b.block_hash for b in blocks]
|
||||
|
||||
# Check full block metadata
|
||||
parent_block_hash = None
|
||||
for block_id in (0, 1, 2):
|
||||
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
|
||||
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
|
||||
assert manager.block_pool.blocks[block_id].block_hash == block_hash
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
parent_block_hash = block_hash.hash_value
|
||||
|
||||
# Check partial/preallocated block metadata
|
||||
for block_id in (3, 4):
|
||||
assert manager.block_pool.blocks[block_id].block_hash is None
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
|
||||
# Request #1 is a non-prompt-logprobs request:
|
||||
# Cache hit in the common prefix when the original block is still in use.
|
||||
# Incomplete 1 block (5 tokens)
|
||||
unique_token_ids = [3] * 5
|
||||
req1 = make_request("1", common_token_ids + unique_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||
assert [b.block_id for b in computed_blocks] == [0, 1, 2]
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [5, 6]
|
||||
for block in computed_blocks:
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
# At this point, we should have 3 free blocks left.
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 3
|
||||
|
||||
manager.free(req0)
|
||||
manager.free(req1)
|
||||
|
||||
# All blocks should be available.
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 10
|
||||
# The order should be
|
||||
# [unallocated (7, 8, 9)]
|
||||
# [unique_req0 (4, 3)]
|
||||
# [unique_req1 (6, 5)]
|
||||
# [common (2, 1, 0)]
|
||||
assert [
|
||||
b.block_id
|
||||
for b in manager.block_pool.free_block_queue.get_all_free_blocks()
|
||||
] == [7, 8, 9, 4, 3, 6, 5, 2, 1, 0]
|
||||
|
||||
# Request #2 is a prompt-logprobs request:
|
||||
# NO cache hit in the common prefix; duplicates request #0 cached blocks
|
||||
unique_token_ids = [3] * 6
|
||||
req2 = make_request("2",
|
||||
common_token_ids + unique_token_ids,
|
||||
prompt_logprobs=5)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
|
||||
assert not computed_blocks
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req2, 55, computed_blocks)
|
||||
block_ids = [b.block_id for b in blocks]
|
||||
# Duplicate cached blocks have different ids but same hashes vs request #0
|
||||
assert [b.block_hash for b in blocks] == req0_block_hashes
|
||||
assert block_ids != [0, 1, 2, 3, 4]
|
||||
|
||||
# Request #2 block hashes are valid since request #0 hashes are.
|
||||
# Check block reference counts.
|
||||
for block_id in block_ids:
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
|
||||
manager.free(req2)
|
||||
|
||||
|
||||
def test_decode():
|
||||
manager = KVCacheManager(
|
||||
block_size=16,
|
||||
|
||||
Reference in New Issue
Block a user