[feat] Add per-block extra_keys to KV events (#33304)

Signed-off-by: zhongdaor-nv <zhongdaor@nvidia.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
zhongdaor-nv
2026-02-20 21:11:40 -07:00
committed by GitHub
parent 991d6bff38
commit a0fe7ea2f0
6 changed files with 100 additions and 21 deletions

View File

@@ -37,6 +37,12 @@ class BlockStored(KVCacheEvent):
medium: str | None medium: str | None
lora_name: str | None lora_name: str | None
extra_keys: list[tuple[Any, ...] | None] | None = None
"""Extra keys used in block hash computation, one entry per block in
block_hashes. Each entry contains MM identifiers, LoRA name, cache_salt,
prompt embeddings data, etc. for that specific block.
"""
class BlockRemoved(KVCacheEvent): class BlockRemoved(KVCacheEvent):
block_hashes: list[ExternalBlockHash] block_hashes: list[ExternalBlockHash]

View File

@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import importlib import importlib
from collections.abc import Callable from collections.abc import Callable
from typing import Any from typing import Any
@@ -498,14 +499,41 @@ def test_generate_block_hash_extra_keys_prompt_embeds():
# Test with prompt embeds for the first block # Test with prompt embeds for the first block
extra_keys, _ = generate_block_hash_extra_keys(request, 0, 5, 0) extra_keys, _ = generate_block_hash_extra_keys(request, 0, 5, 0)
expected_embeds = prompt_embeds[0:5] expected_embeds = prompt_embeds[0:5]
expected_bytes = kv_cache_utils.tensor_data(expected_embeds).tobytes() expected_hash = hashlib.sha256(kv_cache_utils.tensor_data(expected_embeds)).digest()
assert extra_keys == (expected_bytes,) assert extra_keys == (expected_hash,)
# Test with prompt embeds for the second block # Test with prompt embeds for the second block
extra_keys, _ = generate_block_hash_extra_keys(request, 5, 10, 0) extra_keys, _ = generate_block_hash_extra_keys(request, 5, 10, 0)
expected_embeds = prompt_embeds[5:10] expected_embeds = prompt_embeds[5:10]
expected_bytes = kv_cache_utils.tensor_data(expected_embeds).tobytes() expected_hash = hashlib.sha256(kv_cache_utils.tensor_data(expected_embeds)).digest()
assert extra_keys == (expected_bytes,) assert extra_keys == (expected_hash,)
def test_generate_block_hash_extra_keys_prompt_embeds_cached(monkeypatch):
prompt_embeds = torch.randn(10, 3)
request = make_request(
request_id="0",
prompt_token_ids=None,
mm_positions=None,
mm_hashes=None,
prompt_embeds=prompt_embeds,
block_size=20,
)
num_tensor_data_calls = 0
original_tensor_data = kv_cache_utils.tensor_data
def counting_tensor_data(tensor: torch.Tensor):
nonlocal num_tensor_data_calls
num_tensor_data_calls += 1
return original_tensor_data(tensor)
monkeypatch.setattr(kv_cache_utils, "tensor_data", counting_tensor_data)
extra_keys_1, _ = generate_block_hash_extra_keys(request, 0, 5, 0)
extra_keys_2, _ = generate_block_hash_extra_keys(request, 0, 5, 0)
assert extra_keys_1 == extra_keys_2
assert num_tensor_data_calls == 1
def test_generate_block_hash_extra_keys_different_prompt_embeds(): def test_generate_block_hash_extra_keys_different_prompt_embeds():
@@ -1858,22 +1886,26 @@ def test_request_block_hasher_with_prompt_embeds(hash_fn: Callable[[Any], bytes]
block_hashes = request.block_hashes block_hashes = request.block_hashes
assert len(block_hashes) == 2 assert len(block_hashes) == 2
block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes() block1_embeds_hash = hashlib.sha256(
tensor_data(prompt_embeds[:block_size])
).digest()
expected_hash1 = hash_fn( expected_hash1 = hash_fn(
( (
kv_cache_utils.NONE_HASH, kv_cache_utils.NONE_HASH,
tuple(prompt_token_ids[:block_size]), tuple(prompt_token_ids[:block_size]),
(block1_embeds_bytes,), (block1_embeds_hash,),
) )
) )
assert block_hashes[0] == expected_hash1 assert block_hashes[0] == expected_hash1
block2_embeds_bytes = tensor_data(prompt_embeds[block_size:num_tokens]).tobytes() block2_embeds_hash = hashlib.sha256(
tensor_data(prompt_embeds[block_size:num_tokens])
).digest()
expected_hash2 = hash_fn( expected_hash2 = hash_fn(
( (
block_hashes[0], block_hashes[0],
tuple(prompt_token_ids[block_size:num_tokens]), tuple(prompt_token_ids[block_size:num_tokens]),
(block2_embeds_bytes,), (block2_embeds_hash,),
) )
) )
assert block_hashes[1] == expected_hash2 assert block_hashes[1] == expected_hash2
@@ -1903,22 +1935,26 @@ def test_request_with_prompt_embeds_and_mm_inputs(hash_fn: Callable[[Any], bytes
block_hashes = request.block_hashes block_hashes = request.block_hashes
assert len(block_hashes) == 2 assert len(block_hashes) == 2
block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes() block1_embeds_hash = hashlib.sha256(
tensor_data(prompt_embeds[:block_size])
).digest()
expected_hash1 = hash_fn( expected_hash1 = hash_fn(
( (
kv_cache_utils.NONE_HASH, kv_cache_utils.NONE_HASH,
tuple(prompt_token_ids[:block_size]), tuple(prompt_token_ids[:block_size]),
("hash1", block1_embeds_bytes), ("hash1", block1_embeds_hash),
) )
) )
assert block_hashes[0] == expected_hash1 assert block_hashes[0] == expected_hash1
block2_embeds_bytes = tensor_data(prompt_embeds[block_size:num_tokens]).tobytes() block2_embeds_hash = hashlib.sha256(
tensor_data(prompt_embeds[block_size:num_tokens])
).digest()
expected_hash2 = hash_fn( expected_hash2 = hash_fn(
( (
block_hashes[0], block_hashes[0],
tuple(prompt_token_ids[block_size:num_tokens]), tuple(prompt_token_ids[block_size:num_tokens]),
("hash2", block2_embeds_bytes), ("hash2", block2_embeds_hash),
) )
) )
assert block_hashes[1] == expected_hash2 assert block_hashes[1] == expected_hash2

View File

@@ -60,6 +60,13 @@ class BlockStored(KVCacheEvent):
medium: str | None medium: str | None
lora_name: str | None lora_name: str | None
extra_keys: list[tuple[Any, ...] | None] | None = None
"""Extra keys used in block hash computation, one entry per block in
block_hashes. Each entry contains MM identifiers, LoRA name, cache_salt,
prompt embedding hashes, etc. for that specific block. Exposed for external
KV cache consumers to reconstruct block hashes.
"""
def __hash__(self) -> int: def __hash__(self) -> int:
return hash( return hash(
( (
@@ -69,6 +76,7 @@ class BlockStored(KVCacheEvent):
self.block_size, self.block_size,
self.lora_id, self.lora_id,
self.medium, self.medium,
tuple(self.extra_keys) if self.extra_keys else None,
) )
) )

View File

@@ -20,6 +20,7 @@ from vllm.v1.core.kv_cache_utils import (
ExternalBlockHash, ExternalBlockHash,
FreeKVCacheBlockQueue, FreeKVCacheBlockQueue,
KVCacheBlock, KVCacheBlock,
generate_block_hash_extra_keys,
get_block_hash, get_block_hash,
make_block_hash_with_group_id, make_block_hash_with_group_id,
maybe_convert_block_hash, maybe_convert_block_hash,
@@ -279,13 +280,31 @@ class BlockPool:
block_hashes[num_cached_blocks - 1] block_hashes[num_cached_blocks - 1]
) )
# Calculate token range for the blocks being cached
start_token_idx = num_cached_blocks * block_size
end_token_idx = num_full_blocks * block_size
# Generate extra keys for each block individually.
# Each block may have different extra_keys (e.g., different MM
# features, or cache_salt only for the first block).
# Skip null blocks to match the length of new_hashes.
extra_keys_list: list[tuple[Any, ...] | None] = []
curr_mm_idx = 0
for i in range(num_cached_blocks, num_full_blocks):
if blocks[i].is_null:
continue
block_start = i * block_size
block_end = block_start + block_size
extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
request, block_start, block_end, curr_mm_idx
)
extra_keys_list.append(extra_keys)
self.kv_event_queue.append( self.kv_event_queue.append(
BlockStored( BlockStored(
block_hashes=new_hashes, block_hashes=new_hashes,
parent_block_hash=parent_block_hash, parent_block_hash=parent_block_hash,
token_ids=request.all_token_ids[ token_ids=request.all_token_ids[start_token_idx:end_token_idx],
num_cached_blocks * block_size : num_full_blocks * block_size
],
block_size=block_size, block_size=block_size,
lora_id=request.lora_request.adapter_id lora_id=request.lora_request.adapter_id
if request.lora_request if request.lora_request
@@ -294,6 +313,7 @@ class BlockPool:
lora_name=request.lora_request.name lora_name=request.lora_request.name
if request.lora_request if request.lora_request
else None, else None,
extra_keys=extra_keys_list if extra_keys_list else None,
) )
) )

View File

@@ -3,6 +3,7 @@
"""KV-Cache Utilities.""" """KV-Cache Utilities."""
import copy import copy
import hashlib
import os import os
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator, Sequence from collections.abc import Callable, Iterable, Iterator, Sequence
@@ -475,14 +476,19 @@ def _gen_prompt_embeds_extra_hash_keys(
end_token_idx: The end token index of the block. end_token_idx: The end token index of the block.
Returns: Returns:
Return prompt embeddings data of the request if it has prompt embeds. Return a stable hash of the block prompt embeddings if prompt embeds
Return empty list otherwise. are present. Return empty list otherwise.
""" """
if request.prompt_embeds is None: if request.prompt_embeds is None:
return [] return []
block_prompt_embeds = request.prompt_embeds[start_token_idx:end_token_idx] block_range = (start_token_idx, end_token_idx)
embeds_bytes = tensor_data(block_prompt_embeds).tobytes() embeds_hash = request._prompt_embeds_per_block_hashes.get(block_range)
return [embeds_bytes] if embeds_hash is None:
block_prompt_embeds = request.prompt_embeds[start_token_idx:end_token_idx]
# Hash prompt embeds once per block and cache on request
embeds_hash = hashlib.sha256(tensor_data(block_prompt_embeds)).digest()
request._prompt_embeds_per_block_hashes[block_range] = embeds_hash
return [embeds_hash]
def generate_block_hash_extra_keys( def generate_block_hash_extra_keys(
@@ -490,7 +496,7 @@ def generate_block_hash_extra_keys(
) -> tuple[tuple[Any, ...] | None, int]: ) -> tuple[tuple[Any, ...] | None, int]:
"""Generate extra keys for the block hash. The extra keys can come from """Generate extra keys for the block hash. The extra keys can come from
the multi-modal inputs, request specific metadata (e.g., LoRA names), and the multi-modal inputs, request specific metadata (e.g., LoRA names), and
data from prompt embeddings. hashed data from prompt embeddings.
Args: Args:
request: The request object. request: The request object.

View File

@@ -114,6 +114,9 @@ class Request:
self.prompt_token_ids = prompt_token_ids self.prompt_token_ids = prompt_token_ids
self.prompt_embeds = prompt_embeds self.prompt_embeds = prompt_embeds
# Cache per-block prompt-embed hashes to avoid rehashing the same
# tensor slices when generating extra keys.
self._prompt_embeds_per_block_hashes: dict[tuple[int, int], bytes] = {}
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
prompt_token_ids, prompt_embeds prompt_token_ids, prompt_embeds
) )