[feat] Add per-block extra_keys to KV events (#33304)
Signed-off-by: zhongdaor-nv <zhongdaor@nvidia.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -37,6 +37,12 @@ class BlockStored(KVCacheEvent):
|
||||
medium: str | None
|
||||
lora_name: str | None
|
||||
|
||||
extra_keys: list[tuple[Any, ...] | None] | None = None
|
||||
"""Extra keys used in block hash computation, one entry per block in
|
||||
block_hashes. Each entry contains MM identifiers, LoRA name, cache_salt,
|
||||
prompt embeddings data, etc. for that specific block.
|
||||
"""
|
||||
|
||||
|
||||
class BlockRemoved(KVCacheEvent):
|
||||
block_hashes: list[ExternalBlockHash]
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import hashlib
|
||||
import importlib
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
@@ -498,14 +499,41 @@ def test_generate_block_hash_extra_keys_prompt_embeds():
|
||||
# Test with prompt embeds for the first block
|
||||
extra_keys, _ = generate_block_hash_extra_keys(request, 0, 5, 0)
|
||||
expected_embeds = prompt_embeds[0:5]
|
||||
expected_bytes = kv_cache_utils.tensor_data(expected_embeds).tobytes()
|
||||
assert extra_keys == (expected_bytes,)
|
||||
expected_hash = hashlib.sha256(kv_cache_utils.tensor_data(expected_embeds)).digest()
|
||||
assert extra_keys == (expected_hash,)
|
||||
|
||||
# Test with prompt embeds for the second block
|
||||
extra_keys, _ = generate_block_hash_extra_keys(request, 5, 10, 0)
|
||||
expected_embeds = prompt_embeds[5:10]
|
||||
expected_bytes = kv_cache_utils.tensor_data(expected_embeds).tobytes()
|
||||
assert extra_keys == (expected_bytes,)
|
||||
expected_hash = hashlib.sha256(kv_cache_utils.tensor_data(expected_embeds)).digest()
|
||||
assert extra_keys == (expected_hash,)
|
||||
|
||||
|
||||
def test_generate_block_hash_extra_keys_prompt_embeds_cached(monkeypatch):
|
||||
prompt_embeds = torch.randn(10, 3)
|
||||
request = make_request(
|
||||
request_id="0",
|
||||
prompt_token_ids=None,
|
||||
mm_positions=None,
|
||||
mm_hashes=None,
|
||||
prompt_embeds=prompt_embeds,
|
||||
block_size=20,
|
||||
)
|
||||
|
||||
num_tensor_data_calls = 0
|
||||
original_tensor_data = kv_cache_utils.tensor_data
|
||||
|
||||
def counting_tensor_data(tensor: torch.Tensor):
|
||||
nonlocal num_tensor_data_calls
|
||||
num_tensor_data_calls += 1
|
||||
return original_tensor_data(tensor)
|
||||
|
||||
monkeypatch.setattr(kv_cache_utils, "tensor_data", counting_tensor_data)
|
||||
|
||||
extra_keys_1, _ = generate_block_hash_extra_keys(request, 0, 5, 0)
|
||||
extra_keys_2, _ = generate_block_hash_extra_keys(request, 0, 5, 0)
|
||||
assert extra_keys_1 == extra_keys_2
|
||||
assert num_tensor_data_calls == 1
|
||||
|
||||
|
||||
def test_generate_block_hash_extra_keys_different_prompt_embeds():
|
||||
@@ -1858,22 +1886,26 @@ def test_request_block_hasher_with_prompt_embeds(hash_fn: Callable[[Any], bytes]
|
||||
block_hashes = request.block_hashes
|
||||
assert len(block_hashes) == 2
|
||||
|
||||
block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes()
|
||||
block1_embeds_hash = hashlib.sha256(
|
||||
tensor_data(prompt_embeds[:block_size])
|
||||
).digest()
|
||||
expected_hash1 = hash_fn(
|
||||
(
|
||||
kv_cache_utils.NONE_HASH,
|
||||
tuple(prompt_token_ids[:block_size]),
|
||||
(block1_embeds_bytes,),
|
||||
(block1_embeds_hash,),
|
||||
)
|
||||
)
|
||||
assert block_hashes[0] == expected_hash1
|
||||
|
||||
block2_embeds_bytes = tensor_data(prompt_embeds[block_size:num_tokens]).tobytes()
|
||||
block2_embeds_hash = hashlib.sha256(
|
||||
tensor_data(prompt_embeds[block_size:num_tokens])
|
||||
).digest()
|
||||
expected_hash2 = hash_fn(
|
||||
(
|
||||
block_hashes[0],
|
||||
tuple(prompt_token_ids[block_size:num_tokens]),
|
||||
(block2_embeds_bytes,),
|
||||
(block2_embeds_hash,),
|
||||
)
|
||||
)
|
||||
assert block_hashes[1] == expected_hash2
|
||||
@@ -1903,22 +1935,26 @@ def test_request_with_prompt_embeds_and_mm_inputs(hash_fn: Callable[[Any], bytes
|
||||
block_hashes = request.block_hashes
|
||||
assert len(block_hashes) == 2
|
||||
|
||||
block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes()
|
||||
block1_embeds_hash = hashlib.sha256(
|
||||
tensor_data(prompt_embeds[:block_size])
|
||||
).digest()
|
||||
expected_hash1 = hash_fn(
|
||||
(
|
||||
kv_cache_utils.NONE_HASH,
|
||||
tuple(prompt_token_ids[:block_size]),
|
||||
("hash1", block1_embeds_bytes),
|
||||
("hash1", block1_embeds_hash),
|
||||
)
|
||||
)
|
||||
assert block_hashes[0] == expected_hash1
|
||||
|
||||
block2_embeds_bytes = tensor_data(prompt_embeds[block_size:num_tokens]).tobytes()
|
||||
block2_embeds_hash = hashlib.sha256(
|
||||
tensor_data(prompt_embeds[block_size:num_tokens])
|
||||
).digest()
|
||||
expected_hash2 = hash_fn(
|
||||
(
|
||||
block_hashes[0],
|
||||
tuple(prompt_token_ids[block_size:num_tokens]),
|
||||
("hash2", block2_embeds_bytes),
|
||||
("hash2", block2_embeds_hash),
|
||||
)
|
||||
)
|
||||
assert block_hashes[1] == expected_hash2
|
||||
|
||||
@@ -60,6 +60,13 @@ class BlockStored(KVCacheEvent):
|
||||
medium: str | None
|
||||
lora_name: str | None
|
||||
|
||||
extra_keys: list[tuple[Any, ...] | None] | None = None
|
||||
"""Extra keys used in block hash computation, one entry per block in
|
||||
block_hashes. Each entry contains MM identifiers, LoRA name, cache_salt,
|
||||
prompt embedding hashes, etc. for that specific block. Exposed for external
|
||||
KV cache consumers to reconstruct block hashes.
|
||||
"""
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(
|
||||
(
|
||||
@@ -69,6 +76,7 @@ class BlockStored(KVCacheEvent):
|
||||
self.block_size,
|
||||
self.lora_id,
|
||||
self.medium,
|
||||
tuple(self.extra_keys) if self.extra_keys else None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from vllm.v1.core.kv_cache_utils import (
|
||||
ExternalBlockHash,
|
||||
FreeKVCacheBlockQueue,
|
||||
KVCacheBlock,
|
||||
generate_block_hash_extra_keys,
|
||||
get_block_hash,
|
||||
make_block_hash_with_group_id,
|
||||
maybe_convert_block_hash,
|
||||
@@ -279,13 +280,31 @@ class BlockPool:
|
||||
block_hashes[num_cached_blocks - 1]
|
||||
)
|
||||
|
||||
# Calculate token range for the blocks being cached
|
||||
start_token_idx = num_cached_blocks * block_size
|
||||
end_token_idx = num_full_blocks * block_size
|
||||
|
||||
# Generate extra keys for each block individually.
|
||||
# Each block may have different extra_keys (e.g., different MM
|
||||
# features, or cache_salt only for the first block).
|
||||
# Skip null blocks to match the length of new_hashes.
|
||||
extra_keys_list: list[tuple[Any, ...] | None] = []
|
||||
curr_mm_idx = 0
|
||||
for i in range(num_cached_blocks, num_full_blocks):
|
||||
if blocks[i].is_null:
|
||||
continue
|
||||
block_start = i * block_size
|
||||
block_end = block_start + block_size
|
||||
extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
|
||||
request, block_start, block_end, curr_mm_idx
|
||||
)
|
||||
extra_keys_list.append(extra_keys)
|
||||
|
||||
self.kv_event_queue.append(
|
||||
BlockStored(
|
||||
block_hashes=new_hashes,
|
||||
parent_block_hash=parent_block_hash,
|
||||
token_ids=request.all_token_ids[
|
||||
num_cached_blocks * block_size : num_full_blocks * block_size
|
||||
],
|
||||
token_ids=request.all_token_ids[start_token_idx:end_token_idx],
|
||||
block_size=block_size,
|
||||
lora_id=request.lora_request.adapter_id
|
||||
if request.lora_request
|
||||
@@ -294,6 +313,7 @@ class BlockPool:
|
||||
lora_name=request.lora_request.name
|
||||
if request.lora_request
|
||||
else None,
|
||||
extra_keys=extra_keys_list if extra_keys_list else None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
"""KV-Cache Utilities."""
|
||||
|
||||
import copy
|
||||
import hashlib
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Iterable, Iterator, Sequence
|
||||
@@ -475,14 +476,19 @@ def _gen_prompt_embeds_extra_hash_keys(
|
||||
end_token_idx: The end token index of the block.
|
||||
|
||||
Returns:
|
||||
Return prompt embeddings data of the request if it has prompt embeds.
|
||||
Return empty list otherwise.
|
||||
Return a stable hash of the block prompt embeddings if prompt embeds
|
||||
are present. Return empty list otherwise.
|
||||
"""
|
||||
if request.prompt_embeds is None:
|
||||
return []
|
||||
block_prompt_embeds = request.prompt_embeds[start_token_idx:end_token_idx]
|
||||
embeds_bytes = tensor_data(block_prompt_embeds).tobytes()
|
||||
return [embeds_bytes]
|
||||
block_range = (start_token_idx, end_token_idx)
|
||||
embeds_hash = request._prompt_embeds_per_block_hashes.get(block_range)
|
||||
if embeds_hash is None:
|
||||
block_prompt_embeds = request.prompt_embeds[start_token_idx:end_token_idx]
|
||||
# Hash prompt embeds once per block and cache on request
|
||||
embeds_hash = hashlib.sha256(tensor_data(block_prompt_embeds)).digest()
|
||||
request._prompt_embeds_per_block_hashes[block_range] = embeds_hash
|
||||
return [embeds_hash]
|
||||
|
||||
|
||||
def generate_block_hash_extra_keys(
|
||||
@@ -490,7 +496,7 @@ def generate_block_hash_extra_keys(
|
||||
) -> tuple[tuple[Any, ...] | None, int]:
|
||||
"""Generate extra keys for the block hash. The extra keys can come from
|
||||
the multi-modal inputs, request specific metadata (e.g., LoRA names), and
|
||||
data from prompt embeddings.
|
||||
hashed data from prompt embeddings.
|
||||
|
||||
Args:
|
||||
request: The request object.
|
||||
|
||||
@@ -114,6 +114,9 @@ class Request:
|
||||
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.prompt_embeds = prompt_embeds
|
||||
# Cache per-block prompt-embed hashes to avoid rehashing the same
|
||||
# tensor slices when generating extra keys.
|
||||
self._prompt_embeds_per_block_hashes: dict[tuple[int, int], bytes] = {}
|
||||
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||
prompt_token_ids, prompt_embeds
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user