Support SHA256 as hash function in prefix caching (#15297)
Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com>
This commit is contained in:
@@ -7,7 +7,7 @@ import pytest
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import cdiv
|
||||
from vllm.utils import cdiv, sha256
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
|
||||
@@ -39,16 +39,21 @@ def make_request(request_id,
|
||||
)
|
||||
|
||||
|
||||
def test_prefill():
|
||||
@pytest.mark.parametrize("hash_algo", ["sha256", "hash"])
|
||||
def test_prefill(hash_algo):
|
||||
manager = KVCacheManager(
|
||||
block_size=16,
|
||||
num_gpu_blocks=10,
|
||||
max_model_len=8192,
|
||||
sliding_window=None,
|
||||
enable_caching=True,
|
||||
caching_hash_algo=hash_algo,
|
||||
num_preallocate_tokens=16,
|
||||
)
|
||||
|
||||
# choose the hash function according to the parameter
|
||||
hash_fn = sha256 if hash_algo == "sha256" else hash
|
||||
|
||||
# Complete 3 blocks (48 tokens)
|
||||
common_token_ids = [i for i in range(3) for _ in range(16)]
|
||||
|
||||
@@ -68,7 +73,8 @@ def test_prefill():
|
||||
parent_block_hash = None
|
||||
for block_id in (0, 1, 2):
|
||||
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
|
||||
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
|
||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
||||
block_tokens)
|
||||
assert manager.block_pool.blocks[block_id].block_hash == block_hash
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
parent_block_hash = block_hash.hash_value
|
||||
@@ -163,6 +169,8 @@ def test_prefill_plp():
|
||||
enable_caching=True,
|
||||
num_preallocate_tokens=16,
|
||||
)
|
||||
# the default hash function is hash
|
||||
hash_fn = hash
|
||||
|
||||
# Complete 3 blocks (48 tokens)
|
||||
common_token_ids = [i for i in range(3) for _ in range(16)]
|
||||
@@ -185,7 +193,8 @@ def test_prefill_plp():
|
||||
parent_block_hash = None
|
||||
for block_id in (0, 1, 2):
|
||||
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
|
||||
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
|
||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
||||
block_tokens)
|
||||
assert manager.block_pool.blocks[block_id].block_hash == block_hash
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
parent_block_hash = block_hash.hash_value
|
||||
@@ -522,7 +531,8 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
|
||||
assert len(blocks) == 1 + num_preallocated_blocks
|
||||
|
||||
|
||||
def test_cache_blocks():
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, hash])
|
||||
def test_cache_blocks(hash_fn):
|
||||
"""
|
||||
This is a unit test that tests the correctness of the _cache_full_blocks
|
||||
function of KVCacheManager.
|
||||
@@ -550,6 +560,7 @@ def test_cache_blocks():
|
||||
num_cached_blocks=0,
|
||||
num_full_blocks=2,
|
||||
block_size=block_size,
|
||||
hash_fn=hash_fn,
|
||||
)
|
||||
|
||||
assert len(block_pool.cached_block_hash_to_block) == 2
|
||||
@@ -564,6 +575,7 @@ def test_cache_blocks():
|
||||
num_cached_blocks=2,
|
||||
num_full_blocks=3,
|
||||
block_size=block_size,
|
||||
hash_fn=hash_fn,
|
||||
)
|
||||
assert len(block_pool.cached_block_hash_to_block) == 3
|
||||
assert blocks[0].block_hash is not None
|
||||
|
||||
Reference in New Issue
Block a user