[v1] Move block_hashes from KVCacheManager to Request.block_hashes (#19728)
Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
@@ -7,6 +7,7 @@ import pytest
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import RequestStatus
|
||||
from vllm.v1.utils import ConstantList
|
||||
|
||||
from .utils import create_requests, create_scheduler
|
||||
|
||||
@@ -140,7 +141,8 @@ def test_prefix_caching_for_prefill_dedup():
|
||||
requests = create_requests(num_requests=5,
|
||||
num_tokens=num_prompt_tokens,
|
||||
max_tokens=3,
|
||||
same_prompt=True)
|
||||
same_prompt=True,
|
||||
block_size=BLOCK_SIZE)
|
||||
requests_copy = requests.copy()
|
||||
|
||||
# Two requests with the same prompt.
|
||||
@@ -188,7 +190,8 @@ def test_prefix_caching_for_multi_turn():
|
||||
block_size=BLOCK_SIZE)
|
||||
requests = create_requests(num_requests=5,
|
||||
num_tokens=num_prompt_tokens,
|
||||
max_tokens=num_output_tokens)
|
||||
max_tokens=num_output_tokens,
|
||||
block_size=BLOCK_SIZE)
|
||||
|
||||
for req in requests:
|
||||
scheduler.add_request(req)
|
||||
@@ -208,14 +211,19 @@ def test_prefix_caching_for_multi_turn():
|
||||
|
||||
# Create next-turn requests whose prompts are the full output of the
|
||||
# previous turn.
|
||||
next_turn_requests = create_requests(
|
||||
num_requests=5,
|
||||
num_tokens=num_prompt_tokens + num_output_tokens,
|
||||
max_tokens=num_output_tokens,
|
||||
)
|
||||
next_turn_requests = create_requests(num_requests=5,
|
||||
num_tokens=num_prompt_tokens +
|
||||
num_output_tokens,
|
||||
max_tokens=num_output_tokens,
|
||||
block_size=BLOCK_SIZE)
|
||||
for i, req in enumerate(next_turn_requests):
|
||||
req.prompt_token_ids = (requests[i].prompt_token_ids +
|
||||
list(requests[i].output_token_ids))
|
||||
req._all_token_ids = req.prompt_token_ids.copy()
|
||||
req.all_token_ids = ConstantList(req._all_token_ids)
|
||||
req.block_hashes = []
|
||||
req.block_hashes = req.get_hash_new_full_blocks()
|
||||
|
||||
# Schedule the next-turn requests.
|
||||
for req in next_turn_requests:
|
||||
scheduler.add_request(req)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -19,7 +19,7 @@ from vllm.v1.core.kv_cache_utils import (
|
||||
FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
|
||||
estimate_max_model_len, generate_block_hash_extra_keys,
|
||||
get_kv_cache_config, get_max_concurrency_for_kv_cache_config,
|
||||
hash_block_tokens, hash_request_tokens, init_none_hash,
|
||||
get_request_block_hasher, hash_block_tokens, init_none_hash,
|
||||
is_kv_cache_type_uniform, unify_kv_cache_configs)
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheTensor,
|
||||
@@ -33,6 +33,8 @@ from vllm.v1.request import Request
|
||||
def make_request(
|
||||
request_id: str,
|
||||
prompt_token_ids: list[int],
|
||||
block_size: int = 3,
|
||||
hash_fn: Callable = hash,
|
||||
mm_positions: Optional[list[PlaceholderRange]] = None,
|
||||
mm_hashes: Optional[list[str]] = None,
|
||||
cache_salt: Optional[str] = None,
|
||||
@@ -49,18 +51,17 @@ def make_request(
|
||||
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
|
||||
mm_kwargs = [mm_item] * len(mm_positions)
|
||||
|
||||
return Request(
|
||||
request_id=request_id,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_kwargs=mm_kwargs,
|
||||
multi_modal_hashes=mm_hashes,
|
||||
multi_modal_placeholders=mm_positions,
|
||||
sampling_params=SamplingParams(max_tokens=17),
|
||||
pooling_params=None,
|
||||
eos_token_id=100,
|
||||
lora_request=None,
|
||||
cache_salt=cache_salt,
|
||||
)
|
||||
return Request(request_id=request_id,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_kwargs=mm_kwargs,
|
||||
multi_modal_hashes=mm_hashes,
|
||||
multi_modal_placeholders=mm_positions,
|
||||
sampling_params=SamplingParams(max_tokens=17),
|
||||
pooling_params=None,
|
||||
eos_token_id=100,
|
||||
lora_request=None,
|
||||
cache_salt=cache_salt,
|
||||
block_hasher=get_request_block_hasher(block_size, hash_fn))
|
||||
|
||||
|
||||
def new_kv_cache_spec(block_size=16,
|
||||
@@ -428,12 +429,14 @@ def test_hash_block_tokens(hash_fn):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
|
||||
def test_hash_request_tokens(hash_fn):
|
||||
def test_request_block_hasher(hash_fn):
|
||||
import vllm.v1.core.kv_cache_utils
|
||||
init_none_hash(hash_fn)
|
||||
request = make_request(
|
||||
request_id="0",
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
block_size=3,
|
||||
hash_fn=hash_fn,
|
||||
mm_positions=[
|
||||
PlaceholderRange(offset=0, length=3),
|
||||
PlaceholderRange(offset=3, length=3),
|
||||
@@ -441,9 +444,7 @@ def test_hash_request_tokens(hash_fn):
|
||||
mm_hashes=["hash1", "hash2"],
|
||||
)
|
||||
|
||||
block_size = 3
|
||||
block_hashes = hash_request_tokens(hash_fn, block_size, request)
|
||||
|
||||
block_hashes = request.block_hashes
|
||||
assert len(block_hashes) == 2
|
||||
assert isinstance(block_hashes[0], vllm.v1.core.kv_cache_utils.BlockHash)
|
||||
assert isinstance(block_hashes[1], vllm.v1.core.kv_cache_utils.BlockHash)
|
||||
@@ -464,6 +465,8 @@ def test_hash_tokens_different_mm_input(hash_fn):
|
||||
request1 = make_request(
|
||||
request_id="0",
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
block_size=3,
|
||||
hash_fn=hash_fn,
|
||||
mm_positions=[
|
||||
PlaceholderRange(offset=0, length=3),
|
||||
PlaceholderRange(offset=3, length=3),
|
||||
@@ -479,9 +482,8 @@ def test_hash_tokens_different_mm_input(hash_fn):
|
||||
],
|
||||
mm_hashes=["hash3", "hash2"],
|
||||
)
|
||||
block_size = 3
|
||||
block_hashes1 = hash_request_tokens(hash_fn, block_size, request1)
|
||||
block_hashes2 = hash_request_tokens(hash_fn, block_size, request2)
|
||||
block_hashes1 = request1.block_hashes
|
||||
block_hashes2 = request2.block_hashes
|
||||
assert block_hashes1[0] != block_hashes2[0]
|
||||
assert block_hashes1[1] != block_hashes2[1]
|
||||
|
||||
@@ -493,12 +495,13 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
|
||||
request = make_request(
|
||||
request_id="0",
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
block_size=3,
|
||||
hash_fn=hash_fn,
|
||||
mm_positions=None,
|
||||
mm_hashes=None,
|
||||
)
|
||||
|
||||
block_size = 3
|
||||
block_hashes = hash_request_tokens(hash_fn, block_size, request)
|
||||
block_hashes = request.block_hashes
|
||||
|
||||
assert len(block_hashes) == 2
|
||||
assert block_hashes[0].token_ids == (0, 1, 2)
|
||||
@@ -858,6 +861,7 @@ def test_allocate_with_lookahead():
|
||||
request = make_request(
|
||||
request_id="0",
|
||||
prompt_token_ids=[],
|
||||
block_size=block_size,
|
||||
mm_positions=None,
|
||||
mm_hashes=None,
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
"""Compare the with and without prefix caching."""
|
||||
|
||||
import copy
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -17,8 +17,9 @@ from vllm.utils import sha256, sha256_cbor_64bit
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
||||
KVCacheBlock, hash_block_tokens,
|
||||
init_none_hash)
|
||||
KVCacheBlock,
|
||||
get_request_block_hasher,
|
||||
hash_block_tokens, init_none_hash)
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, SlidingWindowSpec)
|
||||
|
||||
@@ -26,6 +27,8 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
def make_request(
|
||||
request_id: str,
|
||||
prompt_token_ids: list[int],
|
||||
block_size: int,
|
||||
hash_fn: Callable,
|
||||
mm_positions: Optional[list[PlaceholderRange]] = None,
|
||||
mm_hashes: Optional[list[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
@@ -43,19 +46,18 @@ def make_request(
|
||||
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
|
||||
mm_kwargs = [mm_item] * len(mm_positions)
|
||||
|
||||
return Request(
|
||||
request_id=request_id,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_kwargs=mm_kwargs,
|
||||
multi_modal_hashes=mm_hashes,
|
||||
multi_modal_placeholders=mm_positions,
|
||||
sampling_params=SamplingParams(max_tokens=17,
|
||||
prompt_logprobs=prompt_logprobs),
|
||||
pooling_params=None,
|
||||
eos_token_id=100,
|
||||
lora_request=None,
|
||||
cache_salt=cache_salt,
|
||||
)
|
||||
return Request(request_id=request_id,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_kwargs=mm_kwargs,
|
||||
multi_modal_hashes=mm_hashes,
|
||||
multi_modal_placeholders=mm_positions,
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=17, prompt_logprobs=prompt_logprobs),
|
||||
pooling_params=None,
|
||||
eos_token_id=100,
|
||||
lora_request=None,
|
||||
cache_salt=cache_salt,
|
||||
block_hasher=get_request_block_hasher(block_size, hash_fn))
|
||||
|
||||
|
||||
def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
|
||||
@@ -105,11 +107,11 @@ def make_kv_cache_config_hybrid_model(block_size: int,
|
||||
|
||||
@pytest.mark.parametrize("hash_algo", ["sha256", "sha256_cbor_64bit", "hash"])
|
||||
def test_prefill(hash_algo):
|
||||
block_size = 16
|
||||
manager = KVCacheManager(
|
||||
make_kv_cache_config(16, 11),
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
caching_hash_algo=hash_algo,
|
||||
)
|
||||
|
||||
# choose the hash function according to the parameter
|
||||
@@ -123,9 +125,9 @@ def test_prefill(hash_algo):
|
||||
# Incomplete 1 block (7 tokens)
|
||||
unique_token_ids = [3] * 7
|
||||
all_token_ids = common_token_ids + unique_token_ids
|
||||
req0 = make_request("0", all_token_ids)
|
||||
req0 = make_request("0", all_token_ids, block_size, hash_fn)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
|
||||
assert len(req0.block_hashes) == 3
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
@@ -152,9 +154,10 @@ def test_prefill(hash_algo):
|
||||
# Cache hit in the common prefix when the original block is still in use.
|
||||
# Incomplete 1 block (5 tokens)
|
||||
unique_token_ids = [3] * 5
|
||||
req1 = make_request("1", common_token_ids + unique_token_ids)
|
||||
req1 = make_request("1", common_token_ids + unique_token_ids, block_size,
|
||||
hash_fn)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||
assert len(req1.block_hashes) == 3
|
||||
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
@@ -187,9 +190,10 @@ def test_prefill(hash_algo):
|
||||
# Cache hit in the common prefix when the original block is already free.
|
||||
# Incomplete 1 block (6 tokens)
|
||||
unique_token_ids = [3] * 6
|
||||
req2 = make_request("2", common_token_ids + unique_token_ids)
|
||||
req2 = make_request("2", common_token_ids + unique_token_ids, block_size,
|
||||
hash_fn)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
|
||||
assert len(req2.block_hashes) == 3
|
||||
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
@@ -208,7 +212,7 @@ def test_prefill(hash_algo):
|
||||
manager.free(req2)
|
||||
|
||||
# Cache miss and eviction.
|
||||
req3 = make_request("3", [99] * (16 * 10))
|
||||
req3 = make_request("3", [99] * (16 * 10), block_size, hash_fn)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@@ -242,9 +246,9 @@ def test_prefill_hybrid_model():
|
||||
# Incomplete 1 block (7 tokens)
|
||||
unique_token_ids = [3] * 7
|
||||
all_token_ids = common_token_ids + unique_token_ids
|
||||
req0 = make_request("0", all_token_ids)
|
||||
req0 = make_request("0", all_token_ids, block_size, hash_fn)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert len(manager.req_to_block_hashes[req0.request_id]) == 3
|
||||
assert len(req0.block_hashes) == 3
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
@@ -274,9 +278,10 @@ def test_prefill_hybrid_model():
|
||||
# Cache hit in the common prefix
|
||||
# Incomplete 1 block (5 tokens)
|
||||
unique_token_ids = [3] * 5
|
||||
req1 = make_request("1", common_token_ids + unique_token_ids)
|
||||
req1 = make_request("1", common_token_ids + unique_token_ids, block_size,
|
||||
hash_fn)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||
assert len(req1.block_hashes) == 3
|
||||
assert computed_blocks.get_block_ids() == ([1, 2, 3], [0, 6,
|
||||
7], [0, 10, 11])
|
||||
assert num_computed_tokens == 3 * 16
|
||||
@@ -290,7 +295,7 @@ def test_prefill_hybrid_model():
|
||||
if block != manager.block_pool.null_block:
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
block_hashes = manager.req_to_block_hashes[req1.request_id]
|
||||
block_hashes = req1.block_hashes
|
||||
manager.free(req0)
|
||||
manager.free(req1)
|
||||
|
||||
@@ -300,12 +305,13 @@ def test_prefill_hybrid_model():
|
||||
def test_partial_request_hit(request_id: str,
|
||||
hash_to_evict: list[BlockHashWithGroupId],
|
||||
expect_hit_length: int):
|
||||
req = make_request(request_id, common_token_ids + unique_token_ids)
|
||||
req = make_request(request_id, common_token_ids + unique_token_ids,
|
||||
block_size, hash)
|
||||
for hash_with_group_id in hash_to_evict:
|
||||
manager.block_pool.cached_block_hash_to_block.pop(
|
||||
hash_with_group_id)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert len(manager.req_to_block_hashes[req.request_id]) == 3
|
||||
assert len(req.block_hashes) == 3
|
||||
assert num_computed_tokens == expect_hit_length * block_size
|
||||
for block_per_group in computed_blocks.blocks:
|
||||
assert len(block_per_group) == num_computed_tokens // block_size
|
||||
@@ -364,8 +370,9 @@ def test_prefill_plp():
|
||||
2. Schedule non-plp request and validate blocks
|
||||
3. Schedule plp request; no hit should occur; validate blocks
|
||||
'''
|
||||
block_size = 16
|
||||
manager = KVCacheManager(
|
||||
make_kv_cache_config(16, 11),
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
)
|
||||
@@ -380,9 +387,13 @@ def test_prefill_plp():
|
||||
# Incomplete 1 block (7 tokens)
|
||||
unique_token_ids = [3] * 7
|
||||
all_token_ids = common_token_ids + unique_token_ids
|
||||
req0 = make_request("0", all_token_ids, prompt_logprobs=5)
|
||||
req0 = make_request("0",
|
||||
all_token_ids,
|
||||
block_size,
|
||||
hash_fn,
|
||||
prompt_logprobs=5)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert len(manager.req_to_block_hashes[req0.request_id]) == 0
|
||||
assert len(req0.block_hashes) == 3
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
@@ -411,9 +422,10 @@ def test_prefill_plp():
|
||||
# Cache hit in the common prefix when the original block is still in use.
|
||||
# Incomplete 1 block (5 tokens)
|
||||
unique_token_ids = [3] * 5
|
||||
req1 = make_request("1", common_token_ids + unique_token_ids)
|
||||
req1 = make_request("1", common_token_ids + unique_token_ids, block_size,
|
||||
hash_fn)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||
assert len(req1.block_hashes) == 3
|
||||
assert computed_blocks.get_block_ids() == ([1, 2, 3], )
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
@@ -447,9 +459,11 @@ def test_prefill_plp():
|
||||
unique_token_ids = [3] * 6
|
||||
req2 = make_request("2",
|
||||
common_token_ids + unique_token_ids,
|
||||
block_size,
|
||||
hash_fn,
|
||||
prompt_logprobs=5)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(manager.req_to_block_hashes[req2.request_id]) == 0
|
||||
assert len(req2.block_hashes) == 3
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
blocks = manager.allocate_slots(req2, 55,
|
||||
@@ -469,8 +483,9 @@ def test_prefill_plp():
|
||||
|
||||
|
||||
def test_decode():
|
||||
block_size = 16
|
||||
manager = KVCacheManager(
|
||||
make_kv_cache_config(16, 11),
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
)
|
||||
@@ -481,7 +496,8 @@ def test_decode():
|
||||
# Fully cache miss
|
||||
# Incomplete 1 block (7 tokens)
|
||||
unique_token_ids = [3] * 7
|
||||
req0 = make_request("0", common_token_ids + unique_token_ids)
|
||||
req0 = make_request("0", common_token_ids + unique_token_ids, block_size,
|
||||
hash)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@@ -518,14 +534,15 @@ def test_decode():
|
||||
|
||||
|
||||
def test_evict():
|
||||
block_size = 16
|
||||
manager = KVCacheManager(
|
||||
make_kv_cache_config(16, 11),
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
)
|
||||
|
||||
last_token_id = 5 * 16 + 7
|
||||
req0 = make_request("0", list(range(last_token_id)))
|
||||
req0 = make_request("0", list(range(last_token_id)), block_size, hash)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@@ -536,7 +553,8 @@ def test_evict():
|
||||
|
||||
# 3 blocks.
|
||||
req1 = make_request("1", list(range(last_token_id,
|
||||
last_token_id + 3 * 16)))
|
||||
last_token_id + 3 * 16)), block_size,
|
||||
hash)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@@ -558,7 +576,7 @@ def test_evict():
|
||||
] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7]
|
||||
|
||||
# Touch the first 2 blocks.
|
||||
req2 = make_request("2", list(range(2 * 16 + 3)))
|
||||
req2 = make_request("2", list(range(2 * 16 + 3)), block_size, hash)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert computed_blocks.get_block_ids() == ([1, 2], )
|
||||
assert num_computed_tokens == 2 * 16
|
||||
@@ -583,7 +601,7 @@ def test_hash_block_correct_reuse():
|
||||
|
||||
# Allocate 1 block and cache it.
|
||||
num_tokens = block_size * 1
|
||||
req = make_request("0", list(range(num_tokens)))
|
||||
req = make_request("0", list(range(num_tokens)), block_size, hash)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@@ -597,7 +615,7 @@ def test_hash_block_correct_reuse():
|
||||
|
||||
# Allocate a new block that's not full, make sure hash info on the
|
||||
# block is cleared.
|
||||
req = make_request("1", list(range(num_tokens - 1)))
|
||||
req = make_request("1", list(range(num_tokens - 1)), block_size, hash)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@@ -624,7 +642,7 @@ def test_computed_blocks_not_evicted():
|
||||
|
||||
# Allocate a block and cache it.
|
||||
num_tokens = block_size * 1
|
||||
req0 = make_request("0", list(range(num_tokens)))
|
||||
req0 = make_request("0", list(range(num_tokens)), block_size, hash)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@@ -635,7 +653,8 @@ def test_computed_blocks_not_evicted():
|
||||
assert blocks.blocks[0][0].block_id == 1
|
||||
|
||||
# Allocate another block.
|
||||
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)))
|
||||
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)),
|
||||
block_size, hash)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@@ -651,7 +670,7 @@ def test_computed_blocks_not_evicted():
|
||||
|
||||
# Now if we have a cache hit on the first block, we should evict the second
|
||||
# cached block rather than the first one.
|
||||
req2 = make_request("2", list(range(num_tokens * 2)))
|
||||
req2 = make_request("2", list(range(num_tokens * 2)), block_size, hash)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(computed_blocks.blocks[0]) == 1
|
||||
assert computed_blocks.blocks[0][0].block_id == 1
|
||||
@@ -675,7 +694,8 @@ def test_basic_prefix_caching_disabled():
|
||||
enable_caching=False,
|
||||
)
|
||||
|
||||
req1 = make_request("1", list(range(10))) # 2 blocks and some more
|
||||
req1 = make_request("1", list(range(10)), block_size,
|
||||
hash) # 2 blocks and some more
|
||||
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert not computed_blocks.blocks[0]
|
||||
@@ -689,7 +709,8 @@ def test_basic_prefix_caching_disabled():
|
||||
manager.free(req1)
|
||||
|
||||
# No caching.
|
||||
req2 = make_request("2", list(range(16))) # shared prefix
|
||||
req2 = make_request("2", list(range(16)), block_size,
|
||||
hash) # shared prefix
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@@ -699,7 +720,7 @@ def test_basic_prefix_caching_disabled():
|
||||
assert len(blocks.blocks[0]) == 4
|
||||
|
||||
# New requests should not have any blocks.
|
||||
req3 = make_request("3", list(range(4)))
|
||||
req3 = make_request("3", list(range(4)), block_size, hash)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@@ -727,20 +748,17 @@ def test_cache_blocks(hash_fn):
|
||||
# Block 1: [4, 5, 6, 7]
|
||||
# Block 2: [8, 9, 10, 11]
|
||||
# Block 3: [12, 13]
|
||||
req = make_request("0", list(range(14)))
|
||||
req = make_request("0", list(range(14)), block_size, hash_fn)
|
||||
|
||||
# Test that blocks are cached correctly for 2 full blocks from the start.
|
||||
blocks = [KVCacheBlock(block_id=i) for i in range(2)]
|
||||
block_hashes: list[BlockHash] = []
|
||||
|
||||
block_pool.cache_full_blocks(
|
||||
request=req,
|
||||
blocks=blocks,
|
||||
block_hashes=block_hashes,
|
||||
num_cached_blocks=0,
|
||||
num_full_blocks=2,
|
||||
block_size=block_size,
|
||||
hash_fn=hash_fn,
|
||||
kv_cache_group_id=0,
|
||||
)
|
||||
|
||||
@@ -752,11 +770,9 @@ def test_cache_blocks(hash_fn):
|
||||
block_pool.cache_full_blocks(
|
||||
request=req,
|
||||
blocks=blocks,
|
||||
block_hashes=block_hashes,
|
||||
num_cached_blocks=2,
|
||||
num_full_blocks=3,
|
||||
block_size=block_size,
|
||||
hash_fn=hash_fn,
|
||||
kv_cache_group_id=0,
|
||||
)
|
||||
assert len(block_pool.cached_block_hash_to_block) == 3
|
||||
@@ -775,23 +791,20 @@ def test_cache_blocks_multi_group():
|
||||
# Block 1/5: [4, 5, 6, 7]
|
||||
# Block 2/6: [8, 9, 10, 11]
|
||||
# Block 3/7: [12, 13]
|
||||
req = make_request("0", list(range(14)))
|
||||
req = make_request("0", list(range(14)), block_size, hash)
|
||||
|
||||
# Cache the blocks for group 0.
|
||||
blocks = [KVCacheBlock(block_id=i) for i in range(2)]
|
||||
block_hashes: list[BlockHash] = []
|
||||
block_pool.cache_full_blocks(
|
||||
request=req,
|
||||
blocks=blocks,
|
||||
block_hashes=block_hashes,
|
||||
num_cached_blocks=0,
|
||||
num_full_blocks=2,
|
||||
block_size=block_size,
|
||||
hash_fn=hash,
|
||||
kv_cache_group_id=0,
|
||||
)
|
||||
assert len(block_pool.cached_block_hash_to_block) == 2
|
||||
assert len(block_hashes) == 2
|
||||
assert len(req.block_hashes) == 3
|
||||
assert all([block.block_hash is not None for block in blocks])
|
||||
|
||||
# Cache the blocks for group 1.
|
||||
@@ -799,38 +812,36 @@ def test_cache_blocks_multi_group():
|
||||
block_pool.cache_full_blocks(
|
||||
request=req,
|
||||
blocks=blocks,
|
||||
block_hashes=block_hashes,
|
||||
num_cached_blocks=0,
|
||||
num_full_blocks=3,
|
||||
block_size=block_size,
|
||||
hash_fn=hash,
|
||||
kv_cache_group_id=1,
|
||||
)
|
||||
assert len(block_pool.cached_block_hash_to_block) == 5
|
||||
assert len(block_hashes) == 3
|
||||
assert len(req.block_hashes) == 3
|
||||
assert all([block.block_hash is not None for block in blocks])
|
||||
|
||||
# Block hash 0: hit for group 0 and 1
|
||||
# Block hash 1: hit for group 0 and 1
|
||||
# Block hash 2: hit for group 1
|
||||
|
||||
assert block_pool.get_cached_block(block_hashes[0],
|
||||
assert block_pool.get_cached_block(req.block_hashes[0],
|
||||
kv_cache_group_ids=[0]) is not None
|
||||
assert block_pool.get_cached_block(block_hashes[1],
|
||||
assert block_pool.get_cached_block(req.block_hashes[1],
|
||||
kv_cache_group_ids=[0]) is not None
|
||||
assert block_pool.get_cached_block(block_hashes[2],
|
||||
assert block_pool.get_cached_block(req.block_hashes[2],
|
||||
kv_cache_group_ids=[0]) is None
|
||||
assert block_pool.get_cached_block(block_hashes[0],
|
||||
assert block_pool.get_cached_block(req.block_hashes[0],
|
||||
kv_cache_group_ids=[1]) is not None
|
||||
assert block_pool.get_cached_block(block_hashes[1],
|
||||
assert block_pool.get_cached_block(req.block_hashes[1],
|
||||
kv_cache_group_ids=[1]) is not None
|
||||
assert block_pool.get_cached_block(block_hashes[2],
|
||||
assert block_pool.get_cached_block(req.block_hashes[2],
|
||||
kv_cache_group_ids=[1]) is not None
|
||||
assert block_pool.get_cached_block(block_hashes[0],
|
||||
assert block_pool.get_cached_block(req.block_hashes[0],
|
||||
kv_cache_group_ids=[0, 1]) is not None
|
||||
assert block_pool.get_cached_block(block_hashes[1],
|
||||
assert block_pool.get_cached_block(req.block_hashes[1],
|
||||
kv_cache_group_ids=[0, 1]) is not None
|
||||
assert block_pool.get_cached_block(block_hashes[2],
|
||||
assert block_pool.get_cached_block(req.block_hashes[2],
|
||||
kv_cache_group_ids=[0, 1]) is None
|
||||
|
||||
|
||||
@@ -838,8 +849,9 @@ def test_mm_prefix_caching():
|
||||
"""
|
||||
This tests that the multi-modal prefix caching is correct.
|
||||
"""
|
||||
block_size = 16
|
||||
manager = KVCacheManager(
|
||||
make_kv_cache_config(16, 11),
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
)
|
||||
@@ -865,6 +877,8 @@ def test_mm_prefix_caching():
|
||||
mm_hashes = common_mm_hashes + ["ccc"]
|
||||
req0 = make_request("0",
|
||||
all_token_ids,
|
||||
block_size,
|
||||
hash,
|
||||
mm_positions=mm_positions,
|
||||
mm_hashes=mm_hashes)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
@@ -872,7 +886,7 @@ def test_mm_prefix_caching():
|
||||
# Completed block should have hashes with extra keys.
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
block_hashes = manager.req_to_block_hashes[req0.request_id]
|
||||
block_hashes = req0.block_hashes
|
||||
assert len(block_hashes) == 3
|
||||
assert block_hashes[0].extra_keys == ("aaa", )
|
||||
assert block_hashes[1].extra_keys == ("aaa", "bbb")
|
||||
@@ -905,6 +919,8 @@ def test_mm_prefix_caching():
|
||||
mm_hashes = common_mm_hashes + ["ccc"]
|
||||
req1 = make_request("1",
|
||||
all_token_ids,
|
||||
block_size,
|
||||
hash,
|
||||
mm_positions=mm_positions,
|
||||
mm_hashes=mm_hashes)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
@@ -927,13 +943,13 @@ def test_cache_key_salting():
|
||||
# 3 complete blocks and an incomplete block with 11 tokens.
|
||||
common_token_ids = [i for i in range(3) for _ in range(block_size)]
|
||||
token_ids = common_token_ids + [3] * 11
|
||||
req0 = make_request("0", token_ids, cache_salt="salt1")
|
||||
req0 = make_request("0", token_ids, block_size, hash, cache_salt="salt1")
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
|
||||
# Completed block should have hashes with extra keys.
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
block_hashes = manager.req_to_block_hashes[req0.request_id]
|
||||
block_hashes = req0.block_hashes
|
||||
assert len(block_hashes) == 3
|
||||
assert block_hashes[0].extra_keys == ("salt1", )
|
||||
assert block_hashes[1].extra_keys is None
|
||||
@@ -959,7 +975,7 @@ def test_cache_key_salting():
|
||||
|
||||
# Test cache hit with a new request that has the same salt.
|
||||
token_ids = common_token_ids + [4] * 11
|
||||
req1 = make_request("1", token_ids, cache_salt="salt1")
|
||||
req1 = make_request("1", token_ids, block_size, hash, cache_salt="salt1")
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
# Should match only a prefix of 3 blocks.
|
||||
assert len(computed_blocks.blocks[0]) == 3
|
||||
@@ -967,11 +983,11 @@ def test_cache_key_salting():
|
||||
|
||||
# Test cache miss with same content but different salt.
|
||||
token_ids = common_token_ids + [4] * 11
|
||||
req2 = make_request("2", token_ids, cache_salt="salt2")
|
||||
req2 = make_request("2", token_ids, block_size, hash, cache_salt="salt2")
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(computed_blocks.blocks[0]) == 0
|
||||
assert num_computed_tokens == 0
|
||||
block_hashes = manager.req_to_block_hashes[req2.request_id]
|
||||
block_hashes = req2.block_hashes
|
||||
assert len(block_hashes) == 3
|
||||
assert block_hashes[0].extra_keys == ("salt2", )
|
||||
|
||||
@@ -992,7 +1008,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
# Complete 3 blocks (48 tokens)
|
||||
# | Common-0 | Common-1 | Common-2 | ... |
|
||||
common_token_ids = [i for i in range(3) for _ in range(16)]
|
||||
req0 = make_request("0", common_token_ids)
|
||||
req0 = make_request("0", common_token_ids, block_size, hash)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@@ -1003,7 +1019,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
req0.request_id]
|
||||
|
||||
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
|
||||
req1 = make_request("1", common_token_ids * 2)
|
||||
req1 = make_request("1", common_token_ids * 2, block_size, hash)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert computed_blocks.blocks[0] == block_part0
|
||||
assert num_computed_tokens == 3 * 16
|
||||
@@ -1020,19 +1036,19 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
|
||||
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
|
||||
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
|
||||
req2 = make_request("2", [7] * block_size * 2)
|
||||
req2 = make_request("2", [7] * block_size * 2, block_size, hash)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
manager.allocate_slots(req2, block_size * 2,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
len(computed_blocks.blocks[0]) * block_size,
|
||||
computed_blocks)
|
||||
|
||||
# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
|
||||
# but it cannot be allocated due to insufficient free blocks (2).
|
||||
# In this case, the ref_cnt of the computed blocks should not be changed.
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 5
|
||||
req3 = make_request("3", common_token_ids * 3)
|
||||
req3 = make_request("3", common_token_ids * 3, block_size, hash)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
|
||||
assert computed_blocks.blocks[0] == block_part1
|
||||
assert num_computed_tokens == 6 * 16
|
||||
@@ -1047,8 +1063,9 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
|
||||
|
||||
def test_reset_prefix_cache():
|
||||
block_size = 16
|
||||
manager = KVCacheManager(
|
||||
make_kv_cache_config(16, 11),
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
)
|
||||
@@ -1056,15 +1073,15 @@ def test_reset_prefix_cache():
|
||||
full_block_token_ids = [i for i in range(3) for _ in range(16)]
|
||||
unique_token_ids = [3] * 7
|
||||
all_token_ids = full_block_token_ids + unique_token_ids
|
||||
req0 = make_request("0", all_token_ids)
|
||||
req0 = make_request("0", all_token_ids, block_size, hash)
|
||||
blocks = manager.allocate_slots(req0, 55)
|
||||
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
|
||||
|
||||
unique_token_ids = [4] * 7
|
||||
all_token_ids = full_block_token_ids + unique_token_ids
|
||||
req1 = make_request("1", all_token_ids)
|
||||
req1 = make_request("1", all_token_ids, block_size, hash)
|
||||
computed_blocks, _ = manager.get_computed_blocks(req1)
|
||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||
assert len(req1.block_hashes) == 3
|
||||
assert len(computed_blocks.blocks[0]) == 3
|
||||
blocks = manager.allocate_slots(req1, 7,
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
@@ -1086,8 +1103,9 @@ def test_reset_prefix_cache():
|
||||
|
||||
def test_prefix_cache_stats_disabled():
|
||||
"""Test that prefix_cache_stats is None when log_stats is False."""
|
||||
block_size = 16
|
||||
manager = KVCacheManager(
|
||||
make_kv_cache_config(16, 11),
|
||||
make_kv_cache_config(block_size, 11),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
log_stats=False, # Disable logging stats
|
||||
@@ -1095,7 +1113,7 @@ def test_prefix_cache_stats_disabled():
|
||||
assert manager.prefix_cache_stats is None
|
||||
|
||||
# Call all functions that check whether log_stats is disabled.
|
||||
req = make_request("0", list(range(16)))
|
||||
req = make_request("0", list(range(16)), block_size, hash)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
|
||||
assert not computed_blocks.blocks[0]
|
||||
assert num_computed_tokens == 0
|
||||
@@ -1192,7 +1210,7 @@ def test_kv_cache_events(blocks_to_cache: int):
|
||||
)
|
||||
|
||||
num_tokens = block_size * blocks_to_cache
|
||||
req0 = make_request("0", list(range(num_tokens)))
|
||||
req0 = make_request("0", list(range(num_tokens)), block_size, hash)
|
||||
_ = manager.allocate_slots(req0, num_tokens)
|
||||
events = manager.take_events()
|
||||
|
||||
@@ -1208,7 +1226,7 @@ def test_kv_cache_events(blocks_to_cache: int):
|
||||
# Should see block_to_cache number of removed block events and a new block
|
||||
# stored event
|
||||
manager.free(req0)
|
||||
req1 = make_request("1", list(range(num_tokens)))
|
||||
req1 = make_request("1", list(range(num_tokens)), block_size, hash)
|
||||
_ = manager.allocate_slots(req1, num_tokens)
|
||||
events = manager.take_events()
|
||||
|
||||
@@ -1242,7 +1260,7 @@ def test_eagle_enabled_removes_last_block():
|
||||
|
||||
# Request with 3 full blocks (48 tokens)
|
||||
token_ids = [0] * (3 * block_size)
|
||||
req = make_request("divisible_request", token_ids)
|
||||
req = make_request("divisible_request", token_ids, block_size, hash)
|
||||
|
||||
# Prime the cache
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
@@ -1252,7 +1270,7 @@ def test_eagle_enabled_removes_last_block():
|
||||
manager.free(req)
|
||||
|
||||
# New request with same tokens + Eagle enabled
|
||||
req_eagle = make_request("eagle_divisible", token_ids)
|
||||
req_eagle = make_request("eagle_divisible", token_ids, block_size, hash)
|
||||
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
||||
|
||||
# Should retain 1 block:
|
||||
@@ -1273,7 +1291,7 @@ def test_eagle_with_partial_blocks():
|
||||
)
|
||||
# 2 full blocks + 5 tokens (non-divisible length)
|
||||
token_ids = [0] * (2 * block_size + 5)
|
||||
req = make_request("partial_block_test", token_ids)
|
||||
req = make_request("partial_block_test", token_ids, block_size, hash)
|
||||
|
||||
# Prime the cache
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
@@ -1283,7 +1301,7 @@ def test_eagle_with_partial_blocks():
|
||||
manager.free(req)
|
||||
|
||||
# New request with Eagle enabled
|
||||
req_eagle = make_request("partial_eagle", token_ids)
|
||||
req_eagle = make_request("partial_eagle", token_ids, block_size, hash)
|
||||
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
||||
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
|
||||
assert len(computed_blocks.blocks[0]) == 1
|
||||
@@ -1314,7 +1332,7 @@ def test_eagle_with_sliding_window():
|
||||
|
||||
# 2 full blocks + 5 tokens (non-divisible length)
|
||||
token_ids = [0] * (2 * block_size + 5)
|
||||
req = make_request("partial_block_test", token_ids)
|
||||
req = make_request("partial_block_test", token_ids, block_size, hash)
|
||||
|
||||
# Prime the cache
|
||||
computed_blocks, _ = manager.get_computed_blocks(req)
|
||||
@@ -1322,12 +1340,12 @@ def test_eagle_with_sliding_window():
|
||||
len(computed_blocks.blocks[0]) * 16,
|
||||
computed_blocks)
|
||||
# record the block hash of the first block in the request for later use
|
||||
block_hash_first_block = manager.req_to_block_hashes[req.request_id][0]
|
||||
block_hash_first_block = req.block_hashes[0]
|
||||
assert block_hash_first_block is not None
|
||||
manager.free(req)
|
||||
|
||||
# New request with Eagle enabled
|
||||
req_eagle = make_request("partial_eagle", token_ids)
|
||||
req_eagle = make_request("partial_eagle", token_ids, block_size, hash)
|
||||
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
|
||||
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
|
||||
assert len(computed_blocks.blocks[0]) == 1
|
||||
@@ -1340,7 +1358,8 @@ def test_eagle_with_sliding_window():
|
||||
BlockHashWithGroupId(block_hash_first_block, 0))
|
||||
|
||||
# New request
|
||||
req_after_evict = make_request("partial_eagle_after_evict", token_ids)
|
||||
req_after_evict = make_request("partial_eagle_after_evict", token_ids,
|
||||
block_size, hash)
|
||||
computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict)
|
||||
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
|
||||
# not considered. But after dropping the last matched block due to eagle,
|
||||
|
||||
@@ -589,7 +589,7 @@ def test_preempt_during_execution():
|
||||
block_size=16,
|
||||
num_blocks=11,
|
||||
enable_prefix_caching=False)
|
||||
requests = create_requests(num_requests=2, num_tokens=80)
|
||||
requests = create_requests(num_requests=2, num_tokens=80, block_size=16)
|
||||
|
||||
# Schedule the first request.
|
||||
scheduler.add_request(requests[0])
|
||||
@@ -762,7 +762,7 @@ def _assert_right_scheduler_output(
|
||||
|
||||
def _assert_right_kv_cache_manager(
|
||||
scheduler: Scheduler,
|
||||
req_ids: list[str],
|
||||
requests: list[Request],
|
||||
num_tokens: int,
|
||||
block_size: int,
|
||||
num_requests: int,
|
||||
@@ -772,12 +772,12 @@ def _assert_right_kv_cache_manager(
|
||||
|
||||
# Make sure the request stats are right.
|
||||
EXPECTED_TOTAL_BLOCKS = num_tokens // block_size
|
||||
for req_id in req_ids:
|
||||
for req in requests:
|
||||
blocks = (scheduler.kv_cache_manager.coordinator.
|
||||
single_type_managers[0].req_to_blocks[req_id])
|
||||
hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id]
|
||||
single_type_managers[0].req_to_blocks[req.request_id])
|
||||
hashes = req.block_hashes
|
||||
assert (scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
num_cached_block[req_id] == EXPECTED_TOTAL_BLOCKS)
|
||||
num_cached_block[req.request_id] == EXPECTED_TOTAL_BLOCKS)
|
||||
assert len(blocks) == EXPECTED_TOTAL_BLOCKS
|
||||
assert len(hashes) == EXPECTED_TOTAL_BLOCKS
|
||||
|
||||
@@ -840,7 +840,8 @@ def test_kv_connector_basic():
|
||||
MAX_TOKENS = 3
|
||||
requests = create_requests(num_requests=NUM_REQUESTS,
|
||||
num_tokens=NUM_TOKENS,
|
||||
max_tokens=MAX_TOKENS)
|
||||
max_tokens=MAX_TOKENS,
|
||||
block_size=BLOCK_SIZE)
|
||||
req_ids = []
|
||||
req_to_index = {}
|
||||
for i, request in enumerate(requests):
|
||||
@@ -868,7 +869,7 @@ def test_kv_connector_basic():
|
||||
)
|
||||
|
||||
# Ensure KVCacheManager is correct.
|
||||
_assert_right_kv_cache_manager(scheduler, req_ids, NUM_TOKENS, BLOCK_SIZE,
|
||||
_assert_right_kv_cache_manager(scheduler, requests, NUM_TOKENS, BLOCK_SIZE,
|
||||
NUM_REQUESTS, NUM_TOTAL_BLOCKS)
|
||||
|
||||
# Continue Generation until done.
|
||||
@@ -886,7 +887,8 @@ def test_kv_connector_basic():
|
||||
NUM_TOKENS = NUM_TOKENS_PREFIX * 2
|
||||
requests = create_requests(num_requests=NUM_REQUESTS,
|
||||
num_tokens=NUM_TOKENS,
|
||||
max_tokens=MAX_TOKENS)
|
||||
max_tokens=MAX_TOKENS,
|
||||
block_size=BLOCK_SIZE)
|
||||
req_ids = []
|
||||
req_to_index = {}
|
||||
for i, request in enumerate(requests):
|
||||
@@ -915,7 +917,7 @@ def test_kv_connector_basic():
|
||||
NUM_MATCHED_NEW_TOKENS))
|
||||
|
||||
# Ensure KVCacheManager is correct.
|
||||
_assert_right_kv_cache_manager(scheduler, req_ids, NUM_TOKENS, BLOCK_SIZE,
|
||||
_assert_right_kv_cache_manager(scheduler, requests, NUM_TOKENS, BLOCK_SIZE,
|
||||
NUM_REQUESTS, NUM_TOTAL_BLOCKS)
|
||||
|
||||
# Continue Generation until done.
|
||||
@@ -953,7 +955,8 @@ def test_kv_connector_unable_to_allocate():
|
||||
MAX_TOKENS = 2
|
||||
requests = create_requests(num_requests=NUM_REQUESTS,
|
||||
num_tokens=NUM_TOKENS,
|
||||
max_tokens=MAX_TOKENS)
|
||||
max_tokens=MAX_TOKENS,
|
||||
block_size=BLOCK_SIZE)
|
||||
req_ids = []
|
||||
req_to_index = {}
|
||||
for i, request in enumerate(requests):
|
||||
@@ -1034,7 +1037,8 @@ def test_kv_connector_handles_preemption():
|
||||
MAX_TOKENS = BLOCK_SIZE * 2
|
||||
requests = create_requests(num_requests=NUM_REQUESTS,
|
||||
num_tokens=NUM_TOKENS,
|
||||
max_tokens=MAX_TOKENS)
|
||||
max_tokens=MAX_TOKENS,
|
||||
block_size=BLOCK_SIZE)
|
||||
req_ids = []
|
||||
req_to_index = {}
|
||||
for i, request in enumerate(requests):
|
||||
@@ -1162,7 +1166,6 @@ def assert_scheduler_empty(scheduler: Scheduler):
|
||||
# KVCache Manager.
|
||||
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
req_to_blocks) == 0
|
||||
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
|
||||
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
num_cached_block) == 0
|
||||
num_free_blocks = (
|
||||
|
||||
@@ -17,7 +17,6 @@ from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
|
||||
def get_sliding_window_manager(sliding_window_spec, block_pool):
|
||||
return SlidingWindowManager(sliding_window_spec,
|
||||
block_pool,
|
||||
caching_hash_fn=lambda x: x,
|
||||
kv_cache_group_id=0)
|
||||
|
||||
|
||||
@@ -25,7 +24,6 @@ def get_chunked_local_attention_manager(chunked_local_attention_spec,
|
||||
block_pool):
|
||||
return ChunkedLocalAttentionManager(chunked_local_attention_spec,
|
||||
block_pool,
|
||||
caching_hash_fn=lambda x: x,
|
||||
kv_cache_group_id=0)
|
||||
|
||||
|
||||
|
||||
@@ -10,6 +10,8 @@ from vllm.multimodal.inputs import (MultiModalBatchedField,
|
||||
MultiModalFieldElem, MultiModalKwargsItem,
|
||||
PlaceholderRange)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
||||
init_none_hash)
|
||||
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
@@ -114,6 +116,9 @@ def create_scheduler(
|
||||
)
|
||||
|
||||
|
||||
_none_hash_initialized = False
|
||||
|
||||
|
||||
def create_requests(
|
||||
num_requests: int,
|
||||
num_tokens: int = 10,
|
||||
@@ -122,7 +127,14 @@ def create_requests(
|
||||
stop_token_ids: Optional[list[int]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
same_prompt: bool = False,
|
||||
block_size: int = 16,
|
||||
) -> list[Request]:
|
||||
global _none_hash_initialized
|
||||
if not _none_hash_initialized:
|
||||
init_none_hash(hash)
|
||||
_none_hash_initialized = True
|
||||
|
||||
block_hasher = get_request_block_hasher(block_size, hash)
|
||||
sampling_params = SamplingParams(ignore_eos=False,
|
||||
max_tokens=max_tokens,
|
||||
stop_token_ids=stop_token_ids,
|
||||
@@ -139,9 +151,11 @@ def create_requests(
|
||||
)
|
||||
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
|
||||
mm_kwargs = [mm_item] * len(mm_position)
|
||||
mm_hashes = ["hash"] * len(mm_position)
|
||||
else:
|
||||
mm_position = None
|
||||
mm_kwargs = None
|
||||
mm_hashes = None
|
||||
prompt_token_ids = ([0] * num_tokens if same_prompt else [i] *
|
||||
num_tokens)
|
||||
request = Request(
|
||||
@@ -151,8 +165,9 @@ def create_requests(
|
||||
pooling_params=None,
|
||||
multi_modal_kwargs=mm_kwargs,
|
||||
multi_modal_placeholders=mm_position,
|
||||
multi_modal_hashes=None,
|
||||
multi_modal_hashes=mm_hashes,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
block_hasher=block_hasher,
|
||||
)
|
||||
requests.append(request)
|
||||
return requests
|
||||
|
||||
@@ -147,6 +147,7 @@ def test_basic_interface():
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request = create_request(request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True)
|
||||
request_id = request.request_id
|
||||
@@ -186,6 +187,7 @@ def test_prompt_less_than_block_size():
|
||||
|
||||
# Request will have 1 partial remote block.
|
||||
request = create_request(request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True,
|
||||
num_remote_blocks=1)
|
||||
|
||||
@@ -21,6 +21,7 @@ def test_basic_lifecycle():
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request = create_request(request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
max_tokens=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True)
|
||||
@@ -103,8 +104,10 @@ def test_short_prompt_lifecycle():
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# Not enough tokens for full block.
|
||||
NUM_TOKENS = vllm_config.cache_config.block_size // 2
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_TOKENS = BLOCK_SIZE // 2
|
||||
request = create_request(request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
max_tokens=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True)
|
||||
@@ -148,7 +151,9 @@ def test_prefix_cache_lifecycle():
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 3
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS)
|
||||
request_normal = create_request(request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS)
|
||||
|
||||
scheduler.add_request(request_normal)
|
||||
scheduler_output = scheduler.schedule()
|
||||
@@ -166,6 +171,7 @@ def test_prefix_cache_lifecycle():
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request_remote = create_request(request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True)
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ def test_basic_lifecycle():
|
||||
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
|
||||
|
||||
request = create_request(request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True)
|
||||
|
||||
@@ -133,14 +134,17 @@ def test_interleaved_lifecycle():
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request_remote = create_request(request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True)
|
||||
request_local_a = create_request(
|
||||
request_id=2,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
)
|
||||
request_local_b = create_request(
|
||||
request_id=3,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
)
|
||||
|
||||
@@ -236,6 +240,7 @@ def test_no_spurious_prefix_caching():
|
||||
# Both of these requests have prompts like [1,1,1,1,1, ...]
|
||||
request_remote = create_request(
|
||||
request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True,
|
||||
use_all_1s_for_prompt_tokens=True,
|
||||
@@ -243,6 +248,7 @@ def test_no_spurious_prefix_caching():
|
||||
|
||||
request_local = create_request(
|
||||
request_id=2,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=False,
|
||||
use_all_1s_for_prompt_tokens=True,
|
||||
@@ -292,6 +298,7 @@ def test_full_block_prompt():
|
||||
NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS)
|
||||
|
||||
request = create_request(request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True)
|
||||
|
||||
@@ -364,8 +371,11 @@ def test_cannot_schedule_after_recv():
|
||||
NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
|
||||
NUM_TOKENS_REMOTE = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
|
||||
|
||||
request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS_LOCAL)
|
||||
request_normal = create_request(request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS_LOCAL)
|
||||
request_remote = create_request(request_id=2,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS_REMOTE,
|
||||
do_remote_prefill=True)
|
||||
|
||||
@@ -456,8 +466,11 @@ def test_cannot_recv():
|
||||
NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
|
||||
NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5))
|
||||
|
||||
request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS_LOCAL)
|
||||
request_normal = create_request(request_id=1,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS_LOCAL)
|
||||
request_remote = create_request(request_id=2,
|
||||
block_size=BLOCK_SIZE,
|
||||
num_tokens=NUM_TOKENS_REMOTE,
|
||||
do_remote_prefill=True)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -14,6 +14,8 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
|
||||
SharedStorageConnector)
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
||||
init_none_hash)
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec)
|
||||
@@ -40,7 +42,6 @@ def assert_scheduler_empty(scheduler: Scheduler):
|
||||
# KVCache Manager.
|
||||
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
req_to_blocks) == 0
|
||||
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
|
||||
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
num_cached_block) == 0
|
||||
num_free_blocks = (
|
||||
@@ -115,16 +116,23 @@ def create_scheduler(
|
||||
)
|
||||
|
||||
|
||||
def create_request(
|
||||
request_id: int,
|
||||
num_tokens: int = 10,
|
||||
max_tokens: int = 16,
|
||||
do_remote_decode: bool = False,
|
||||
do_remote_prefill: bool = False,
|
||||
use_all_1s_for_prompt_tokens: bool = False,
|
||||
num_remote_blocks: int = 3,
|
||||
) -> Request:
|
||||
_none_hash_initialized = False
|
||||
|
||||
|
||||
def create_request(request_id: int,
|
||||
num_tokens: int = 10,
|
||||
max_tokens: int = 16,
|
||||
do_remote_decode: bool = False,
|
||||
do_remote_prefill: bool = False,
|
||||
use_all_1s_for_prompt_tokens: bool = False,
|
||||
num_remote_blocks: int = 3,
|
||||
block_size: int = 16,
|
||||
hash_fn: Callable = hash) -> Request:
|
||||
"""Make dummy request for testing."""
|
||||
global _none_hash_initialized
|
||||
if not _none_hash_initialized:
|
||||
init_none_hash(hash)
|
||||
_none_hash_initialized = True
|
||||
|
||||
kv_transfer_params: Optional[dict[str, Any]] = None
|
||||
|
||||
@@ -158,6 +166,7 @@ def create_request(
|
||||
multi_modal_placeholders=None,
|
||||
multi_modal_hashes=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
block_hasher=get_request_block_hasher(block_size, hash_fn),
|
||||
)
|
||||
req.kv_transfer_params = kv_transfer_params
|
||||
return req
|
||||
|
||||
@@ -3243,6 +3243,24 @@ def sha256_cbor_64bit(input) -> int:
|
||||
return full_hash & ((1 << 64) - 1)
|
||||
|
||||
|
||||
def get_hash_fn_by_name(hash_fn_name: str) -> Callable:
|
||||
"""Get a hash function by name, or raise an error if
|
||||
the function is not found.
|
||||
Args:
|
||||
hash_fn_name: Name of the hash function.
|
||||
Returns:
|
||||
A hash function.
|
||||
"""
|
||||
if hash_fn_name == "sha256":
|
||||
return sha256
|
||||
if hash_fn_name == "sha256_cbor_64bit":
|
||||
return sha256_cbor_64bit
|
||||
if hash_fn_name == "builtin":
|
||||
return hash
|
||||
|
||||
raise ValueError(f"Unsupported hash function: {hash_fn_name}")
|
||||
|
||||
|
||||
def is_torch_equal_or_newer(target: str) -> bool:
|
||||
"""Check if the installed torch version is >= the target version.
|
||||
|
||||
|
||||
@@ -2,15 +2,13 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable, Optional
|
||||
from typing import Optional
|
||||
|
||||
from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved,
|
||||
BlockStored, KVCacheEvent)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
||||
FreeKVCacheBlockQueue, KVCacheBlock,
|
||||
generate_block_hash_extra_keys,
|
||||
hash_block_tokens)
|
||||
FreeKVCacheBlockQueue, KVCacheBlock)
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -97,84 +95,39 @@ class BlockPool:
|
||||
self,
|
||||
request: Request,
|
||||
blocks: list[KVCacheBlock],
|
||||
block_hashes: list[BlockHash],
|
||||
num_cached_blocks: int,
|
||||
num_full_blocks: int,
|
||||
block_size: int,
|
||||
kv_cache_group_id: int,
|
||||
hash_fn: Callable,
|
||||
) -> None:
|
||||
"""Cache a list of full blocks for prefix caching.
|
||||
This function takes a list of blocks that will have their block hash
|
||||
metadata to be updated and cached. Given a request, it computes the
|
||||
block hashes for the blocks starting from `num_cached_blocks` to
|
||||
`num_full_blocks`, updating the metadata for each block
|
||||
and caching them in the `cached_block_hash_to_block`.
|
||||
metadata to be updated and cached. Given a request, it updates the
|
||||
metadata for each block and caching it in the
|
||||
`cached_block_hash_to_block`.
|
||||
The block hashes values are computed by the Request object immediately
|
||||
when it is created and when new tokens are appended.
|
||||
|
||||
Args:
|
||||
request: The request to cache the blocks.
|
||||
blocks: All blocks in the request.
|
||||
block_hashes: Block hashes of the blocks in the request. Note that
|
||||
this list may be shorter than the blocks list. In this case the
|
||||
missed block hash will be computed in this function.
|
||||
num_cached_blocks: The number of blocks that are already cached.
|
||||
num_full_blocks: The number of blocks that are full and should
|
||||
be cached after this function.
|
||||
block_size: Number of tokens in each block.
|
||||
kv_cache_group_id: The id of the KV cache group.
|
||||
hash_fn: The hash function to use for block hashes.
|
||||
"""
|
||||
if num_cached_blocks == num_full_blocks:
|
||||
return
|
||||
new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
|
||||
assert len(block_hashes) >= num_cached_blocks
|
||||
new_block_hashes = block_hashes[num_cached_blocks:]
|
||||
assert len(request.block_hashes) >= num_full_blocks
|
||||
new_block_hashes = request.block_hashes[num_cached_blocks:]
|
||||
|
||||
# Update the new blocks with the block hashes through the chain.
|
||||
if num_cached_blocks == 0:
|
||||
prev_block_hash_value = None
|
||||
else:
|
||||
prev_block = blocks[num_cached_blocks - 1]
|
||||
assert prev_block.block_hash is not None
|
||||
prev_block_hash_value = prev_block.block_hash.get_hash_value()
|
||||
|
||||
parent_block_hash = prev_block_hash_value
|
||||
new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events
|
||||
else None)
|
||||
for i, blk in enumerate(new_full_blocks):
|
||||
assert blk.block_hash is None
|
||||
|
||||
if i < len(new_block_hashes):
|
||||
# The block hash may already be computed in
|
||||
# "get_computed_blocks" if the tokens are not generated by
|
||||
# this request (either the prompt tokens or the previously
|
||||
# generated tokens with preemption), or by other
|
||||
# single_type_managers with the same block_size.
|
||||
# In this case we simply reuse the block hash.
|
||||
block_hash = new_block_hashes[i]
|
||||
else:
|
||||
# Otherwise compute the block hash and cache it in the request
|
||||
# in case it will be preempted in the future.
|
||||
blk_idx = num_cached_blocks + i
|
||||
start_token_idx = blk_idx * block_size
|
||||
end_token_idx = (blk_idx + 1) * block_size
|
||||
block_tokens = request.all_token_ids[
|
||||
start_token_idx:end_token_idx]
|
||||
assert len(block_tokens) == block_size, (
|
||||
f"Expected {block_size} tokens, got "
|
||||
f"{len(block_tokens)} at {blk_idx}th block for request "
|
||||
f"{request.request_id}({request})")
|
||||
|
||||
# Generate extra keys for multi-modal inputs. Note that since
|
||||
# we reach to this branch only when the block is completed with
|
||||
# generated tokens, we only need to consider the last mm input.
|
||||
extra_keys, _ = generate_block_hash_extra_keys(
|
||||
request, start_token_idx, end_token_idx, -1)
|
||||
|
||||
# Compute the hash of the current block.
|
||||
block_hash = hash_block_tokens(hash_fn, prev_block_hash_value,
|
||||
block_tokens, extra_keys)
|
||||
block_hashes.append(block_hash)
|
||||
block_hash = new_block_hashes[i]
|
||||
|
||||
# Update and added the full block to the cache.
|
||||
block_hash_with_group_id = BlockHashWithGroupId(
|
||||
@@ -184,9 +137,15 @@ class BlockPool:
|
||||
blk.block_id] = blk
|
||||
if new_hashes is not None:
|
||||
new_hashes.append(block_hash.hash_value)
|
||||
prev_block_hash_value = block_hash.hash_value
|
||||
|
||||
if self.enable_kv_cache_events:
|
||||
if num_cached_blocks == 0:
|
||||
parent_block_hash = None
|
||||
else:
|
||||
parent_block = blocks[num_cached_blocks - 1]
|
||||
assert parent_block.block_hash is not None
|
||||
parent_block_hash = parent_block.block_hash.get_hash_value()
|
||||
|
||||
self.kv_event_queue.append(
|
||||
BlockStored(
|
||||
block_hashes=new_hashes,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Optional
|
||||
from typing import Optional
|
||||
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||
@@ -23,7 +23,6 @@ class KVCacheCoordinator(ABC):
|
||||
max_model_len: int,
|
||||
use_eagle: bool,
|
||||
enable_caching: bool,
|
||||
caching_hash_fn: Callable,
|
||||
enable_kv_cache_events: bool,
|
||||
):
|
||||
self.kv_cache_config = kv_cache_config
|
||||
@@ -40,7 +39,6 @@ class KVCacheCoordinator(ABC):
|
||||
kv_cache_spec=kv_cache_group.kv_cache_spec,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_group_id=i,
|
||||
caching_hash_fn=caching_hash_fn,
|
||||
) for i, kv_cache_group in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups))
|
||||
|
||||
@@ -99,19 +97,17 @@ class KVCacheCoordinator(ABC):
|
||||
manager.allocate_new_blocks(request_id, num_tokens)
|
||||
for manager in self.single_type_managers)
|
||||
|
||||
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
|
||||
num_computed_tokens: int) -> None:
|
||||
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
|
||||
"""
|
||||
Cache the blocks for the request.
|
||||
|
||||
Args:
|
||||
request: The request.
|
||||
block_hashes: The block hashes of the request.
|
||||
num_tokens: The total number of tokens that need to be cached
|
||||
(including tokens that are already cached).
|
||||
"""
|
||||
for manager in self.single_type_managers:
|
||||
manager.cache_blocks(request, block_hashes, num_computed_tokens)
|
||||
manager.cache_blocks(request, num_computed_tokens)
|
||||
|
||||
def free(self, request_id: str) -> None:
|
||||
"""
|
||||
@@ -184,10 +180,9 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
|
||||
"""
|
||||
|
||||
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
|
||||
use_eagle: bool, caching_hash_fn: Callable,
|
||||
enable_kv_cache_events: bool):
|
||||
use_eagle: bool, enable_kv_cache_events: bool):
|
||||
super().__init__(kv_cache_config, max_model_len, use_eagle, False,
|
||||
caching_hash_fn, enable_kv_cache_events)
|
||||
enable_kv_cache_events)
|
||||
self.num_single_type_manager = len(self.single_type_managers)
|
||||
|
||||
def get_num_common_prefix_blocks(self, request_id: str,
|
||||
@@ -213,10 +208,9 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
||||
|
||||
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
|
||||
use_eagle: bool, enable_caching: bool,
|
||||
caching_hash_fn: Callable, enable_kv_cache_events: bool):
|
||||
enable_kv_cache_events: bool):
|
||||
super().__init__(kv_cache_config, max_model_len, use_eagle,
|
||||
enable_caching, caching_hash_fn,
|
||||
enable_kv_cache_events)
|
||||
enable_caching, enable_kv_cache_events)
|
||||
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[
|
||||
0].kv_cache_spec
|
||||
self.block_size = self.kv_cache_spec.block_size
|
||||
@@ -250,10 +244,9 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
|
||||
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
|
||||
use_eagle: bool, enable_caching: bool,
|
||||
caching_hash_fn: Callable, enable_kv_cache_events: bool):
|
||||
enable_kv_cache_events: bool):
|
||||
super().__init__(kv_cache_config, max_model_len, use_eagle,
|
||||
enable_caching, caching_hash_fn,
|
||||
enable_kv_cache_events)
|
||||
enable_caching, enable_kv_cache_events)
|
||||
self.verify_and_split_kv_cache_groups()
|
||||
|
||||
def verify_and_split_kv_cache_groups(self) -> None:
|
||||
@@ -386,17 +379,15 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
|
||||
def get_kv_cache_coordinator(
|
||||
kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool,
|
||||
enable_caching: bool, caching_hash_fn: Callable,
|
||||
enable_caching: bool,
|
||||
enable_kv_cache_events: bool) -> KVCacheCoordinator:
|
||||
if not enable_caching:
|
||||
return KVCacheCoordinatorNoPrefixCache(kv_cache_config, max_model_len,
|
||||
use_eagle, caching_hash_fn,
|
||||
use_eagle,
|
||||
enable_kv_cache_events)
|
||||
if len(kv_cache_config.kv_cache_groups) == 1:
|
||||
return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len,
|
||||
use_eagle, enable_caching,
|
||||
caching_hash_fn,
|
||||
enable_kv_cache_events)
|
||||
return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle,
|
||||
enable_caching, caching_hash_fn,
|
||||
enable_kv_cache_events)
|
||||
enable_caching, enable_kv_cache_events)
|
||||
|
||||
@@ -1,16 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import sha256, sha256_cbor_64bit
|
||||
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
|
||||
hash_request_tokens, init_none_hash)
|
||||
from vllm.v1.core.kv_cache_utils import KVCacheBlock
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
@@ -71,23 +68,13 @@ class KVCacheManager:
|
||||
kv_cache_config: KVCacheConfig,
|
||||
max_model_len: int,
|
||||
enable_caching: bool = True,
|
||||
caching_hash_algo: str = "builtin",
|
||||
use_eagle: bool = False,
|
||||
log_stats: bool = False,
|
||||
enable_kv_cache_events: bool = False,
|
||||
) -> None:
|
||||
self.max_model_len = max_model_len
|
||||
|
||||
if len(kv_cache_config.kv_cache_groups) == 0:
|
||||
# Attention free models don't have kv cache,
|
||||
# thus don't need prefix caching.
|
||||
enable_caching = False
|
||||
self.enable_caching = enable_caching
|
||||
|
||||
self.caching_hash_fn = (
|
||||
sha256_cbor_64bit if caching_hash_algo == "sha256_cbor_64bit" else
|
||||
sha256 if caching_hash_algo == "sha256" else hash)
|
||||
init_none_hash(self.caching_hash_fn)
|
||||
self.use_eagle = use_eagle
|
||||
self.log_stats = log_stats
|
||||
# FIXME: make prefix cache stats conditional on log_stats
|
||||
@@ -107,19 +94,12 @@ class KVCacheManager:
|
||||
max_model_len=self.max_model_len,
|
||||
use_eagle=self.use_eagle,
|
||||
enable_caching=self.enable_caching,
|
||||
caching_hash_fn=self.caching_hash_fn,
|
||||
enable_kv_cache_events=enable_kv_cache_events,
|
||||
)
|
||||
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
|
||||
self.block_pool = self.coordinator.block_pool
|
||||
self.kv_cache_config = kv_cache_config
|
||||
|
||||
# Mapping from request ID to kv block hashes.
|
||||
# This is to avoid recomputing the block hashes for each call of
|
||||
# `get_computed_blocks` or `allocate_slots`.
|
||||
self.req_to_block_hashes: defaultdict[
|
||||
str, list[BlockHash]] = defaultdict(list)
|
||||
|
||||
@property
|
||||
def usage(self) -> float:
|
||||
"""Get the KV cache usage.
|
||||
@@ -161,15 +141,6 @@ class KVCacheManager:
|
||||
and request.sampling_params.prompt_logprobs is not None)):
|
||||
return self.create_empty_block_list(), 0
|
||||
|
||||
# The block hashes for the request may already be computed
|
||||
# if the scheduler has tried to schedule the request before.
|
||||
block_hashes = self.req_to_block_hashes[request.request_id]
|
||||
if not block_hashes:
|
||||
assert self.block_size is not None
|
||||
block_hashes = hash_request_tokens(self.caching_hash_fn,
|
||||
self.block_size, request)
|
||||
self.req_to_block_hashes[request.request_id] = block_hashes
|
||||
|
||||
# NOTE: When all tokens hit the cache, we must recompute the last token
|
||||
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
|
||||
# This can trigger recomputation of an entire block, rather than just
|
||||
@@ -178,7 +149,7 @@ class KVCacheManager:
|
||||
# could slightly improve performance in the future.
|
||||
max_cache_hit_length = request.num_tokens - 1
|
||||
computed_blocks, num_new_computed_tokens = (
|
||||
self.coordinator.find_longest_cache_hit(block_hashes,
|
||||
self.coordinator.find_longest_cache_hit(request.block_hashes,
|
||||
max_cache_hit_length))
|
||||
|
||||
if self.log_stats:
|
||||
@@ -296,11 +267,7 @@ class KVCacheManager:
|
||||
# at `request.num_tokens`, ensuring only "finalized" tokens are cached.
|
||||
num_tokens_to_cache = min(num_computed_tokens + num_new_tokens,
|
||||
request.num_tokens)
|
||||
self.coordinator.cache_blocks(
|
||||
request,
|
||||
self.req_to_block_hashes[request.request_id],
|
||||
num_tokens_to_cache,
|
||||
)
|
||||
self.coordinator.cache_blocks(request, num_tokens_to_cache)
|
||||
|
||||
return KVCacheBlocks(new_blocks)
|
||||
|
||||
@@ -373,14 +340,6 @@ class KVCacheManager:
|
||||
return self.coordinator.get_num_common_prefix_blocks(
|
||||
request.request_id, num_running_requests)
|
||||
|
||||
def free_block_hashes(self, request: Request) -> None:
|
||||
"""Discard the block hashes for the request.
|
||||
|
||||
NOTE: Unlike `free`, this method should be called only when the request
|
||||
is finished, not when it is preempted.
|
||||
"""
|
||||
self.req_to_block_hashes.pop(request.request_id, None)
|
||||
|
||||
def take_events(self) -> list[KVCacheEvent]:
|
||||
"""Take the KV cache events from the block pool.
|
||||
|
||||
@@ -397,9 +356,7 @@ class KVCacheManager:
|
||||
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
|
||||
"""Cache the blocks for the request, if enabled."""
|
||||
if self.enable_caching:
|
||||
block_hashes = self.req_to_block_hashes[request.request_id]
|
||||
self.coordinator.cache_blocks(request, block_hashes,
|
||||
num_computed_tokens)
|
||||
self.coordinator.cache_blocks(request, num_computed_tokens)
|
||||
|
||||
def create_empty_block_list(self) -> KVCacheBlocks:
|
||||
"""Creates a new KVCacheBlocks instance with no blocks."""
|
||||
|
||||
@@ -547,41 +547,61 @@ def hash_block_tokens(
|
||||
curr_block_token_ids_tuple, extra_keys)
|
||||
|
||||
|
||||
def hash_request_tokens(hash_function: Any, block_size: int,
|
||||
request: Request) -> list[BlockHash]:
|
||||
"""Computes hash values of a chain of blocks given a sequence of
|
||||
token IDs. The hash value is used for prefix caching.
|
||||
|
||||
Args:
|
||||
block_size: The size of each block.
|
||||
request: The request object.
|
||||
|
||||
Returns:
|
||||
The list of computed hash values.
|
||||
def get_request_block_hasher(
|
||||
block_size: int,
|
||||
caching_hash_fn: Callable[[Any],
|
||||
int]) -> Callable[[Request], list[BlockHash]]:
|
||||
"""
|
||||
token_ids = request.all_token_ids
|
||||
Returns a function which computes the list of un-computed block hashes
|
||||
of a request.
|
||||
|
||||
req_need_extra_keys = need_extra_keys(request)
|
||||
req_extra_keys = None
|
||||
curr_mm_idx = 0
|
||||
Each request holds a list of its block hashes (request.block_hashes).
|
||||
When a request is created, it calls the below function to compute
|
||||
the hashes of all full blocks of the request's initial tokens.
|
||||
The hashes are then stored in request.block_hashes.
|
||||
Later, whenever new tokens are appended to the request, it calls
|
||||
the below function again to compute any new full blocks of tokens.
|
||||
The returned new hashes are appended to request.block_hashes.
|
||||
"""
|
||||
|
||||
ret = []
|
||||
parent_block_hash_value = None
|
||||
# Only full blocks will be hashed
|
||||
for start in range(0, len(token_ids) - block_size + 1, block_size):
|
||||
end = start + block_size
|
||||
block_token_ids = token_ids[start:end]
|
||||
def request_block_hasher(request: Request) -> list[BlockHash]:
|
||||
start_token_idx = len(request.block_hashes) * block_size
|
||||
num_tokens = request.num_tokens
|
||||
|
||||
curr_mm_idx = 0
|
||||
if start_token_idx > 0:
|
||||
# Set curr_mm_idx = -1 to indicate the last mm input.
|
||||
# Note that since we reach to this branch only when the block is
|
||||
# completed with generated tokens, we only need to consider the
|
||||
# last mm input.
|
||||
curr_mm_idx = -1
|
||||
|
||||
prev_block_hash_value = request.block_hashes[-1].hash_value \
|
||||
if request.block_hashes else None
|
||||
new_block_hashes: list[BlockHash] = []
|
||||
while True:
|
||||
end_token_idx = start_token_idx + block_size
|
||||
if end_token_idx > num_tokens:
|
||||
# We only hash full blocks
|
||||
break
|
||||
|
||||
if req_need_extra_keys:
|
||||
# MM and LoRA requests need extra keys for block-hash computation.
|
||||
req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
|
||||
request, start, end, curr_mm_idx)
|
||||
extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
|
||||
request, start_token_idx, end_token_idx, curr_mm_idx)
|
||||
|
||||
block_hash = hash_block_tokens(hash_function, parent_block_hash_value,
|
||||
block_token_ids, req_extra_keys)
|
||||
ret.append(block_hash)
|
||||
parent_block_hash_value = block_hash.hash_value
|
||||
return ret
|
||||
# Compute the hash of the current block
|
||||
block_tokens = request.all_token_ids[start_token_idx:end_token_idx]
|
||||
block_hash = hash_block_tokens(caching_hash_fn,
|
||||
prev_block_hash_value, block_tokens,
|
||||
extra_keys)
|
||||
|
||||
new_block_hashes.append(block_hash)
|
||||
start_token_idx += block_size
|
||||
prev_block_hash_value = block_hash.hash_value
|
||||
|
||||
return new_block_hashes
|
||||
|
||||
return request_block_hasher
|
||||
|
||||
|
||||
def max_memory_usage_bytes(vllm_config: VllmConfig,
|
||||
|
||||
@@ -155,7 +155,6 @@ class Scheduler(SchedulerInterface):
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=self.max_model_len,
|
||||
enable_caching=self.cache_config.enable_prefix_caching,
|
||||
caching_hash_algo=self.cache_config.prefix_caching_hash_algo,
|
||||
use_eagle=self.use_eagle,
|
||||
log_stats=self.log_stats,
|
||||
enable_kv_cache_events=self.enable_kv_cache_events,
|
||||
@@ -1036,7 +1035,6 @@ class Scheduler(SchedulerInterface):
|
||||
def _free_blocks(self, request: Request):
|
||||
assert request.is_finished()
|
||||
self.kv_cache_manager.free(request)
|
||||
self.kv_cache_manager.free_block_hashes(request)
|
||||
del self.requests[request.request_id]
|
||||
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import itertools
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from typing import Callable
|
||||
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
@@ -25,7 +24,6 @@ class SingleTypeKVCacheManager(ABC):
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
block_pool: BlockPool,
|
||||
kv_cache_group_id: int,
|
||||
caching_hash_fn: Callable,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the SingleTypeKVCacheManager.
|
||||
@@ -33,7 +31,6 @@ class SingleTypeKVCacheManager(ABC):
|
||||
kv_cache_spec: The kv_cache_spec for this manager.
|
||||
block_pool: The block pool.
|
||||
kv_cache_group_id: The id of the kv cache group of this manager.
|
||||
caching_hash_fn: The caching hash function.
|
||||
"""
|
||||
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
@@ -52,7 +49,6 @@ class SingleTypeKVCacheManager(ABC):
|
||||
# data for reempted ones.
|
||||
self.num_cached_block: dict[str, int] = {}
|
||||
|
||||
self.caching_hash_fn = caching_hash_fn
|
||||
self.kv_cache_group_id = kv_cache_group_id
|
||||
self._null_block = block_pool.null_block
|
||||
|
||||
@@ -130,14 +126,12 @@ class SingleTypeKVCacheManager(ABC):
|
||||
req_blocks.extend(new_blocks)
|
||||
return new_blocks
|
||||
|
||||
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
|
||||
num_tokens: int) -> None:
|
||||
def cache_blocks(self, request: Request, num_tokens: int) -> None:
|
||||
"""
|
||||
Cache the blocks for the request.
|
||||
|
||||
Args:
|
||||
request: The request.
|
||||
block_hashes: The block hashes of the request.
|
||||
num_tokens: The total number of tokens that need to be cached
|
||||
(including tokens that are already cached).
|
||||
"""
|
||||
@@ -147,12 +141,10 @@ class SingleTypeKVCacheManager(ABC):
|
||||
self.block_pool.cache_full_blocks(
|
||||
request=request,
|
||||
blocks=self.req_to_blocks[request.request_id],
|
||||
block_hashes=block_hashes,
|
||||
num_cached_blocks=num_cached_blocks,
|
||||
num_full_blocks=num_full_blocks,
|
||||
block_size=self.block_size,
|
||||
kv_cache_group_id=self.kv_cache_group_id,
|
||||
hash_fn=self.caching_hash_fn,
|
||||
)
|
||||
|
||||
self.num_cached_block[request.request_id] = num_full_blocks
|
||||
|
||||
@@ -25,9 +25,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.tasks import POOLING_TASKS, SupportedTask
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.utils import (decorate_logs, make_zmq_socket,
|
||||
from vllm.utils import (decorate_logs, get_hash_fn_by_name, make_zmq_socket,
|
||||
resolve_obj_by_qualname, set_process_title)
|
||||
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, get_kv_cache_config,
|
||||
get_request_block_hasher,
|
||||
init_none_hash,
|
||||
unify_kv_cache_configs)
|
||||
from vllm.v1.core.sched.interface import SchedulerInterface
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@@ -140,6 +142,19 @@ class EngineCore:
|
||||
self.batch_queue_size)
|
||||
self.batch_queue = queue.Queue(self.batch_queue_size)
|
||||
|
||||
self.request_block_hasher: Optional[Callable[[Request],
|
||||
list[BlockHash]]] = None
|
||||
if (self.vllm_config.cache_config.enable_prefix_caching
|
||||
or self.scheduler.get_kv_connector() is not None):
|
||||
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
caching_hash_fn = get_hash_fn_by_name(
|
||||
vllm_config.cache_config.prefix_caching_hash_algo)
|
||||
init_none_hash(caching_hash_fn)
|
||||
|
||||
self.request_block_hasher = get_request_block_hasher(
|
||||
block_size, caching_hash_fn)
|
||||
|
||||
def _initialize_kv_caches(
|
||||
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
|
||||
start = time.time()
|
||||
@@ -417,7 +432,8 @@ class EngineCore:
|
||||
request.mm_kwargs = self.mm_input_cache_server.get_and_update(
|
||||
request.mm_kwargs, request.mm_hashes)
|
||||
|
||||
req = Request.from_engine_core_request(request)
|
||||
req = Request.from_engine_core_request(request,
|
||||
self.request_block_hasher)
|
||||
if req.use_structured_output:
|
||||
# Note on thread safety: no race condition.
|
||||
# `grammar_init` is only invoked in input processing thread. For
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
|
||||
import enum
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||
from vllm.pooling_params import PoolingParams
|
||||
@@ -16,6 +17,7 @@ from vllm.v1.utils import ConstantList
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
|
||||
|
||||
class Request:
|
||||
@@ -36,6 +38,8 @@ class Request:
|
||||
structured_output_request: Optional["StructuredOutputRequest"] = None,
|
||||
cache_salt: Optional[str] = None,
|
||||
priority: int = 0,
|
||||
block_hasher: Optional[Callable[["Request"],
|
||||
list["BlockHash"]]] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.client_index = client_index
|
||||
@@ -108,8 +112,18 @@ class Request:
|
||||
# indicates that the output is corrupted
|
||||
self.num_nans_in_logits = 0
|
||||
|
||||
self.block_hashes: list[BlockHash] = []
|
||||
self.get_hash_new_full_blocks: Optional[Callable[
|
||||
[], list[BlockHash]]] = None
|
||||
if block_hasher is not None:
|
||||
self.get_hash_new_full_blocks = partial(block_hasher, self)
|
||||
self.block_hashes = self.get_hash_new_full_blocks()
|
||||
|
||||
@classmethod
|
||||
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
|
||||
def from_engine_core_request(
|
||||
cls, request: EngineCoreRequest,
|
||||
block_hasher: Optional[Callable[["Request"], list["BlockHash"]]]
|
||||
) -> "Request":
|
||||
if request.mm_kwargs is not None:
|
||||
assert is_list_of(request.mm_kwargs, MultiModalKwargsItem), (
|
||||
"mm_kwargs was not updated in EngineCore.add_request")
|
||||
@@ -131,6 +145,7 @@ class Request:
|
||||
if request.sampling_params else None,
|
||||
cache_salt=request.cache_salt,
|
||||
priority=request.priority,
|
||||
block_hasher=block_hasher,
|
||||
)
|
||||
|
||||
def append_output_token_ids(
|
||||
@@ -144,6 +159,9 @@ class Request:
|
||||
self._output_token_ids.extend(token_ids)
|
||||
self._all_token_ids.extend(token_ids)
|
||||
|
||||
if self.get_hash_new_full_blocks is not None:
|
||||
self.block_hashes.extend(self.get_hash_new_full_blocks())
|
||||
|
||||
@property
|
||||
def is_output_corrupted(self) -> bool:
|
||||
return self.num_nans_in_logits > 0
|
||||
|
||||
Reference in New Issue
Block a user