[Prefix Cache] Add reproducible prefix-cache block hashing using SHA-256 + CBOR (64bit) (#20511)

Signed-off-by: Maroon Ayoub <maroon.ayoub@ibm.com>
This commit is contained in:
Maroon Ayoub
2025-07-14 05:45:31 +03:00
committed by GitHub
parent 8632e831ba
commit 66f6fbd393
8 changed files with 88 additions and 28 deletions

View File

@@ -47,3 +47,4 @@ python-json-logger # Used by logging as per examples/others/logging_configuratio
scipy # Required for phi-4-multimodal-instruct scipy # Required for phi-4-multimodal-instruct
ninja # Required for xgrammar, rocm, tpu, xpu ninja # Required for xgrammar, rocm, tpu, xpu
pybase64 # fast base64 implementation pybase64 # fast base64 implementation
cbor2 # Required for cross-language serialization of hashable objects

View File

@@ -11,6 +11,7 @@ ruff
# Required for argparse hook only # Required for argparse hook only
-f https://download.pytorch.org/whl/cpu -f https://download.pytorch.org/whl/cpu
cachetools cachetools
cbor2
cloudpickle cloudpickle
fastapi fastapi
msgspec msgspec

View File

@@ -8,7 +8,7 @@ import torch
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes, sha256 from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit
from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.kv_cache_manager import KVCacheManager
# disable yapf here as it formats differently than isort such that both fail # disable yapf here as it formats differently than isort such that both fail
# yapf: disable # yapf: disable
@@ -16,7 +16,8 @@ from vllm.v1.core.kv_cache_utils import (
FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics, FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
estimate_max_model_len, generate_block_hash_extra_keys, estimate_max_model_len, generate_block_hash_extra_keys,
get_kv_cache_config, get_max_concurrency_for_kv_cache_config, get_kv_cache_config, get_max_concurrency_for_kv_cache_config,
hash_block_tokens, hash_request_tokens, unify_kv_cache_configs) hash_block_tokens, hash_request_tokens, init_none_hash,
unify_kv_cache_configs)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor, KVCacheGroupSpec, KVCacheTensor,
SlidingWindowSpec) SlidingWindowSpec)
@@ -78,24 +79,27 @@ def new_sliding_window_spec(block_size=16,
sliding_window=sliding_window) sliding_window=sliding_window)
def test_none_hash(monkeypatch): @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
def test_none_hash(monkeypatch, hash_fn):
import vllm.v1.core.kv_cache_utils import vllm.v1.core.kv_cache_utils
# case 1: PYTHONHASHSEED is not set, use random # case 1: PYTHONHASHSEED is not set, use random
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.delenv('PYTHONHASHSEED', raising=False) m.delenv('PYTHONHASHSEED', raising=False)
reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils) reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils)
reloaded_kv_cache_utils.init_none_hash(hash_fn)
assert reloaded_kv_cache_utils.NONE_HASH is not None assert reloaded_kv_cache_utils.NONE_HASH is not None
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int) assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int)
assert reloaded_kv_cache_utils.NONE_HASH != 0 assert reloaded_kv_cache_utils.NONE_HASH != 0
# case 2: PYTHONHASHSEED is set, use the seed # case 2: PYTHONHASHSEED is set, use the seed and hash_fn
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv('PYTHONHASHSEED', 'python hash seed') m.setenv('PYTHONHASHSEED', 'python hash seed')
reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils) reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils)
reloaded_kv_cache_utils.init_none_hash(hash_fn)
assert reloaded_kv_cache_utils.NONE_HASH is not None assert reloaded_kv_cache_utils.NONE_HASH is not None
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int) assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int)
assert sha256('python hash seed') == reloaded_kv_cache_utils.NONE_HASH assert hash_fn('python hash seed') == reloaded_kv_cache_utils.NONE_HASH
def test_kv_cache_block(): def test_kv_cache_block():
@@ -287,9 +291,10 @@ def test_generate_block_hash_extra_keys_cache_salt():
assert next_mm_idx == 1 assert next_mm_idx == 1
@pytest.mark.parametrize("hash_fn", [sha256, hash]) @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
def test_hash_block_tokens(hash_fn): def test_hash_block_tokens(hash_fn):
import vllm.v1.core.kv_cache_utils import vllm.v1.core.kv_cache_utils
init_none_hash(hash_fn)
parent_block_hash = 123 parent_block_hash = 123
curr_block_token_ids = (1, 2, 3) curr_block_token_ids = (1, 2, 3)
extra_keys = ("key1", "key2") extra_keys = ("key1", "key2")
@@ -303,9 +308,10 @@ def test_hash_block_tokens(hash_fn):
assert block_hash.extra_keys == extra_keys assert block_hash.extra_keys == extra_keys
@pytest.mark.parametrize("hash_fn", [sha256, hash]) @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
def test_hash_request_tokens(hash_fn): def test_hash_request_tokens(hash_fn):
import vllm.v1.core.kv_cache_utils import vllm.v1.core.kv_cache_utils
init_none_hash(hash_fn)
request = make_request( request = make_request(
request_id=0, request_id=0,
prompt_token_ids=[_ for _ in range(6)], prompt_token_ids=[_ for _ in range(6)],
@@ -332,8 +338,10 @@ def test_hash_request_tokens(hash_fn):
assert block_hashes[1].extra_keys == ("hash2", ) assert block_hashes[1].extra_keys == ("hash2", )
@pytest.mark.parametrize("hash_fn", [sha256, hash]) @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
def test_hash_tokens_different_mm_input(hash_fn): def test_hash_tokens_different_mm_input(hash_fn):
init_none_hash(hash_fn)
request1 = make_request( request1 = make_request(
request_id=0, request_id=0,
prompt_token_ids=[_ for _ in range(6)], prompt_token_ids=[_ for _ in range(6)],
@@ -359,8 +367,10 @@ def test_hash_tokens_different_mm_input(hash_fn):
assert block_hashes1[1] != block_hashes2[1] assert block_hashes1[1] != block_hashes2[1]
@pytest.mark.parametrize("hash_fn", [sha256, hash]) @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
def test_hash_request_tokens_no_mm_inputs(hash_fn): def test_hash_request_tokens_no_mm_inputs(hash_fn):
init_none_hash(hash_fn)
request = make_request( request = make_request(
request_id=0, request_id=0,
prompt_token_ids=[_ for _ in range(6)], prompt_token_ids=[_ for _ in range(6)],
@@ -916,4 +926,4 @@ def test_get_kv_cache_config():
], ],
kv_cache_groups=[ kv_cache_groups=[
KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()) KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec())
]) ])

View File

@@ -11,11 +11,12 @@ import torch
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import sha256 from vllm.utils import sha256, sha256_cbor_64bit
from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
KVCacheBlock, hash_block_tokens) KVCacheBlock, hash_block_tokens,
init_none_hash)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, SlidingWindowSpec) KVCacheGroupSpec, SlidingWindowSpec)
@@ -91,7 +92,7 @@ def make_kv_cache_config_hybrid_model(block_size: int,
) )
@pytest.mark.parametrize("hash_algo", ["sha256", "hash"]) @pytest.mark.parametrize("hash_algo", ["sha256", "sha256_cbor_64bit", "hash"])
def test_prefill(hash_algo): def test_prefill(hash_algo):
manager = KVCacheManager( manager = KVCacheManager(
make_kv_cache_config(16, 11), make_kv_cache_config(16, 11),
@@ -101,7 +102,8 @@ def test_prefill(hash_algo):
) )
# choose the hash function according to the parameter # choose the hash function according to the parameter
hash_fn = sha256 if hash_algo == "sha256" else hash hash_fn = (sha256_cbor_64bit if hash_algo == "sha256_cbor_64bit" else
sha256 if hash_algo == "sha256" else hash)
# Complete 3 blocks (48 tokens) # Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(16)] common_token_ids = [i for i in range(3) for _ in range(16)]
@@ -696,12 +698,14 @@ def test_basic_prefix_caching_disabled():
assert not blocks assert not blocks
@pytest.mark.parametrize("hash_fn", [sha256, hash]) @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
def test_cache_blocks(hash_fn): def test_cache_blocks(hash_fn):
""" """
This is a unit test that tests the correctness of the _cache_full_blocks This is a unit test that tests the correctness of the _cache_full_blocks
function of KVCacheManager. function of KVCacheManager.
""" """
init_none_hash(hash_fn)
block_size = 4 block_size = 4
block_pool = BlockPool( block_pool = BlockPool(
num_gpu_blocks=5, num_gpu_blocks=5,

View File

@@ -1564,7 +1564,7 @@ class ModelConfig:
BlockSize = Literal[1, 8, 16, 32, 64, 128] BlockSize = Literal[1, 8, 16, 32, 64, 128]
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"] CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"]
PrefixCachingHashAlgo = Literal["builtin", "sha256"] PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"]
@config @config
@@ -1609,7 +1609,12 @@ class CacheConfig:
prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin" prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin"
"""Set the hash algorithm for prefix caching:\n """Set the hash algorithm for prefix caching:\n
- "builtin" is Python's built-in hash.\n - "builtin" is Python's built-in hash.\n
- "sha256" is collision resistant but with certain overheads.""" - "sha256" is collision resistant but with certain overheads.
This option uses Pickle for object serialization before hashing.\n
- "sha256_cbor_64bit" provides a reproducible, cross-language compatible
hash. It serializes objects using canonical CBOR and hashes them with
SHA-256. The resulting hash consists of the lower 64 bits of the SHA-256
digest."""
cpu_offload_gb: float = 0 cpu_offload_gb: float = 0
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means """The space in GiB to offload to CPU, per GPU. Default is 0, which means
no offloading. Intuitively, this argument can be seen as a virtual way to no offloading. Intuitively, this argument can be seen as a virtual way to

View File

@@ -52,6 +52,7 @@ from urllib.parse import urlparse
from uuid import uuid4 from uuid import uuid4
import cachetools import cachetools
import cbor2
import cloudpickle import cloudpickle
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
@@ -3177,6 +3178,29 @@ def sha256(input) -> int:
byteorder="big") byteorder="big")
def sha256_cbor_64bit(input) -> int:
"""
Hash objects using CBOR serialization and SHA-256, then truncate to 64bits.
This option is useful for non-Python-dependent serialization and hashing.
Args:
input: Object to be serialized and hashed. Supported types include
basic Python types and complex structures like lists, tuples, and
dictionaries.
Custom classes must implement CBOR serialization methods.
Returns:
An integer in the range [0, 2^64-1] representing the lower 64 bits
of the SHA-256 hash of the CBOR serialized input.
"""
input_bytes = cbor2.dumps(input, canonical=True)
full_hash = int.from_bytes(hashlib.sha256(input_bytes).digest(),
byteorder="big")
return full_hash & ((1 << 64) - 1)
def is_torch_equal_or_newer(target: str) -> bool: def is_torch_equal_or_newer(target: str) -> bool:
"""Check if the installed torch version is >= the target version. """Check if the installed torch version is >= the target version.

View File

@@ -7,10 +7,10 @@ from typing import Optional
from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_events import KVCacheEvent
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import sha256 from vllm.utils import sha256, sha256_cbor_64bit
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
hash_request_tokens) hash_request_tokens, init_none_hash)
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
@@ -79,7 +79,10 @@ class KVCacheManager:
self.max_model_len = max_model_len self.max_model_len = max_model_len
self.enable_caching = enable_caching self.enable_caching = enable_caching
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash self.caching_hash_fn = (
sha256_cbor_64bit if caching_hash_algo == "sha256_cbor_64bit" else
sha256 if caching_hash_algo == "sha256" else hash)
init_none_hash(self.caching_hash_fn)
self.use_eagle = use_eagle self.use_eagle = use_eagle
self.log_stats = log_stats self.log_stats = log_stats
# FIXME: make prefix cache stats conditional on log_stats # FIXME: make prefix cache stats conditional on log_stats

View File

@@ -10,7 +10,7 @@ from typing import Any, Callable, NamedTuple, Optional
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import GiB_bytes, cdiv, sha256 from vllm.utils import GiB_bytes, cdiv, sha256_cbor_64bit
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec, KVCacheGroupSpec, KVCacheSpec,
KVCacheTensor, SlidingWindowSpec) KVCacheTensor, SlidingWindowSpec)
@@ -46,18 +46,30 @@ class BlockHashWithGroupId(NamedTuple):
return self.block_hash.hash_value return self.block_hash.hash_value
# The hash seed for the first block of the prefix block sequence. # The hash seed for the first block of any prefix block sequence.
#
# Even if the hash function is the builtin hash(), we use sha256 to generate
# the initial hash to simplify the code. This is not performance critical
# as it is done one per process.
# #
# We use a random value to avoid hash collisions or PYTHONHASHSEED environment # We use a random value to avoid hash collisions or PYTHONHASHSEED environment
# variable if set such that processes can share the seed if needed. # variable if set such that processes can share the seed if needed.
# This aligns with the behavior of Python's hash() function, which also uses # This aligns with the behavior of Python's hash() function, which also uses
# a random seed if PYTHONHASHSEED is not set. # a random seed if PYTHONHASHSEED is not set.
NONE_HASH = int.from_bytes(os.urandom(32), byteorder="big") if os.getenv( #
"PYTHONHASHSEED") is None else sha256(os.getenv("PYTHONHASHSEED")) # The function `init_none_hash` initializes this variable globally.
NONE_HASH: int
def init_none_hash(hash_fn: Callable):
global NONE_HASH
hash_seed = os.getenv("PYTHONHASHSEED")
if hash_seed is None and hash_fn is sha256_cbor_64bit:
logger.warning(
"PYTHONHASHSEED is not set. This will lead to non-reproducible "
"block-hashes when using sha256_cbor_64bit as the hash function."
"Consider setting PYTHONHASHSEED to a fixed value for "
"reproducibility.")
NONE_HASH = (int.from_bytes(os.urandom(32), byteorder="big")
if hash_seed is None else hash_fn(hash_seed))
class PrefixCachingMetrics: class PrefixCachingMetrics: