[feat] Add per-block extra_keys to KV events (#33304)

Signed-off-by: zhongdaor-nv <zhongdaor@nvidia.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
zhongdaor-nv
2026-02-20 21:11:40 -07:00
committed by GitHub
parent 991d6bff38
commit a0fe7ea2f0
6 changed files with 100 additions and 21 deletions

View File

@@ -37,6 +37,12 @@ class BlockStored(KVCacheEvent):
medium: str | None
lora_name: str | None
extra_keys: list[tuple[Any, ...] | None] | None = None
"""Extra keys used in block hash computation, one entry per block in
block_hashes. Each entry contains MM identifiers, LoRA name, cache_salt,
prompt embeddings data, etc. for that specific block.
"""
class BlockRemoved(KVCacheEvent):
block_hashes: list[ExternalBlockHash]

View File

@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import importlib
from collections.abc import Callable
from typing import Any
@@ -498,14 +499,41 @@ def test_generate_block_hash_extra_keys_prompt_embeds():
# Test with prompt embeds for the first block
extra_keys, _ = generate_block_hash_extra_keys(request, 0, 5, 0)
expected_embeds = prompt_embeds[0:5]
expected_bytes = kv_cache_utils.tensor_data(expected_embeds).tobytes()
assert extra_keys == (expected_bytes,)
expected_hash = hashlib.sha256(kv_cache_utils.tensor_data(expected_embeds)).digest()
assert extra_keys == (expected_hash,)
# Test with prompt embeds for the second block
extra_keys, _ = generate_block_hash_extra_keys(request, 5, 10, 0)
expected_embeds = prompt_embeds[5:10]
expected_bytes = kv_cache_utils.tensor_data(expected_embeds).tobytes()
assert extra_keys == (expected_bytes,)
expected_hash = hashlib.sha256(kv_cache_utils.tensor_data(expected_embeds)).digest()
assert extra_keys == (expected_hash,)
def test_generate_block_hash_extra_keys_prompt_embeds_cached(monkeypatch):
prompt_embeds = torch.randn(10, 3)
request = make_request(
request_id="0",
prompt_token_ids=None,
mm_positions=None,
mm_hashes=None,
prompt_embeds=prompt_embeds,
block_size=20,
)
num_tensor_data_calls = 0
original_tensor_data = kv_cache_utils.tensor_data
def counting_tensor_data(tensor: torch.Tensor):
nonlocal num_tensor_data_calls
num_tensor_data_calls += 1
return original_tensor_data(tensor)
monkeypatch.setattr(kv_cache_utils, "tensor_data", counting_tensor_data)
extra_keys_1, _ = generate_block_hash_extra_keys(request, 0, 5, 0)
extra_keys_2, _ = generate_block_hash_extra_keys(request, 0, 5, 0)
assert extra_keys_1 == extra_keys_2
assert num_tensor_data_calls == 1
def test_generate_block_hash_extra_keys_different_prompt_embeds():
@@ -1858,22 +1886,26 @@ def test_request_block_hasher_with_prompt_embeds(hash_fn: Callable[[Any], bytes]
block_hashes = request.block_hashes
assert len(block_hashes) == 2
block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes()
block1_embeds_hash = hashlib.sha256(
tensor_data(prompt_embeds[:block_size])
).digest()
expected_hash1 = hash_fn(
(
kv_cache_utils.NONE_HASH,
tuple(prompt_token_ids[:block_size]),
(block1_embeds_bytes,),
(block1_embeds_hash,),
)
)
assert block_hashes[0] == expected_hash1
block2_embeds_bytes = tensor_data(prompt_embeds[block_size:num_tokens]).tobytes()
block2_embeds_hash = hashlib.sha256(
tensor_data(prompt_embeds[block_size:num_tokens])
).digest()
expected_hash2 = hash_fn(
(
block_hashes[0],
tuple(prompt_token_ids[block_size:num_tokens]),
(block2_embeds_bytes,),
(block2_embeds_hash,),
)
)
assert block_hashes[1] == expected_hash2
@@ -1903,22 +1935,26 @@ def test_request_with_prompt_embeds_and_mm_inputs(hash_fn: Callable[[Any], bytes
block_hashes = request.block_hashes
assert len(block_hashes) == 2
block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes()
block1_embeds_hash = hashlib.sha256(
tensor_data(prompt_embeds[:block_size])
).digest()
expected_hash1 = hash_fn(
(
kv_cache_utils.NONE_HASH,
tuple(prompt_token_ids[:block_size]),
("hash1", block1_embeds_bytes),
("hash1", block1_embeds_hash),
)
)
assert block_hashes[0] == expected_hash1
block2_embeds_bytes = tensor_data(prompt_embeds[block_size:num_tokens]).tobytes()
block2_embeds_hash = hashlib.sha256(
tensor_data(prompt_embeds[block_size:num_tokens])
).digest()
expected_hash2 = hash_fn(
(
block_hashes[0],
tuple(prompt_token_ids[block_size:num_tokens]),
("hash2", block2_embeds_bytes),
("hash2", block2_embeds_hash),
)
)
assert block_hashes[1] == expected_hash2

View File

@@ -60,6 +60,13 @@ class BlockStored(KVCacheEvent):
medium: str | None
lora_name: str | None
extra_keys: list[tuple[Any, ...] | None] | None = None
"""Extra keys used in block hash computation, one entry per block in
block_hashes. Each entry contains MM identifiers, LoRA name, cache_salt,
prompt embedding hashes, etc. for that specific block. Exposed for external
KV cache consumers to reconstruct block hashes.
"""
def __hash__(self) -> int:
return hash(
(
@@ -69,6 +76,7 @@ class BlockStored(KVCacheEvent):
self.block_size,
self.lora_id,
self.medium,
tuple(self.extra_keys) if self.extra_keys else None,
)
)

View File

@@ -20,6 +20,7 @@ from vllm.v1.core.kv_cache_utils import (
ExternalBlockHash,
FreeKVCacheBlockQueue,
KVCacheBlock,
generate_block_hash_extra_keys,
get_block_hash,
make_block_hash_with_group_id,
maybe_convert_block_hash,
@@ -279,13 +280,31 @@ class BlockPool:
block_hashes[num_cached_blocks - 1]
)
# Calculate token range for the blocks being cached
start_token_idx = num_cached_blocks * block_size
end_token_idx = num_full_blocks * block_size
# Generate extra keys for each block individually.
# Each block may have different extra_keys (e.g., different MM
# features, or cache_salt only for the first block).
# Skip null blocks to match the length of new_hashes.
extra_keys_list: list[tuple[Any, ...] | None] = []
curr_mm_idx = 0
for i in range(num_cached_blocks, num_full_blocks):
if blocks[i].is_null:
continue
block_start = i * block_size
block_end = block_start + block_size
extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
request, block_start, block_end, curr_mm_idx
)
extra_keys_list.append(extra_keys)
self.kv_event_queue.append(
BlockStored(
block_hashes=new_hashes,
parent_block_hash=parent_block_hash,
token_ids=request.all_token_ids[
num_cached_blocks * block_size : num_full_blocks * block_size
],
token_ids=request.all_token_ids[start_token_idx:end_token_idx],
block_size=block_size,
lora_id=request.lora_request.adapter_id
if request.lora_request
@@ -294,6 +313,7 @@ class BlockPool:
lora_name=request.lora_request.name
if request.lora_request
else None,
extra_keys=extra_keys_list if extra_keys_list else None,
)
)

View File

@@ -3,6 +3,7 @@
"""KV-Cache Utilities."""
import copy
import hashlib
import os
from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator, Sequence
@@ -475,14 +476,19 @@ def _gen_prompt_embeds_extra_hash_keys(
end_token_idx: The end token index of the block.
Returns:
Return prompt embeddings data of the request if it has prompt embeds.
Return empty list otherwise.
Return a stable hash of the block prompt embeddings if prompt embeds
are present. Return empty list otherwise.
"""
if request.prompt_embeds is None:
return []
block_prompt_embeds = request.prompt_embeds[start_token_idx:end_token_idx]
embeds_bytes = tensor_data(block_prompt_embeds).tobytes()
return [embeds_bytes]
block_range = (start_token_idx, end_token_idx)
embeds_hash = request._prompt_embeds_per_block_hashes.get(block_range)
if embeds_hash is None:
block_prompt_embeds = request.prompt_embeds[start_token_idx:end_token_idx]
# Hash prompt embeds once per block and cache on request
embeds_hash = hashlib.sha256(tensor_data(block_prompt_embeds)).digest()
request._prompt_embeds_per_block_hashes[block_range] = embeds_hash
return [embeds_hash]
def generate_block_hash_extra_keys(
@@ -490,7 +496,7 @@ def generate_block_hash_extra_keys(
) -> tuple[tuple[Any, ...] | None, int]:
"""Generate extra keys for the block hash. The extra keys can come from
the multi-modal inputs, request specific metadata (e.g., LoRA names), and
data from prompt embeddings.
hashed data from prompt embeddings.
Args:
request: The request object.

View File

@@ -114,6 +114,9 @@ class Request:
self.prompt_token_ids = prompt_token_ids
self.prompt_embeds = prompt_embeds
# Cache per-block prompt-embed hashes to avoid rehashing the same
# tensor slices when generating extra keys.
self._prompt_embeds_per_block_hashes: dict[tuple[int, int], bytes] = {}
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
prompt_token_ids, prompt_embeds
)