[Frontend] Add per-request number of cached token stats (#10174)

This commit is contained in:
zifeitong
2024-11-12 08:42:28 -08:00
committed by GitHub
parent 176fcb1c71
commit 47db6ec831
9 changed files with 89 additions and 23 deletions

View File

@@ -27,6 +27,7 @@ UNSTABLE_PROMPT_SEQUENCE = [
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("cached_position", [0, 1])
@pytest.mark.parametrize("block_size", [16])
def test_mixed_requests(
hf_runner,
vllm_runner,
@@ -36,11 +37,12 @@ def test_mixed_requests(
dtype: str,
max_tokens: int,
cached_position: int,
block_size: int,
monkeypatch,
) -> None:
"""
Test the case when some sequences have the prefix cache hit
and the others don't. The cached position determines where
and the others don't. The cached position determines where
the sequence is at among the batch of prefills.
"""
override_backend_env_variable(monkeypatch, backend)
@@ -53,12 +55,30 @@ def test_mixed_requests(
model,
dtype=dtype,
enable_prefix_caching=True,
block_size=block_size,
) as vllm_model:
# Run the first prompt so the cache is populated
vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens)
# Run all the promopts
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
req_outputs = vllm_model.model.generate(example_prompts, greedy_params)
# Verify number of cached tokens
for i in range(len(req_outputs)):
if i == cached_position:
expected_num_cached_tokens = (
len(req_outputs[i].prompt_token_ids) //
block_size) * block_size
else:
expected_num_cached_tokens = 0
assert req_outputs[
i].num_cached_tokens == expected_num_cached_tokens
vllm_outputs = [
(output.prompt_token_ids + list(output.outputs[0].token_ids),
output.prompt + output.outputs[0].text) for output in req_outputs
]
check_outputs_equal(
outputs_0_lst=hf_outputs,