[feat] Add per-block extra_keys to KV events (#33304)
Signed-off-by: zhongdaor-nv <zhongdaor@nvidia.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -37,6 +37,12 @@ class BlockStored(KVCacheEvent):
|
|||||||
medium: str | None
|
medium: str | None
|
||||||
lora_name: str | None
|
lora_name: str | None
|
||||||
|
|
||||||
|
extra_keys: list[tuple[Any, ...] | None] | None = None
|
||||||
|
"""Extra keys used in block hash computation, one entry per block in
|
||||||
|
block_hashes. Each entry contains MM identifiers, LoRA name, cache_salt,
|
||||||
|
prompt embeddings data, etc. for that specific block.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class BlockRemoved(KVCacheEvent):
|
class BlockRemoved(KVCacheEvent):
|
||||||
block_hashes: list[ExternalBlockHash]
|
block_hashes: list[ExternalBlockHash]
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import hashlib
|
||||||
import importlib
|
import importlib
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -498,14 +499,41 @@ def test_generate_block_hash_extra_keys_prompt_embeds():
|
|||||||
# Test with prompt embeds for the first block
|
# Test with prompt embeds for the first block
|
||||||
extra_keys, _ = generate_block_hash_extra_keys(request, 0, 5, 0)
|
extra_keys, _ = generate_block_hash_extra_keys(request, 0, 5, 0)
|
||||||
expected_embeds = prompt_embeds[0:5]
|
expected_embeds = prompt_embeds[0:5]
|
||||||
expected_bytes = kv_cache_utils.tensor_data(expected_embeds).tobytes()
|
expected_hash = hashlib.sha256(kv_cache_utils.tensor_data(expected_embeds)).digest()
|
||||||
assert extra_keys == (expected_bytes,)
|
assert extra_keys == (expected_hash,)
|
||||||
|
|
||||||
# Test with prompt embeds for the second block
|
# Test with prompt embeds for the second block
|
||||||
extra_keys, _ = generate_block_hash_extra_keys(request, 5, 10, 0)
|
extra_keys, _ = generate_block_hash_extra_keys(request, 5, 10, 0)
|
||||||
expected_embeds = prompt_embeds[5:10]
|
expected_embeds = prompt_embeds[5:10]
|
||||||
expected_bytes = kv_cache_utils.tensor_data(expected_embeds).tobytes()
|
expected_hash = hashlib.sha256(kv_cache_utils.tensor_data(expected_embeds)).digest()
|
||||||
assert extra_keys == (expected_bytes,)
|
assert extra_keys == (expected_hash,)
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_block_hash_extra_keys_prompt_embeds_cached(monkeypatch):
|
||||||
|
prompt_embeds = torch.randn(10, 3)
|
||||||
|
request = make_request(
|
||||||
|
request_id="0",
|
||||||
|
prompt_token_ids=None,
|
||||||
|
mm_positions=None,
|
||||||
|
mm_hashes=None,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
block_size=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_tensor_data_calls = 0
|
||||||
|
original_tensor_data = kv_cache_utils.tensor_data
|
||||||
|
|
||||||
|
def counting_tensor_data(tensor: torch.Tensor):
|
||||||
|
nonlocal num_tensor_data_calls
|
||||||
|
num_tensor_data_calls += 1
|
||||||
|
return original_tensor_data(tensor)
|
||||||
|
|
||||||
|
monkeypatch.setattr(kv_cache_utils, "tensor_data", counting_tensor_data)
|
||||||
|
|
||||||
|
extra_keys_1, _ = generate_block_hash_extra_keys(request, 0, 5, 0)
|
||||||
|
extra_keys_2, _ = generate_block_hash_extra_keys(request, 0, 5, 0)
|
||||||
|
assert extra_keys_1 == extra_keys_2
|
||||||
|
assert num_tensor_data_calls == 1
|
||||||
|
|
||||||
|
|
||||||
def test_generate_block_hash_extra_keys_different_prompt_embeds():
|
def test_generate_block_hash_extra_keys_different_prompt_embeds():
|
||||||
@@ -1858,22 +1886,26 @@ def test_request_block_hasher_with_prompt_embeds(hash_fn: Callable[[Any], bytes]
|
|||||||
block_hashes = request.block_hashes
|
block_hashes = request.block_hashes
|
||||||
assert len(block_hashes) == 2
|
assert len(block_hashes) == 2
|
||||||
|
|
||||||
block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes()
|
block1_embeds_hash = hashlib.sha256(
|
||||||
|
tensor_data(prompt_embeds[:block_size])
|
||||||
|
).digest()
|
||||||
expected_hash1 = hash_fn(
|
expected_hash1 = hash_fn(
|
||||||
(
|
(
|
||||||
kv_cache_utils.NONE_HASH,
|
kv_cache_utils.NONE_HASH,
|
||||||
tuple(prompt_token_ids[:block_size]),
|
tuple(prompt_token_ids[:block_size]),
|
||||||
(block1_embeds_bytes,),
|
(block1_embeds_hash,),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
assert block_hashes[0] == expected_hash1
|
assert block_hashes[0] == expected_hash1
|
||||||
|
|
||||||
block2_embeds_bytes = tensor_data(prompt_embeds[block_size:num_tokens]).tobytes()
|
block2_embeds_hash = hashlib.sha256(
|
||||||
|
tensor_data(prompt_embeds[block_size:num_tokens])
|
||||||
|
).digest()
|
||||||
expected_hash2 = hash_fn(
|
expected_hash2 = hash_fn(
|
||||||
(
|
(
|
||||||
block_hashes[0],
|
block_hashes[0],
|
||||||
tuple(prompt_token_ids[block_size:num_tokens]),
|
tuple(prompt_token_ids[block_size:num_tokens]),
|
||||||
(block2_embeds_bytes,),
|
(block2_embeds_hash,),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
assert block_hashes[1] == expected_hash2
|
assert block_hashes[1] == expected_hash2
|
||||||
@@ -1903,22 +1935,26 @@ def test_request_with_prompt_embeds_and_mm_inputs(hash_fn: Callable[[Any], bytes
|
|||||||
block_hashes = request.block_hashes
|
block_hashes = request.block_hashes
|
||||||
assert len(block_hashes) == 2
|
assert len(block_hashes) == 2
|
||||||
|
|
||||||
block1_embeds_bytes = tensor_data(prompt_embeds[:block_size]).tobytes()
|
block1_embeds_hash = hashlib.sha256(
|
||||||
|
tensor_data(prompt_embeds[:block_size])
|
||||||
|
).digest()
|
||||||
expected_hash1 = hash_fn(
|
expected_hash1 = hash_fn(
|
||||||
(
|
(
|
||||||
kv_cache_utils.NONE_HASH,
|
kv_cache_utils.NONE_HASH,
|
||||||
tuple(prompt_token_ids[:block_size]),
|
tuple(prompt_token_ids[:block_size]),
|
||||||
("hash1", block1_embeds_bytes),
|
("hash1", block1_embeds_hash),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
assert block_hashes[0] == expected_hash1
|
assert block_hashes[0] == expected_hash1
|
||||||
|
|
||||||
block2_embeds_bytes = tensor_data(prompt_embeds[block_size:num_tokens]).tobytes()
|
block2_embeds_hash = hashlib.sha256(
|
||||||
|
tensor_data(prompt_embeds[block_size:num_tokens])
|
||||||
|
).digest()
|
||||||
expected_hash2 = hash_fn(
|
expected_hash2 = hash_fn(
|
||||||
(
|
(
|
||||||
block_hashes[0],
|
block_hashes[0],
|
||||||
tuple(prompt_token_ids[block_size:num_tokens]),
|
tuple(prompt_token_ids[block_size:num_tokens]),
|
||||||
("hash2", block2_embeds_bytes),
|
("hash2", block2_embeds_hash),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
assert block_hashes[1] == expected_hash2
|
assert block_hashes[1] == expected_hash2
|
||||||
|
|||||||
@@ -60,6 +60,13 @@ class BlockStored(KVCacheEvent):
|
|||||||
medium: str | None
|
medium: str | None
|
||||||
lora_name: str | None
|
lora_name: str | None
|
||||||
|
|
||||||
|
extra_keys: list[tuple[Any, ...] | None] | None = None
|
||||||
|
"""Extra keys used in block hash computation, one entry per block in
|
||||||
|
block_hashes. Each entry contains MM identifiers, LoRA name, cache_salt,
|
||||||
|
prompt embedding hashes, etc. for that specific block. Exposed for external
|
||||||
|
KV cache consumers to reconstruct block hashes.
|
||||||
|
"""
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
return hash(
|
return hash(
|
||||||
(
|
(
|
||||||
@@ -69,6 +76,7 @@ class BlockStored(KVCacheEvent):
|
|||||||
self.block_size,
|
self.block_size,
|
||||||
self.lora_id,
|
self.lora_id,
|
||||||
self.medium,
|
self.medium,
|
||||||
|
tuple(self.extra_keys) if self.extra_keys else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from vllm.v1.core.kv_cache_utils import (
|
|||||||
ExternalBlockHash,
|
ExternalBlockHash,
|
||||||
FreeKVCacheBlockQueue,
|
FreeKVCacheBlockQueue,
|
||||||
KVCacheBlock,
|
KVCacheBlock,
|
||||||
|
generate_block_hash_extra_keys,
|
||||||
get_block_hash,
|
get_block_hash,
|
||||||
make_block_hash_with_group_id,
|
make_block_hash_with_group_id,
|
||||||
maybe_convert_block_hash,
|
maybe_convert_block_hash,
|
||||||
@@ -279,13 +280,31 @@ class BlockPool:
|
|||||||
block_hashes[num_cached_blocks - 1]
|
block_hashes[num_cached_blocks - 1]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Calculate token range for the blocks being cached
|
||||||
|
start_token_idx = num_cached_blocks * block_size
|
||||||
|
end_token_idx = num_full_blocks * block_size
|
||||||
|
|
||||||
|
# Generate extra keys for each block individually.
|
||||||
|
# Each block may have different extra_keys (e.g., different MM
|
||||||
|
# features, or cache_salt only for the first block).
|
||||||
|
# Skip null blocks to match the length of new_hashes.
|
||||||
|
extra_keys_list: list[tuple[Any, ...] | None] = []
|
||||||
|
curr_mm_idx = 0
|
||||||
|
for i in range(num_cached_blocks, num_full_blocks):
|
||||||
|
if blocks[i].is_null:
|
||||||
|
continue
|
||||||
|
block_start = i * block_size
|
||||||
|
block_end = block_start + block_size
|
||||||
|
extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
|
||||||
|
request, block_start, block_end, curr_mm_idx
|
||||||
|
)
|
||||||
|
extra_keys_list.append(extra_keys)
|
||||||
|
|
||||||
self.kv_event_queue.append(
|
self.kv_event_queue.append(
|
||||||
BlockStored(
|
BlockStored(
|
||||||
block_hashes=new_hashes,
|
block_hashes=new_hashes,
|
||||||
parent_block_hash=parent_block_hash,
|
parent_block_hash=parent_block_hash,
|
||||||
token_ids=request.all_token_ids[
|
token_ids=request.all_token_ids[start_token_idx:end_token_idx],
|
||||||
num_cached_blocks * block_size : num_full_blocks * block_size
|
|
||||||
],
|
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
lora_id=request.lora_request.adapter_id
|
lora_id=request.lora_request.adapter_id
|
||||||
if request.lora_request
|
if request.lora_request
|
||||||
@@ -294,6 +313,7 @@ class BlockPool:
|
|||||||
lora_name=request.lora_request.name
|
lora_name=request.lora_request.name
|
||||||
if request.lora_request
|
if request.lora_request
|
||||||
else None,
|
else None,
|
||||||
|
extra_keys=extra_keys_list if extra_keys_list else None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
"""KV-Cache Utilities."""
|
"""KV-Cache Utilities."""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import hashlib
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Callable, Iterable, Iterator, Sequence
|
from collections.abc import Callable, Iterable, Iterator, Sequence
|
||||||
@@ -475,14 +476,19 @@ def _gen_prompt_embeds_extra_hash_keys(
|
|||||||
end_token_idx: The end token index of the block.
|
end_token_idx: The end token index of the block.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Return prompt embeddings data of the request if it has prompt embeds.
|
Return a stable hash of the block prompt embeddings if prompt embeds
|
||||||
Return empty list otherwise.
|
are present. Return empty list otherwise.
|
||||||
"""
|
"""
|
||||||
if request.prompt_embeds is None:
|
if request.prompt_embeds is None:
|
||||||
return []
|
return []
|
||||||
block_prompt_embeds = request.prompt_embeds[start_token_idx:end_token_idx]
|
block_range = (start_token_idx, end_token_idx)
|
||||||
embeds_bytes = tensor_data(block_prompt_embeds).tobytes()
|
embeds_hash = request._prompt_embeds_per_block_hashes.get(block_range)
|
||||||
return [embeds_bytes]
|
if embeds_hash is None:
|
||||||
|
block_prompt_embeds = request.prompt_embeds[start_token_idx:end_token_idx]
|
||||||
|
# Hash prompt embeds once per block and cache on request
|
||||||
|
embeds_hash = hashlib.sha256(tensor_data(block_prompt_embeds)).digest()
|
||||||
|
request._prompt_embeds_per_block_hashes[block_range] = embeds_hash
|
||||||
|
return [embeds_hash]
|
||||||
|
|
||||||
|
|
||||||
def generate_block_hash_extra_keys(
|
def generate_block_hash_extra_keys(
|
||||||
@@ -490,7 +496,7 @@ def generate_block_hash_extra_keys(
|
|||||||
) -> tuple[tuple[Any, ...] | None, int]:
|
) -> tuple[tuple[Any, ...] | None, int]:
|
||||||
"""Generate extra keys for the block hash. The extra keys can come from
|
"""Generate extra keys for the block hash. The extra keys can come from
|
||||||
the multi-modal inputs, request specific metadata (e.g., LoRA names), and
|
the multi-modal inputs, request specific metadata (e.g., LoRA names), and
|
||||||
data from prompt embeddings.
|
hashed data from prompt embeddings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: The request object.
|
request: The request object.
|
||||||
|
|||||||
@@ -114,6 +114,9 @@ class Request:
|
|||||||
|
|
||||||
self.prompt_token_ids = prompt_token_ids
|
self.prompt_token_ids = prompt_token_ids
|
||||||
self.prompt_embeds = prompt_embeds
|
self.prompt_embeds = prompt_embeds
|
||||||
|
# Cache per-block prompt-embed hashes to avoid rehashing the same
|
||||||
|
# tensor slices when generating extra keys.
|
||||||
|
self._prompt_embeds_per_block_hashes: dict[tuple[int, int], bytes] = {}
|
||||||
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||||
prompt_token_ids, prompt_embeds
|
prompt_token_ids, prompt_embeds
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user