[v1] Move block_hashes from KVCacheManager to Request.block_hashes (#19728)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Or Ozeri
2025-08-16 02:52:52 +03:00
committed by GitHub
parent b9dc9d2607
commit c280066f9d
19 changed files with 381 additions and 335 deletions

View File

@@ -7,6 +7,7 @@ import pytest
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import RequestStatus
from vllm.v1.utils import ConstantList
from .utils import create_requests, create_scheduler
@@ -140,7 +141,8 @@ def test_prefix_caching_for_prefill_dedup():
requests = create_requests(num_requests=5,
num_tokens=num_prompt_tokens,
max_tokens=3,
same_prompt=True)
same_prompt=True,
block_size=BLOCK_SIZE)
requests_copy = requests.copy()
# Two requests with the same prompt.
@@ -188,7 +190,8 @@ def test_prefix_caching_for_multi_turn():
block_size=BLOCK_SIZE)
requests = create_requests(num_requests=5,
num_tokens=num_prompt_tokens,
max_tokens=num_output_tokens)
max_tokens=num_output_tokens,
block_size=BLOCK_SIZE)
for req in requests:
scheduler.add_request(req)
@@ -208,14 +211,19 @@ def test_prefix_caching_for_multi_turn():
# Create next-turn requests whose prompts are the full output of the
# previous turn.
next_turn_requests = create_requests(
num_requests=5,
num_tokens=num_prompt_tokens + num_output_tokens,
max_tokens=num_output_tokens,
)
next_turn_requests = create_requests(num_requests=5,
num_tokens=num_prompt_tokens +
num_output_tokens,
max_tokens=num_output_tokens,
block_size=BLOCK_SIZE)
for i, req in enumerate(next_turn_requests):
req.prompt_token_ids = (requests[i].prompt_token_ids +
list(requests[i].output_token_ids))
req._all_token_ids = req.prompt_token_ids.copy()
req.all_token_ids = ConstantList(req._all_token_ids)
req.block_hashes = []
req.block_hashes = req.get_hash_new_full_blocks()
# Schedule the next-turn requests.
for req in next_turn_requests:
scheduler.add_request(req)

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
from typing import Optional
from typing import Callable, Optional
import pytest
import torch
@@ -19,7 +19,7 @@ from vllm.v1.core.kv_cache_utils import (
FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
estimate_max_model_len, generate_block_hash_extra_keys,
get_kv_cache_config, get_max_concurrency_for_kv_cache_config,
hash_block_tokens, hash_request_tokens, init_none_hash,
get_request_block_hasher, hash_block_tokens, init_none_hash,
is_kv_cache_type_uniform, unify_kv_cache_configs)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor,
@@ -33,6 +33,8 @@ from vllm.v1.request import Request
def make_request(
request_id: str,
prompt_token_ids: list[int],
block_size: int = 3,
hash_fn: Callable = hash,
mm_positions: Optional[list[PlaceholderRange]] = None,
mm_hashes: Optional[list[str]] = None,
cache_salt: Optional[str] = None,
@@ -49,18 +51,17 @@ def make_request(
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
mm_kwargs = [mm_item] * len(mm_positions)
return Request(
request_id=request_id,
prompt_token_ids=prompt_token_ids,
multi_modal_kwargs=mm_kwargs,
multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17),
pooling_params=None,
eos_token_id=100,
lora_request=None,
cache_salt=cache_salt,
)
return Request(request_id=request_id,
prompt_token_ids=prompt_token_ids,
multi_modal_kwargs=mm_kwargs,
multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17),
pooling_params=None,
eos_token_id=100,
lora_request=None,
cache_salt=cache_salt,
block_hasher=get_request_block_hasher(block_size, hash_fn))
def new_kv_cache_spec(block_size=16,
@@ -428,12 +429,14 @@ def test_hash_block_tokens(hash_fn):
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
def test_hash_request_tokens(hash_fn):
def test_request_block_hasher(hash_fn):
import vllm.v1.core.kv_cache_utils
init_none_hash(hash_fn)
request = make_request(
request_id="0",
prompt_token_ids=[_ for _ in range(6)],
block_size=3,
hash_fn=hash_fn,
mm_positions=[
PlaceholderRange(offset=0, length=3),
PlaceholderRange(offset=3, length=3),
@@ -441,9 +444,7 @@ def test_hash_request_tokens(hash_fn):
mm_hashes=["hash1", "hash2"],
)
block_size = 3
block_hashes = hash_request_tokens(hash_fn, block_size, request)
block_hashes = request.block_hashes
assert len(block_hashes) == 2
assert isinstance(block_hashes[0], vllm.v1.core.kv_cache_utils.BlockHash)
assert isinstance(block_hashes[1], vllm.v1.core.kv_cache_utils.BlockHash)
@@ -464,6 +465,8 @@ def test_hash_tokens_different_mm_input(hash_fn):
request1 = make_request(
request_id="0",
prompt_token_ids=[_ for _ in range(6)],
block_size=3,
hash_fn=hash_fn,
mm_positions=[
PlaceholderRange(offset=0, length=3),
PlaceholderRange(offset=3, length=3),
@@ -479,9 +482,8 @@ def test_hash_tokens_different_mm_input(hash_fn):
],
mm_hashes=["hash3", "hash2"],
)
block_size = 3
block_hashes1 = hash_request_tokens(hash_fn, block_size, request1)
block_hashes2 = hash_request_tokens(hash_fn, block_size, request2)
block_hashes1 = request1.block_hashes
block_hashes2 = request2.block_hashes
assert block_hashes1[0] != block_hashes2[0]
assert block_hashes1[1] != block_hashes2[1]
@@ -493,12 +495,13 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
request = make_request(
request_id="0",
prompt_token_ids=[_ for _ in range(6)],
block_size=3,
hash_fn=hash_fn,
mm_positions=None,
mm_hashes=None,
)
block_size = 3
block_hashes = hash_request_tokens(hash_fn, block_size, request)
block_hashes = request.block_hashes
assert len(block_hashes) == 2
assert block_hashes[0].token_ids == (0, 1, 2)
@@ -858,6 +861,7 @@ def test_allocate_with_lookahead():
request = make_request(
request_id="0",
prompt_token_ids=[],
block_size=block_size,
mm_positions=None,
mm_hashes=None,
)

View File

@@ -3,7 +3,7 @@
"""Compare the with and without prefix caching."""
import copy
from typing import Optional
from typing import Callable, Optional
import pytest
import torch
@@ -17,8 +17,9 @@ from vllm.utils import sha256, sha256_cbor_64bit
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
KVCacheBlock, hash_block_tokens,
init_none_hash)
KVCacheBlock,
get_request_block_hasher,
hash_block_tokens, init_none_hash)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, SlidingWindowSpec)
@@ -26,6 +27,8 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
def make_request(
request_id: str,
prompt_token_ids: list[int],
block_size: int,
hash_fn: Callable,
mm_positions: Optional[list[PlaceholderRange]] = None,
mm_hashes: Optional[list[str]] = None,
prompt_logprobs: Optional[int] = None,
@@ -43,19 +46,18 @@ def make_request(
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
mm_kwargs = [mm_item] * len(mm_positions)
return Request(
request_id=request_id,
prompt_token_ids=prompt_token_ids,
multi_modal_kwargs=mm_kwargs,
multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17,
prompt_logprobs=prompt_logprobs),
pooling_params=None,
eos_token_id=100,
lora_request=None,
cache_salt=cache_salt,
)
return Request(request_id=request_id,
prompt_token_ids=prompt_token_ids,
multi_modal_kwargs=mm_kwargs,
multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(
max_tokens=17, prompt_logprobs=prompt_logprobs),
pooling_params=None,
eos_token_id=100,
lora_request=None,
cache_salt=cache_salt,
block_hasher=get_request_block_hasher(block_size, hash_fn))
def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
@@ -105,11 +107,11 @@ def make_kv_cache_config_hybrid_model(block_size: int,
@pytest.mark.parametrize("hash_algo", ["sha256", "sha256_cbor_64bit", "hash"])
def test_prefill(hash_algo):
block_size = 16
manager = KVCacheManager(
make_kv_cache_config(16, 11),
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
caching_hash_algo=hash_algo,
)
# choose the hash function according to the parameter
@@ -123,9 +125,9 @@ def test_prefill(hash_algo):
# Incomplete 1 block (7 tokens)
unique_token_ids = [3] * 7
all_token_ids = common_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
req0 = make_request("0", all_token_ids, block_size, hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
assert len(req0.block_hashes) == 3
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55,
@@ -152,9 +154,10 @@ def test_prefill(hash_algo):
# Cache hit in the common prefix when the original block is still in use.
# Incomplete 1 block (5 tokens)
unique_token_ids = [3] * 5
req1 = make_request("1", common_token_ids + unique_token_ids)
req1 = make_request("1", common_token_ids + unique_token_ids, block_size,
hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert len(req1.block_hashes) == 3
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
@@ -187,9 +190,10 @@ def test_prefill(hash_algo):
# Cache hit in the common prefix when the original block is already free.
# Incomplete 1 block (6 tokens)
unique_token_ids = [3] * 6
req2 = make_request("2", common_token_ids + unique_token_ids)
req2 = make_request("2", common_token_ids + unique_token_ids, block_size,
hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
assert len(req2.block_hashes) == 3
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
@@ -208,7 +212,7 @@ def test_prefill(hash_algo):
manager.free(req2)
# Cache miss and eviction.
req3 = make_request("3", [99] * (16 * 10))
req3 = make_request("3", [99] * (16 * 10), block_size, hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
@@ -242,9 +246,9 @@ def test_prefill_hybrid_model():
# Incomplete 1 block (7 tokens)
unique_token_ids = [3] * 7
all_token_ids = common_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
req0 = make_request("0", all_token_ids, block_size, hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
assert len(req0.block_hashes) == 3
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55,
@@ -274,9 +278,10 @@ def test_prefill_hybrid_model():
# Cache hit in the common prefix
# Incomplete 1 block (5 tokens)
unique_token_ids = [3] * 5
req1 = make_request("1", common_token_ids + unique_token_ids)
req1 = make_request("1", common_token_ids + unique_token_ids, block_size,
hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert len(req1.block_hashes) == 3
assert computed_blocks.get_block_ids() == ([1, 2, 3], [0, 6,
7], [0, 10, 11])
assert num_computed_tokens == 3 * 16
@@ -290,7 +295,7 @@ def test_prefill_hybrid_model():
if block != manager.block_pool.null_block:
assert block.ref_cnt == 2
block_hashes = manager.req_to_block_hashes[req1.request_id]
block_hashes = req1.block_hashes
manager.free(req0)
manager.free(req1)
@@ -300,12 +305,13 @@ def test_prefill_hybrid_model():
def test_partial_request_hit(request_id: str,
hash_to_evict: list[BlockHashWithGroupId],
expect_hit_length: int):
req = make_request(request_id, common_token_ids + unique_token_ids)
req = make_request(request_id, common_token_ids + unique_token_ids,
block_size, hash)
for hash_with_group_id in hash_to_evict:
manager.block_pool.cached_block_hash_to_block.pop(
hash_with_group_id)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert len(manager.req_to_block_hashes[req.request_id]) == 3
assert len(req.block_hashes) == 3
assert num_computed_tokens == expect_hit_length * block_size
for block_per_group in computed_blocks.blocks:
assert len(block_per_group) == num_computed_tokens // block_size
@@ -364,8 +370,9 @@ def test_prefill_plp():
2. Schedule non-plp request and validate blocks
3. Schedule plp request; no hit should occur; validate blocks
'''
block_size = 16
manager = KVCacheManager(
make_kv_cache_config(16, 11),
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
)
@@ -380,9 +387,13 @@ def test_prefill_plp():
# Incomplete 1 block (7 tokens)
unique_token_ids = [3] * 7
all_token_ids = common_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids, prompt_logprobs=5)
req0 = make_request("0",
all_token_ids,
block_size,
hash_fn,
prompt_logprobs=5)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert len(manager.req_to_block_hashes[req0.request_id]) == 0
assert len(req0.block_hashes) == 3
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req0, 55,
@@ -411,9 +422,10 @@ def test_prefill_plp():
# Cache hit in the common prefix when the original block is still in use.
# Incomplete 1 block (5 tokens)
unique_token_ids = [3] * 5
req1 = make_request("1", common_token_ids + unique_token_ids)
req1 = make_request("1", common_token_ids + unique_token_ids, block_size,
hash_fn)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert len(req1.block_hashes) == 3
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
assert num_computed_tokens == 3 * 16
num_new_tokens = 53 - 3 * 16
@@ -447,9 +459,11 @@ def test_prefill_plp():
unique_token_ids = [3] * 6
req2 = make_request("2",
common_token_ids + unique_token_ids,
block_size,
hash_fn,
prompt_logprobs=5)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(manager.req_to_block_hashes[req2.request_id]) == 0
assert len(req2.block_hashes) == 3
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
blocks = manager.allocate_slots(req2, 55,
@@ -469,8 +483,9 @@ def test_prefill_plp():
def test_decode():
block_size = 16
manager = KVCacheManager(
make_kv_cache_config(16, 11),
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
)
@@ -481,7 +496,8 @@ def test_decode():
# Fully cache miss
# Incomplete 1 block (7 tokens)
unique_token_ids = [3] * 7
req0 = make_request("0", common_token_ids + unique_token_ids)
req0 = make_request("0", common_token_ids + unique_token_ids, block_size,
hash)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
@@ -518,14 +534,15 @@ def test_decode():
def test_evict():
block_size = 16
manager = KVCacheManager(
make_kv_cache_config(16, 11),
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
)
last_token_id = 5 * 16 + 7
req0 = make_request("0", list(range(last_token_id)))
req0 = make_request("0", list(range(last_token_id)), block_size, hash)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
@@ -536,7 +553,8 @@ def test_evict():
# 3 blocks.
req1 = make_request("1", list(range(last_token_id,
last_token_id + 3 * 16)))
last_token_id + 3 * 16)), block_size,
hash)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
@@ -558,7 +576,7 @@ def test_evict():
] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7]
# Touch the first 2 blocks.
req2 = make_request("2", list(range(2 * 16 + 3)))
req2 = make_request("2", list(range(2 * 16 + 3)), block_size, hash)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert computed_blocks.get_block_ids() == ([1, 2], )
assert num_computed_tokens == 2 * 16
@@ -583,7 +601,7 @@ def test_hash_block_correct_reuse():
# Allocate 1 block and cache it.
num_tokens = block_size * 1
req = make_request("0", list(range(num_tokens)))
req = make_request("0", list(range(num_tokens)), block_size, hash)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
@@ -597,7 +615,7 @@ def test_hash_block_correct_reuse():
# Allocate a new block that's not full, make sure hash info on the
# block is cleared.
req = make_request("1", list(range(num_tokens - 1)))
req = make_request("1", list(range(num_tokens - 1)), block_size, hash)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
@@ -624,7 +642,7 @@ def test_computed_blocks_not_evicted():
# Allocate a block and cache it.
num_tokens = block_size * 1
req0 = make_request("0", list(range(num_tokens)))
req0 = make_request("0", list(range(num_tokens)), block_size, hash)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
@@ -635,7 +653,8 @@ def test_computed_blocks_not_evicted():
assert blocks.blocks[0][0].block_id == 1
# Allocate another block.
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)))
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)),
block_size, hash)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
@@ -651,7 +670,7 @@ def test_computed_blocks_not_evicted():
# Now if we have a cache hit on the first block, we should evict the second
# cached block rather than the first one.
req2 = make_request("2", list(range(num_tokens * 2)))
req2 = make_request("2", list(range(num_tokens * 2)), block_size, hash)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks.blocks[0]) == 1
assert computed_blocks.blocks[0][0].block_id == 1
@@ -675,7 +694,8 @@ def test_basic_prefix_caching_disabled():
enable_caching=False,
)
req1 = make_request("1", list(range(10))) # 2 blocks and some more
req1 = make_request("1", list(range(10)), block_size,
hash) # 2 blocks and some more
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks.blocks[0]
@@ -689,7 +709,8 @@ def test_basic_prefix_caching_disabled():
manager.free(req1)
# No caching.
req2 = make_request("2", list(range(16))) # shared prefix
req2 = make_request("2", list(range(16)), block_size,
hash) # shared prefix
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
@@ -699,7 +720,7 @@ def test_basic_prefix_caching_disabled():
assert len(blocks.blocks[0]) == 4
# New requests should not have any blocks.
req3 = make_request("3", list(range(4)))
req3 = make_request("3", list(range(4)), block_size, hash)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
@@ -727,20 +748,17 @@ def test_cache_blocks(hash_fn):
# Block 1: [4, 5, 6, 7]
# Block 2: [8, 9, 10, 11]
# Block 3: [12, 13]
req = make_request("0", list(range(14)))
req = make_request("0", list(range(14)), block_size, hash_fn)
# Test that blocks are cached correctly for 2 full blocks from the start.
blocks = [KVCacheBlock(block_id=i) for i in range(2)]
block_hashes: list[BlockHash] = []
block_pool.cache_full_blocks(
request=req,
blocks=blocks,
block_hashes=block_hashes,
num_cached_blocks=0,
num_full_blocks=2,
block_size=block_size,
hash_fn=hash_fn,
kv_cache_group_id=0,
)
@@ -752,11 +770,9 @@ def test_cache_blocks(hash_fn):
block_pool.cache_full_blocks(
request=req,
blocks=blocks,
block_hashes=block_hashes,
num_cached_blocks=2,
num_full_blocks=3,
block_size=block_size,
hash_fn=hash_fn,
kv_cache_group_id=0,
)
assert len(block_pool.cached_block_hash_to_block) == 3
@@ -775,23 +791,20 @@ def test_cache_blocks_multi_group():
# Block 1/5: [4, 5, 6, 7]
# Block 2/6: [8, 9, 10, 11]
# Block 3/7: [12, 13]
req = make_request("0", list(range(14)))
req = make_request("0", list(range(14)), block_size, hash)
# Cache the blocks for group 0.
blocks = [KVCacheBlock(block_id=i) for i in range(2)]
block_hashes: list[BlockHash] = []
block_pool.cache_full_blocks(
request=req,
blocks=blocks,
block_hashes=block_hashes,
num_cached_blocks=0,
num_full_blocks=2,
block_size=block_size,
hash_fn=hash,
kv_cache_group_id=0,
)
assert len(block_pool.cached_block_hash_to_block) == 2
assert len(block_hashes) == 2
assert len(req.block_hashes) == 3
assert all([block.block_hash is not None for block in blocks])
# Cache the blocks for group 1.
@@ -799,38 +812,36 @@ def test_cache_blocks_multi_group():
block_pool.cache_full_blocks(
request=req,
blocks=blocks,
block_hashes=block_hashes,
num_cached_blocks=0,
num_full_blocks=3,
block_size=block_size,
hash_fn=hash,
kv_cache_group_id=1,
)
assert len(block_pool.cached_block_hash_to_block) == 5
assert len(block_hashes) == 3
assert len(req.block_hashes) == 3
assert all([block.block_hash is not None for block in blocks])
# Block hash 0: hit for group 0 and 1
# Block hash 1: hit for group 0 and 1
# Block hash 2: hit for group 1
assert block_pool.get_cached_block(block_hashes[0],
assert block_pool.get_cached_block(req.block_hashes[0],
kv_cache_group_ids=[0]) is not None
assert block_pool.get_cached_block(block_hashes[1],
assert block_pool.get_cached_block(req.block_hashes[1],
kv_cache_group_ids=[0]) is not None
assert block_pool.get_cached_block(block_hashes[2],
assert block_pool.get_cached_block(req.block_hashes[2],
kv_cache_group_ids=[0]) is None
assert block_pool.get_cached_block(block_hashes[0],
assert block_pool.get_cached_block(req.block_hashes[0],
kv_cache_group_ids=[1]) is not None
assert block_pool.get_cached_block(block_hashes[1],
assert block_pool.get_cached_block(req.block_hashes[1],
kv_cache_group_ids=[1]) is not None
assert block_pool.get_cached_block(block_hashes[2],
assert block_pool.get_cached_block(req.block_hashes[2],
kv_cache_group_ids=[1]) is not None
assert block_pool.get_cached_block(block_hashes[0],
assert block_pool.get_cached_block(req.block_hashes[0],
kv_cache_group_ids=[0, 1]) is not None
assert block_pool.get_cached_block(block_hashes[1],
assert block_pool.get_cached_block(req.block_hashes[1],
kv_cache_group_ids=[0, 1]) is not None
assert block_pool.get_cached_block(block_hashes[2],
assert block_pool.get_cached_block(req.block_hashes[2],
kv_cache_group_ids=[0, 1]) is None
@@ -838,8 +849,9 @@ def test_mm_prefix_caching():
"""
This tests that the multi-modal prefix caching is correct.
"""
block_size = 16
manager = KVCacheManager(
make_kv_cache_config(16, 11),
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
)
@@ -865,6 +877,8 @@ def test_mm_prefix_caching():
mm_hashes = common_mm_hashes + ["ccc"]
req0 = make_request("0",
all_token_ids,
block_size,
hash,
mm_positions=mm_positions,
mm_hashes=mm_hashes)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
@@ -872,7 +886,7 @@ def test_mm_prefix_caching():
# Completed block should have hashes with extra keys.
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
block_hashes = manager.req_to_block_hashes[req0.request_id]
block_hashes = req0.block_hashes
assert len(block_hashes) == 3
assert block_hashes[0].extra_keys == ("aaa", )
assert block_hashes[1].extra_keys == ("aaa", "bbb")
@@ -905,6 +919,8 @@ def test_mm_prefix_caching():
mm_hashes = common_mm_hashes + ["ccc"]
req1 = make_request("1",
all_token_ids,
block_size,
hash,
mm_positions=mm_positions,
mm_hashes=mm_hashes)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
@@ -927,13 +943,13 @@ def test_cache_key_salting():
# 3 complete blocks and an incomplete block with 11 tokens.
common_token_ids = [i for i in range(3) for _ in range(block_size)]
token_ids = common_token_ids + [3] * 11
req0 = make_request("0", token_ids, cache_salt="salt1")
req0 = make_request("0", token_ids, block_size, hash, cache_salt="salt1")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
# Completed block should have hashes with extra keys.
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
block_hashes = manager.req_to_block_hashes[req0.request_id]
block_hashes = req0.block_hashes
assert len(block_hashes) == 3
assert block_hashes[0].extra_keys == ("salt1", )
assert block_hashes[1].extra_keys is None
@@ -959,7 +975,7 @@ def test_cache_key_salting():
# Test cache hit with a new request that has the same salt.
token_ids = common_token_ids + [4] * 11
req1 = make_request("1", token_ids, cache_salt="salt1")
req1 = make_request("1", token_ids, block_size, hash, cache_salt="salt1")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
# Should match only a prefix of 3 blocks.
assert len(computed_blocks.blocks[0]) == 3
@@ -967,11 +983,11 @@ def test_cache_key_salting():
# Test cache miss with same content but different salt.
token_ids = common_token_ids + [4] * 11
req2 = make_request("2", token_ids, cache_salt="salt2")
req2 = make_request("2", token_ids, block_size, hash, cache_salt="salt2")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks.blocks[0]) == 0
assert num_computed_tokens == 0
block_hashes = manager.req_to_block_hashes[req2.request_id]
block_hashes = req2.block_hashes
assert len(block_hashes) == 3
assert block_hashes[0].extra_keys == ("salt2", )
@@ -992,7 +1008,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# Complete 3 blocks (48 tokens)
# | Common-0 | Common-1 | Common-2 | ... |
common_token_ids = [i for i in range(3) for _ in range(16)]
req0 = make_request("0", common_token_ids)
req0 = make_request("0", common_token_ids, block_size, hash)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
@@ -1003,7 +1019,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
req0.request_id]
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
req1 = make_request("1", common_token_ids * 2)
req1 = make_request("1", common_token_ids * 2, block_size, hash)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert computed_blocks.blocks[0] == block_part0
assert num_computed_tokens == 3 * 16
@@ -1020,19 +1036,19 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
req2 = make_request("2", [7] * block_size * 2)
req2 = make_request("2", [7] * block_size * 2, block_size, hash)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
manager.allocate_slots(req2, block_size * 2,
len(computed_blocks.blocks[0]) * 16,
len(computed_blocks.blocks[0]) * block_size,
computed_blocks)
# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
# but it cannot be allocated due to insufficient free blocks (2).
# In this case, the ref_cnt of the computed blocks should not be changed.
assert manager.block_pool.free_block_queue.num_free_blocks == 5
req3 = make_request("3", common_token_ids * 3)
req3 = make_request("3", common_token_ids * 3, block_size, hash)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert computed_blocks.blocks[0] == block_part1
assert num_computed_tokens == 6 * 16
@@ -1047,8 +1063,9 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
def test_reset_prefix_cache():
block_size = 16
manager = KVCacheManager(
make_kv_cache_config(16, 11),
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
)
@@ -1056,15 +1073,15 @@ def test_reset_prefix_cache():
full_block_token_ids = [i for i in range(3) for _ in range(16)]
unique_token_ids = [3] * 7
all_token_ids = full_block_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids)
req0 = make_request("0", all_token_ids, block_size, hash)
blocks = manager.allocate_slots(req0, 55)
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
unique_token_ids = [4] * 7
all_token_ids = full_block_token_ids + unique_token_ids
req1 = make_request("1", all_token_ids)
req1 = make_request("1", all_token_ids, block_size, hash)
computed_blocks, _ = manager.get_computed_blocks(req1)
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
assert len(req1.block_hashes) == 3
assert len(computed_blocks.blocks[0]) == 3
blocks = manager.allocate_slots(req1, 7,
len(computed_blocks.blocks[0]) * 16,
@@ -1086,8 +1103,9 @@ def test_reset_prefix_cache():
def test_prefix_cache_stats_disabled():
"""Test that prefix_cache_stats is None when log_stats is False."""
block_size = 16
manager = KVCacheManager(
make_kv_cache_config(16, 11),
make_kv_cache_config(block_size, 11),
max_model_len=8192,
enable_caching=True,
log_stats=False, # Disable logging stats
@@ -1095,7 +1113,7 @@ def test_prefix_cache_stats_disabled():
assert manager.prefix_cache_stats is None
# Call all functions that check whether log_stats is disabled.
req = make_request("0", list(range(16)))
req = make_request("0", list(range(16)), block_size, hash)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
@@ -1192,7 +1210,7 @@ def test_kv_cache_events(blocks_to_cache: int):
)
num_tokens = block_size * blocks_to_cache
req0 = make_request("0", list(range(num_tokens)))
req0 = make_request("0", list(range(num_tokens)), block_size, hash)
_ = manager.allocate_slots(req0, num_tokens)
events = manager.take_events()
@@ -1208,7 +1226,7 @@ def test_kv_cache_events(blocks_to_cache: int):
# Should see block_to_cache number of removed block events and a new block
# stored event
manager.free(req0)
req1 = make_request("1", list(range(num_tokens)))
req1 = make_request("1", list(range(num_tokens)), block_size, hash)
_ = manager.allocate_slots(req1, num_tokens)
events = manager.take_events()
@@ -1242,7 +1260,7 @@ def test_eagle_enabled_removes_last_block():
# Request with 3 full blocks (48 tokens)
token_ids = [0] * (3 * block_size)
req = make_request("divisible_request", token_ids)
req = make_request("divisible_request", token_ids, block_size, hash)
# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
@@ -1252,7 +1270,7 @@ def test_eagle_enabled_removes_last_block():
manager.free(req)
# New request with same tokens + Eagle enabled
req_eagle = make_request("eagle_divisible", token_ids)
req_eagle = make_request("eagle_divisible", token_ids, block_size, hash)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Should retain 1 block:
@@ -1273,7 +1291,7 @@ def test_eagle_with_partial_blocks():
)
# 2 full blocks + 5 tokens (non-divisible length)
token_ids = [0] * (2 * block_size + 5)
req = make_request("partial_block_test", token_ids)
req = make_request("partial_block_test", token_ids, block_size, hash)
# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
@@ -1283,7 +1301,7 @@ def test_eagle_with_partial_blocks():
manager.free(req)
# New request with Eagle enabled
req_eagle = make_request("partial_eagle", token_ids)
req_eagle = make_request("partial_eagle", token_ids, block_size, hash)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert len(computed_blocks.blocks[0]) == 1
@@ -1314,7 +1332,7 @@ def test_eagle_with_sliding_window():
# 2 full blocks + 5 tokens (non-divisible length)
token_ids = [0] * (2 * block_size + 5)
req = make_request("partial_block_test", token_ids)
req = make_request("partial_block_test", token_ids, block_size, hash)
# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
@@ -1322,12 +1340,12 @@ def test_eagle_with_sliding_window():
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
# record the block hash of the first block in the request for later use
block_hash_first_block = manager.req_to_block_hashes[req.request_id][0]
block_hash_first_block = req.block_hashes[0]
assert block_hash_first_block is not None
manager.free(req)
# New request with Eagle enabled
req_eagle = make_request("partial_eagle", token_ids)
req_eagle = make_request("partial_eagle", token_ids, block_size, hash)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert len(computed_blocks.blocks[0]) == 1
@@ -1340,7 +1358,8 @@ def test_eagle_with_sliding_window():
BlockHashWithGroupId(block_hash_first_block, 0))
# New request
req_after_evict = make_request("partial_eagle_after_evict", token_ids)
req_after_evict = make_request("partial_eagle_after_evict", token_ids,
block_size, hash)
computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict)
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
# not considered. But after dropping the last matched block due to eagle,

View File

@@ -589,7 +589,7 @@ def test_preempt_during_execution():
block_size=16,
num_blocks=11,
enable_prefix_caching=False)
requests = create_requests(num_requests=2, num_tokens=80)
requests = create_requests(num_requests=2, num_tokens=80, block_size=16)
# Schedule the first request.
scheduler.add_request(requests[0])
@@ -762,7 +762,7 @@ def _assert_right_scheduler_output(
def _assert_right_kv_cache_manager(
scheduler: Scheduler,
req_ids: list[str],
requests: list[Request],
num_tokens: int,
block_size: int,
num_requests: int,
@@ -772,12 +772,12 @@ def _assert_right_kv_cache_manager(
# Make sure the request stats are right.
EXPECTED_TOTAL_BLOCKS = num_tokens // block_size
for req_id in req_ids:
for req in requests:
blocks = (scheduler.kv_cache_manager.coordinator.
single_type_managers[0].req_to_blocks[req_id])
hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id]
single_type_managers[0].req_to_blocks[req.request_id])
hashes = req.block_hashes
assert (scheduler.kv_cache_manager.coordinator.single_type_managers[0].
num_cached_block[req_id] == EXPECTED_TOTAL_BLOCKS)
num_cached_block[req.request_id] == EXPECTED_TOTAL_BLOCKS)
assert len(blocks) == EXPECTED_TOTAL_BLOCKS
assert len(hashes) == EXPECTED_TOTAL_BLOCKS
@@ -840,7 +840,8 @@ def test_kv_connector_basic():
MAX_TOKENS = 3
requests = create_requests(num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS)
max_tokens=MAX_TOKENS,
block_size=BLOCK_SIZE)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
@@ -868,7 +869,7 @@ def test_kv_connector_basic():
)
# Ensure KVCacheManager is correct.
_assert_right_kv_cache_manager(scheduler, req_ids, NUM_TOKENS, BLOCK_SIZE,
_assert_right_kv_cache_manager(scheduler, requests, NUM_TOKENS, BLOCK_SIZE,
NUM_REQUESTS, NUM_TOTAL_BLOCKS)
# Continue Generation until done.
@@ -886,7 +887,8 @@ def test_kv_connector_basic():
NUM_TOKENS = NUM_TOKENS_PREFIX * 2
requests = create_requests(num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS)
max_tokens=MAX_TOKENS,
block_size=BLOCK_SIZE)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
@@ -915,7 +917,7 @@ def test_kv_connector_basic():
NUM_MATCHED_NEW_TOKENS))
# Ensure KVCacheManager is correct.
_assert_right_kv_cache_manager(scheduler, req_ids, NUM_TOKENS, BLOCK_SIZE,
_assert_right_kv_cache_manager(scheduler, requests, NUM_TOKENS, BLOCK_SIZE,
NUM_REQUESTS, NUM_TOTAL_BLOCKS)
# Continue Generation until done.
@@ -953,7 +955,8 @@ def test_kv_connector_unable_to_allocate():
MAX_TOKENS = 2
requests = create_requests(num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS)
max_tokens=MAX_TOKENS,
block_size=BLOCK_SIZE)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
@@ -1034,7 +1037,8 @@ def test_kv_connector_handles_preemption():
MAX_TOKENS = BLOCK_SIZE * 2
requests = create_requests(num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS)
max_tokens=MAX_TOKENS,
block_size=BLOCK_SIZE)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
@@ -1162,7 +1166,6 @@ def assert_scheduler_empty(scheduler: Scheduler):
# KVCache Manager.
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
num_cached_block) == 0
num_free_blocks = (

View File

@@ -17,7 +17,6 @@ from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
def get_sliding_window_manager(sliding_window_spec, block_pool):
return SlidingWindowManager(sliding_window_spec,
block_pool,
caching_hash_fn=lambda x: x,
kv_cache_group_id=0)
@@ -25,7 +24,6 @@ def get_chunked_local_attention_manager(chunked_local_attention_spec,
block_pool):
return ChunkedLocalAttentionManager(chunked_local_attention_spec,
block_pool,
caching_hash_fn=lambda x: x,
kv_cache_group_id=0)

View File

@@ -10,6 +10,8 @@ from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargsItem,
PlaceholderRange)
from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash)
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
@@ -114,6 +116,9 @@ def create_scheduler(
)
_none_hash_initialized = False
def create_requests(
num_requests: int,
num_tokens: int = 10,
@@ -122,7 +127,14 @@ def create_requests(
stop_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[int] = None,
same_prompt: bool = False,
block_size: int = 16,
) -> list[Request]:
global _none_hash_initialized
if not _none_hash_initialized:
init_none_hash(hash)
_none_hash_initialized = True
block_hasher = get_request_block_hasher(block_size, hash)
sampling_params = SamplingParams(ignore_eos=False,
max_tokens=max_tokens,
stop_token_ids=stop_token_ids,
@@ -139,9 +151,11 @@ def create_requests(
)
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
mm_kwargs = [mm_item] * len(mm_position)
mm_hashes = ["hash"] * len(mm_position)
else:
mm_position = None
mm_kwargs = None
mm_hashes = None
prompt_token_ids = ([0] * num_tokens if same_prompt else [i] *
num_tokens)
request = Request(
@@ -151,8 +165,9 @@ def create_requests(
pooling_params=None,
multi_modal_kwargs=mm_kwargs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
multi_modal_hashes=mm_hashes,
eos_token_id=EOS_TOKEN_ID,
block_hasher=block_hasher,
)
requests.append(request)
return requests

View File

@@ -147,6 +147,7 @@ def test_basic_interface():
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
request_id = request.request_id
@@ -186,6 +187,7 @@ def test_prompt_less_than_block_size():
# Request will have 1 partial remote block.
request = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_prefill=True,
num_remote_blocks=1)

View File

@@ -21,6 +21,7 @@ def test_basic_lifecycle():
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(request_id=1,
block_size=BLOCK_SIZE,
max_tokens=1,
num_tokens=NUM_TOKENS,
do_remote_decode=True)
@@ -103,8 +104,10 @@ def test_short_prompt_lifecycle():
scheduler = create_scheduler(vllm_config)
# Not enough tokens for full block.
NUM_TOKENS = vllm_config.cache_config.block_size // 2
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_TOKENS = BLOCK_SIZE // 2
request = create_request(request_id=1,
block_size=BLOCK_SIZE,
max_tokens=1,
num_tokens=NUM_TOKENS,
do_remote_decode=True)
@@ -148,7 +151,9 @@ def test_prefix_cache_lifecycle():
NUM_EXTERNAL_FULL_BLOCKS = 3
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS)
request_normal = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS)
scheduler.add_request(request_normal)
scheduler_output = scheduler.schedule()
@@ -166,6 +171,7 @@ def test_prefix_cache_lifecycle():
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request_remote = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_decode=True)

View File

@@ -23,6 +23,7 @@ def test_basic_lifecycle():
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
request = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
@@ -133,14 +134,17 @@ def test_interleaved_lifecycle():
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request_remote = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
request_local_a = create_request(
request_id=2,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
)
request_local_b = create_request(
request_id=3,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
)
@@ -236,6 +240,7 @@ def test_no_spurious_prefix_caching():
# Both of these requests have prompts like [1,1,1,1,1, ...]
request_remote = create_request(
request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_prefill=True,
use_all_1s_for_prompt_tokens=True,
@@ -243,6 +248,7 @@ def test_no_spurious_prefix_caching():
request_local = create_request(
request_id=2,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_prefill=False,
use_all_1s_for_prompt_tokens=True,
@@ -292,6 +298,7 @@ def test_full_block_prompt():
NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS)
request = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
@@ -364,8 +371,11 @@ def test_cannot_schedule_after_recv():
NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
NUM_TOKENS_REMOTE = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS_LOCAL)
request_normal = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS_LOCAL)
request_remote = create_request(request_id=2,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS_REMOTE,
do_remote_prefill=True)
@@ -456,8 +466,11 @@ def test_cannot_recv():
NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5))
request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS_LOCAL)
request_normal = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS_LOCAL)
request_remote = create_request(request_id=2,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS_REMOTE,
do_remote_prefill=True)

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile
from collections import defaultdict
from typing import Any, Optional
from typing import Any, Callable, Optional
import torch
@@ -14,6 +14,8 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
SharedStorageConnector)
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash)
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
@@ -40,7 +42,6 @@ def assert_scheduler_empty(scheduler: Scheduler):
# KVCache Manager.
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
num_cached_block) == 0
num_free_blocks = (
@@ -115,16 +116,23 @@ def create_scheduler(
)
def create_request(
request_id: int,
num_tokens: int = 10,
max_tokens: int = 16,
do_remote_decode: bool = False,
do_remote_prefill: bool = False,
use_all_1s_for_prompt_tokens: bool = False,
num_remote_blocks: int = 3,
) -> Request:
_none_hash_initialized = False
def create_request(request_id: int,
num_tokens: int = 10,
max_tokens: int = 16,
do_remote_decode: bool = False,
do_remote_prefill: bool = False,
use_all_1s_for_prompt_tokens: bool = False,
num_remote_blocks: int = 3,
block_size: int = 16,
hash_fn: Callable = hash) -> Request:
"""Make dummy request for testing."""
global _none_hash_initialized
if not _none_hash_initialized:
init_none_hash(hash)
_none_hash_initialized = True
kv_transfer_params: Optional[dict[str, Any]] = None
@@ -158,6 +166,7 @@ def create_request(
multi_modal_placeholders=None,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
block_hasher=get_request_block_hasher(block_size, hash_fn),
)
req.kv_transfer_params = kv_transfer_params
return req

View File

@@ -3243,6 +3243,24 @@ def sha256_cbor_64bit(input) -> int:
return full_hash & ((1 << 64) - 1)
def get_hash_fn_by_name(hash_fn_name: str) -> Callable:
"""Get a hash function by name, or raise an error if
the function is not found.
Args:
hash_fn_name: Name of the hash function.
Returns:
A hash function.
"""
if hash_fn_name == "sha256":
return sha256
if hash_fn_name == "sha256_cbor_64bit":
return sha256_cbor_64bit
if hash_fn_name == "builtin":
return hash
raise ValueError(f"Unsupported hash function: {hash_fn_name}")
def is_torch_equal_or_newer(target: str) -> bool:
"""Check if the installed torch version is >= the target version.

View File

@@ -2,15 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from collections.abc import Iterable
from typing import Callable, Optional
from typing import Optional
from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved,
BlockStored, KVCacheEvent)
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
FreeKVCacheBlockQueue, KVCacheBlock,
generate_block_hash_extra_keys,
hash_block_tokens)
FreeKVCacheBlockQueue, KVCacheBlock)
from vllm.v1.request import Request
logger = init_logger(__name__)
@@ -97,84 +95,39 @@ class BlockPool:
self,
request: Request,
blocks: list[KVCacheBlock],
block_hashes: list[BlockHash],
num_cached_blocks: int,
num_full_blocks: int,
block_size: int,
kv_cache_group_id: int,
hash_fn: Callable,
) -> None:
"""Cache a list of full blocks for prefix caching.
This function takes a list of blocks that will have their block hash
metadata to be updated and cached. Given a request, it computes the
block hashes for the blocks starting from `num_cached_blocks` to
`num_full_blocks`, updating the metadata for each block
and caching them in the `cached_block_hash_to_block`.
metadata to be updated and cached. Given a request, it updates the
metadata for each block and caching it in the
`cached_block_hash_to_block`.
The block hashes values are computed by the Request object immediately
when it is created and when new tokens are appended.
Args:
request: The request to cache the blocks.
blocks: All blocks in the request.
block_hashes: Block hashes of the blocks in the request. Note that
this list may be shorter than the blocks list. In this case the
missed block hash will be computed in this function.
num_cached_blocks: The number of blocks that are already cached.
num_full_blocks: The number of blocks that are full and should
be cached after this function.
block_size: Number of tokens in each block.
kv_cache_group_id: The id of the KV cache group.
hash_fn: The hash function to use for block hashes.
"""
if num_cached_blocks == num_full_blocks:
return
new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
assert len(block_hashes) >= num_cached_blocks
new_block_hashes = block_hashes[num_cached_blocks:]
assert len(request.block_hashes) >= num_full_blocks
new_block_hashes = request.block_hashes[num_cached_blocks:]
# Update the new blocks with the block hashes through the chain.
if num_cached_blocks == 0:
prev_block_hash_value = None
else:
prev_block = blocks[num_cached_blocks - 1]
assert prev_block.block_hash is not None
prev_block_hash_value = prev_block.block_hash.get_hash_value()
parent_block_hash = prev_block_hash_value
new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events
else None)
for i, blk in enumerate(new_full_blocks):
assert blk.block_hash is None
if i < len(new_block_hashes):
# The block hash may already be computed in
# "get_computed_blocks" if the tokens are not generated by
# this request (either the prompt tokens or the previously
# generated tokens with preemption), or by other
# single_type_managers with the same block_size.
# In this case we simply reuse the block hash.
block_hash = new_block_hashes[i]
else:
# Otherwise compute the block hash and cache it in the request
# in case it will be preempted in the future.
blk_idx = num_cached_blocks + i
start_token_idx = blk_idx * block_size
end_token_idx = (blk_idx + 1) * block_size
block_tokens = request.all_token_ids[
start_token_idx:end_token_idx]
assert len(block_tokens) == block_size, (
f"Expected {block_size} tokens, got "
f"{len(block_tokens)} at {blk_idx}th block for request "
f"{request.request_id}({request})")
# Generate extra keys for multi-modal inputs. Note that since
# we reach to this branch only when the block is completed with
# generated tokens, we only need to consider the last mm input.
extra_keys, _ = generate_block_hash_extra_keys(
request, start_token_idx, end_token_idx, -1)
# Compute the hash of the current block.
block_hash = hash_block_tokens(hash_fn, prev_block_hash_value,
block_tokens, extra_keys)
block_hashes.append(block_hash)
block_hash = new_block_hashes[i]
# Update and added the full block to the cache.
block_hash_with_group_id = BlockHashWithGroupId(
@@ -184,9 +137,15 @@ class BlockPool:
blk.block_id] = blk
if new_hashes is not None:
new_hashes.append(block_hash.hash_value)
prev_block_hash_value = block_hash.hash_value
if self.enable_kv_cache_events:
if num_cached_blocks == 0:
parent_block_hash = None
else:
parent_block = blocks[num_cached_blocks - 1]
assert parent_block.block_hash is not None
parent_block_hash = parent_block.block_hash.get_hash_value()
self.kv_event_queue.append(
BlockStored(
block_hashes=new_hashes,

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Callable, Optional
from typing import Optional
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
@@ -23,7 +23,6 @@ class KVCacheCoordinator(ABC):
max_model_len: int,
use_eagle: bool,
enable_caching: bool,
caching_hash_fn: Callable,
enable_kv_cache_events: bool,
):
self.kv_cache_config = kv_cache_config
@@ -40,7 +39,6 @@ class KVCacheCoordinator(ABC):
kv_cache_spec=kv_cache_group.kv_cache_spec,
block_pool=self.block_pool,
kv_cache_group_id=i,
caching_hash_fn=caching_hash_fn,
) for i, kv_cache_group in enumerate(
self.kv_cache_config.kv_cache_groups))
@@ -99,19 +97,17 @@ class KVCacheCoordinator(ABC):
manager.allocate_new_blocks(request_id, num_tokens)
for manager in self.single_type_managers)
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
num_computed_tokens: int) -> None:
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
"""
Cache the blocks for the request.
Args:
request: The request.
block_hashes: The block hashes of the request.
num_tokens: The total number of tokens that need to be cached
(including tokens that are already cached).
"""
for manager in self.single_type_managers:
manager.cache_blocks(request, block_hashes, num_computed_tokens)
manager.cache_blocks(request, num_computed_tokens)
def free(self, request_id: str) -> None:
"""
@@ -184,10 +180,9 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
"""
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
use_eagle: bool, caching_hash_fn: Callable,
enable_kv_cache_events: bool):
use_eagle: bool, enable_kv_cache_events: bool):
super().__init__(kv_cache_config, max_model_len, use_eagle, False,
caching_hash_fn, enable_kv_cache_events)
enable_kv_cache_events)
self.num_single_type_manager = len(self.single_type_managers)
def get_num_common_prefix_blocks(self, request_id: str,
@@ -213,10 +208,9 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
use_eagle: bool, enable_caching: bool,
caching_hash_fn: Callable, enable_kv_cache_events: bool):
enable_kv_cache_events: bool):
super().__init__(kv_cache_config, max_model_len, use_eagle,
enable_caching, caching_hash_fn,
enable_kv_cache_events)
enable_caching, enable_kv_cache_events)
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[
0].kv_cache_spec
self.block_size = self.kv_cache_spec.block_size
@@ -250,10 +244,9 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
use_eagle: bool, enable_caching: bool,
caching_hash_fn: Callable, enable_kv_cache_events: bool):
enable_kv_cache_events: bool):
super().__init__(kv_cache_config, max_model_len, use_eagle,
enable_caching, caching_hash_fn,
enable_kv_cache_events)
enable_caching, enable_kv_cache_events)
self.verify_and_split_kv_cache_groups()
def verify_and_split_kv_cache_groups(self) -> None:
@@ -386,17 +379,15 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
def get_kv_cache_coordinator(
kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool,
enable_caching: bool, caching_hash_fn: Callable,
enable_caching: bool,
enable_kv_cache_events: bool) -> KVCacheCoordinator:
if not enable_caching:
return KVCacheCoordinatorNoPrefixCache(kv_cache_config, max_model_len,
use_eagle, caching_hash_fn,
use_eagle,
enable_kv_cache_events)
if len(kv_cache_config.kv_cache_groups) == 1:
return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len,
use_eagle, enable_caching,
caching_hash_fn,
enable_kv_cache_events)
return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle,
enable_caching, caching_hash_fn,
enable_kv_cache_events)
enable_caching, enable_kv_cache_events)

View File

@@ -1,16 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from dataclasses import dataclass
from typing import Optional
from vllm.distributed.kv_events import KVCacheEvent
from vllm.logger import init_logger
from vllm.utils import sha256, sha256_cbor_64bit
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
hash_request_tokens, init_none_hash)
from vllm.v1.core.kv_cache_utils import KVCacheBlock
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request, RequestStatus
@@ -71,23 +68,13 @@ class KVCacheManager:
kv_cache_config: KVCacheConfig,
max_model_len: int,
enable_caching: bool = True,
caching_hash_algo: str = "builtin",
use_eagle: bool = False,
log_stats: bool = False,
enable_kv_cache_events: bool = False,
) -> None:
self.max_model_len = max_model_len
if len(kv_cache_config.kv_cache_groups) == 0:
# Attention free models don't have kv cache,
# thus don't need prefix caching.
enable_caching = False
self.enable_caching = enable_caching
self.caching_hash_fn = (
sha256_cbor_64bit if caching_hash_algo == "sha256_cbor_64bit" else
sha256 if caching_hash_algo == "sha256" else hash)
init_none_hash(self.caching_hash_fn)
self.use_eagle = use_eagle
self.log_stats = log_stats
# FIXME: make prefix cache stats conditional on log_stats
@@ -107,19 +94,12 @@ class KVCacheManager:
max_model_len=self.max_model_len,
use_eagle=self.use_eagle,
enable_caching=self.enable_caching,
caching_hash_fn=self.caching_hash_fn,
enable_kv_cache_events=enable_kv_cache_events,
)
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
self.block_pool = self.coordinator.block_pool
self.kv_cache_config = kv_cache_config
# Mapping from request ID to kv block hashes.
# This is to avoid recomputing the block hashes for each call of
# `get_computed_blocks` or `allocate_slots`.
self.req_to_block_hashes: defaultdict[
str, list[BlockHash]] = defaultdict(list)
@property
def usage(self) -> float:
"""Get the KV cache usage.
@@ -161,15 +141,6 @@ class KVCacheManager:
and request.sampling_params.prompt_logprobs is not None)):
return self.create_empty_block_list(), 0
# The block hashes for the request may already be computed
# if the scheduler has tried to schedule the request before.
block_hashes = self.req_to_block_hashes[request.request_id]
if not block_hashes:
assert self.block_size is not None
block_hashes = hash_request_tokens(self.caching_hash_fn,
self.block_size, request)
self.req_to_block_hashes[request.request_id] = block_hashes
# NOTE: When all tokens hit the cache, we must recompute the last token
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
# This can trigger recomputation of an entire block, rather than just
@@ -178,7 +149,7 @@ class KVCacheManager:
# could slightly improve performance in the future.
max_cache_hit_length = request.num_tokens - 1
computed_blocks, num_new_computed_tokens = (
self.coordinator.find_longest_cache_hit(block_hashes,
self.coordinator.find_longest_cache_hit(request.block_hashes,
max_cache_hit_length))
if self.log_stats:
@@ -296,11 +267,7 @@ class KVCacheManager:
# at `request.num_tokens`, ensuring only "finalized" tokens are cached.
num_tokens_to_cache = min(num_computed_tokens + num_new_tokens,
request.num_tokens)
self.coordinator.cache_blocks(
request,
self.req_to_block_hashes[request.request_id],
num_tokens_to_cache,
)
self.coordinator.cache_blocks(request, num_tokens_to_cache)
return KVCacheBlocks(new_blocks)
@@ -373,14 +340,6 @@ class KVCacheManager:
return self.coordinator.get_num_common_prefix_blocks(
request.request_id, num_running_requests)
def free_block_hashes(self, request: Request) -> None:
"""Discard the block hashes for the request.
NOTE: Unlike `free`, this method should be called only when the request
is finished, not when it is preempted.
"""
self.req_to_block_hashes.pop(request.request_id, None)
def take_events(self) -> list[KVCacheEvent]:
"""Take the KV cache events from the block pool.
@@ -397,9 +356,7 @@ class KVCacheManager:
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
"""Cache the blocks for the request, if enabled."""
if self.enable_caching:
block_hashes = self.req_to_block_hashes[request.request_id]
self.coordinator.cache_blocks(request, block_hashes,
num_computed_tokens)
self.coordinator.cache_blocks(request, num_computed_tokens)
def create_empty_block_list(self) -> KVCacheBlocks:
"""Creates a new KVCacheBlocks instance with no blocks."""

View File

@@ -547,41 +547,61 @@ def hash_block_tokens(
curr_block_token_ids_tuple, extra_keys)
def hash_request_tokens(hash_function: Any, block_size: int,
request: Request) -> list[BlockHash]:
"""Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching.
Args:
block_size: The size of each block.
request: The request object.
Returns:
The list of computed hash values.
def get_request_block_hasher(
block_size: int,
caching_hash_fn: Callable[[Any],
int]) -> Callable[[Request], list[BlockHash]]:
"""
token_ids = request.all_token_ids
Returns a function which computes the list of un-computed block hashes
of a request.
req_need_extra_keys = need_extra_keys(request)
req_extra_keys = None
curr_mm_idx = 0
Each request holds a list of its block hashes (request.block_hashes).
When a request is created, it calls the below function to compute
the hashes of all full blocks of the request's initial tokens.
The hashes are then stored in request.block_hashes.
Later, whenever new tokens are appended to the request, it calls
the below function again to compute any new full blocks of tokens.
The returned new hashes are appended to request.block_hashes.
"""
ret = []
parent_block_hash_value = None
# Only full blocks will be hashed
for start in range(0, len(token_ids) - block_size + 1, block_size):
end = start + block_size
block_token_ids = token_ids[start:end]
def request_block_hasher(request: Request) -> list[BlockHash]:
start_token_idx = len(request.block_hashes) * block_size
num_tokens = request.num_tokens
curr_mm_idx = 0
if start_token_idx > 0:
# Set curr_mm_idx = -1 to indicate the last mm input.
# Note that since we reach to this branch only when the block is
# completed with generated tokens, we only need to consider the
# last mm input.
curr_mm_idx = -1
prev_block_hash_value = request.block_hashes[-1].hash_value \
if request.block_hashes else None
new_block_hashes: list[BlockHash] = []
while True:
end_token_idx = start_token_idx + block_size
if end_token_idx > num_tokens:
# We only hash full blocks
break
if req_need_extra_keys:
# MM and LoRA requests need extra keys for block-hash computation.
req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
request, start, end, curr_mm_idx)
extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
request, start_token_idx, end_token_idx, curr_mm_idx)
block_hash = hash_block_tokens(hash_function, parent_block_hash_value,
block_token_ids, req_extra_keys)
ret.append(block_hash)
parent_block_hash_value = block_hash.hash_value
return ret
# Compute the hash of the current block
block_tokens = request.all_token_ids[start_token_idx:end_token_idx]
block_hash = hash_block_tokens(caching_hash_fn,
prev_block_hash_value, block_tokens,
extra_keys)
new_block_hashes.append(block_hash)
start_token_idx += block_size
prev_block_hash_value = block_hash.hash_value
return new_block_hashes
return request_block_hasher
def max_memory_usage_bytes(vllm_config: VllmConfig,

View File

@@ -155,7 +155,6 @@ class Scheduler(SchedulerInterface):
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
enable_caching=self.cache_config.enable_prefix_caching,
caching_hash_algo=self.cache_config.prefix_caching_hash_algo,
use_eagle=self.use_eagle,
log_stats=self.log_stats,
enable_kv_cache_events=self.enable_kv_cache_events,
@@ -1036,7 +1035,6 @@ class Scheduler(SchedulerInterface):
def _free_blocks(self, request: Request):
assert request.is_finished()
self.kv_cache_manager.free(request)
self.kv_cache_manager.free_block_hashes(request)
del self.requests[request.request_id]
def get_num_unfinished_requests(self) -> int:

View File

@@ -3,7 +3,6 @@
import itertools
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Callable
from vllm.utils import cdiv
from vllm.v1.core.block_pool import BlockPool
@@ -25,7 +24,6 @@ class SingleTypeKVCacheManager(ABC):
kv_cache_spec: KVCacheSpec,
block_pool: BlockPool,
kv_cache_group_id: int,
caching_hash_fn: Callable,
) -> None:
"""
Initializes the SingleTypeKVCacheManager.
@@ -33,7 +31,6 @@ class SingleTypeKVCacheManager(ABC):
kv_cache_spec: The kv_cache_spec for this manager.
block_pool: The block pool.
kv_cache_group_id: The id of the kv cache group of this manager.
caching_hash_fn: The caching hash function.
"""
self.block_size = kv_cache_spec.block_size
@@ -52,7 +49,6 @@ class SingleTypeKVCacheManager(ABC):
# data for reempted ones.
self.num_cached_block: dict[str, int] = {}
self.caching_hash_fn = caching_hash_fn
self.kv_cache_group_id = kv_cache_group_id
self._null_block = block_pool.null_block
@@ -130,14 +126,12 @@ class SingleTypeKVCacheManager(ABC):
req_blocks.extend(new_blocks)
return new_blocks
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
num_tokens: int) -> None:
def cache_blocks(self, request: Request, num_tokens: int) -> None:
"""
Cache the blocks for the request.
Args:
request: The request.
block_hashes: The block hashes of the request.
num_tokens: The total number of tokens that need to be cached
(including tokens that are already cached).
"""
@@ -147,12 +141,10 @@ class SingleTypeKVCacheManager(ABC):
self.block_pool.cache_full_blocks(
request=request,
blocks=self.req_to_blocks[request.request_id],
block_hashes=block_hashes,
num_cached_blocks=num_cached_blocks,
num_full_blocks=num_full_blocks,
block_size=self.block_size,
kv_cache_group_id=self.kv_cache_group_id,
hash_fn=self.caching_hash_fn,
)
self.num_cached_block[request.request_id] = num_full_blocks

View File

@@ -25,9 +25,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.utils import (decorate_logs, make_zmq_socket,
from vllm.utils import (decorate_logs, get_hash_fn_by_name, make_zmq_socket,
resolve_obj_by_qualname, set_process_title)
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
from vllm.v1.core.kv_cache_utils import (BlockHash, get_kv_cache_config,
get_request_block_hasher,
init_none_hash,
unify_kv_cache_configs)
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import SchedulerOutput
@@ -140,6 +142,19 @@ class EngineCore:
self.batch_queue_size)
self.batch_queue = queue.Queue(self.batch_queue_size)
self.request_block_hasher: Optional[Callable[[Request],
list[BlockHash]]] = None
if (self.vllm_config.cache_config.enable_prefix_caching
or self.scheduler.get_kv_connector() is not None):
block_size = vllm_config.cache_config.block_size
caching_hash_fn = get_hash_fn_by_name(
vllm_config.cache_config.prefix_caching_hash_algo)
init_none_hash(caching_hash_fn)
self.request_block_hasher = get_request_block_hasher(
block_size, caching_hash_fn)
def _initialize_kv_caches(
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
start = time.time()
@@ -417,7 +432,8 @@ class EngineCore:
request.mm_kwargs = self.mm_input_cache_server.get_and_update(
request.mm_kwargs, request.mm_hashes)
req = Request.from_engine_core_request(request)
req = Request.from_engine_core_request(request,
self.request_block_hasher)
if req.use_structured_output:
# Note on thread safety: no race condition.
# `grammar_init` is only invoked in input processing thread. For

View File

@@ -3,7 +3,8 @@
import enum
import time
from typing import TYPE_CHECKING, Any, Optional, Union
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.pooling_params import PoolingParams
@@ -16,6 +17,7 @@ from vllm.v1.utils import ConstantList
if TYPE_CHECKING:
from vllm.lora.request import LoRARequest
from vllm.v1.core.kv_cache_utils import BlockHash
class Request:
@@ -36,6 +38,8 @@ class Request:
structured_output_request: Optional["StructuredOutputRequest"] = None,
cache_salt: Optional[str] = None,
priority: int = 0,
block_hasher: Optional[Callable[["Request"],
list["BlockHash"]]] = None,
) -> None:
self.request_id = request_id
self.client_index = client_index
@@ -108,8 +112,18 @@ class Request:
# indicates that the output is corrupted
self.num_nans_in_logits = 0
self.block_hashes: list[BlockHash] = []
self.get_hash_new_full_blocks: Optional[Callable[
[], list[BlockHash]]] = None
if block_hasher is not None:
self.get_hash_new_full_blocks = partial(block_hasher, self)
self.block_hashes = self.get_hash_new_full_blocks()
@classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
def from_engine_core_request(
cls, request: EngineCoreRequest,
block_hasher: Optional[Callable[["Request"], list["BlockHash"]]]
) -> "Request":
if request.mm_kwargs is not None:
assert is_list_of(request.mm_kwargs, MultiModalKwargsItem), (
"mm_kwargs was not updated in EngineCore.add_request")
@@ -131,6 +145,7 @@ class Request:
if request.sampling_params else None,
cache_salt=request.cache_salt,
priority=request.priority,
block_hasher=block_hasher,
)
def append_output_token_ids(
@@ -144,6 +159,9 @@ class Request:
self._output_token_ids.extend(token_ids)
self._all_token_ids.extend(token_ids)
if self.get_hash_new_full_blocks is not None:
self.block_hashes.extend(self.get_hash_new_full_blocks())
@property
def is_output_corrupted(self) -> bool:
return self.num_nans_in_logits > 0