fix: disambiguate multimodal prefix cache keys (#36708)
Signed-off-by: tianshu.yu <tianshuyu.formal@gmail.com>
This commit is contained in:
committed by
GitHub
parent
e5a77a5015
commit
269bf46d99
@@ -447,12 +447,12 @@ def test_generate_block_hash_extra_keys():
|
||||
|
||||
# Test with no extra keys
|
||||
extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 0, 5, 0)
|
||||
assert extra_keys == ("hash1",)
|
||||
assert extra_keys == (("hash1", 0),)
|
||||
assert next_mm_idx == 1
|
||||
|
||||
# Test with partial overlap
|
||||
extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 3, 8, 0)
|
||||
assert extra_keys == ("hash1",)
|
||||
assert extra_keys == (("hash1", -3),)
|
||||
assert next_mm_idx == 1
|
||||
|
||||
# Test with no overlap
|
||||
@@ -462,7 +462,7 @@ def test_generate_block_hash_extra_keys():
|
||||
|
||||
# Test with multiple extra keys
|
||||
extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 0, 15, 0)
|
||||
assert extra_keys == ("hash1", "hash2")
|
||||
assert extra_keys == (("hash1", 0), ("hash2", 10))
|
||||
assert next_mm_idx == 2
|
||||
|
||||
|
||||
@@ -513,7 +513,7 @@ def test_generate_block_hash_extra_keys_cache_salt():
|
||||
|
||||
# Test with no extra keys
|
||||
extra_keys, next_mm_idx = generate_block_hash_extra_keys(request_mm, 0, 5, 0)
|
||||
assert extra_keys == ("hash1", "salt")
|
||||
assert extra_keys == (("hash1", 0), "salt")
|
||||
assert next_mm_idx == 1
|
||||
|
||||
|
||||
@@ -637,8 +637,10 @@ def test_request_block_hasher(hash_fn):
|
||||
|
||||
block_hashes = request.block_hashes
|
||||
assert len(block_hashes) == 2
|
||||
assert block_hashes[0] == hash_fn((kv_cache_utils.NONE_HASH, (0, 1, 2), ("hash1",)))
|
||||
assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), ("hash2",)))
|
||||
assert block_hashes[0] == hash_fn(
|
||||
(kv_cache_utils.NONE_HASH, (0, 1, 2), (("hash1", 0),))
|
||||
)
|
||||
assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), (("hash2", 0),)))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
|
||||
@@ -1973,7 +1975,7 @@ def test_request_with_prompt_embeds_and_mm_inputs(hash_fn: Callable[[Any], bytes
|
||||
(
|
||||
kv_cache_utils.NONE_HASH,
|
||||
tuple(prompt_token_ids[:block_size]),
|
||||
("hash1", block1_embeds_hash),
|
||||
(("hash1", 0), block1_embeds_hash),
|
||||
)
|
||||
)
|
||||
assert block_hashes[0] == expected_hash1
|
||||
@@ -1985,7 +1987,7 @@ def test_request_with_prompt_embeds_and_mm_inputs(hash_fn: Callable[[Any], bytes
|
||||
(
|
||||
block_hashes[0],
|
||||
tuple(prompt_token_ids[block_size:num_tokens]),
|
||||
("hash2", block2_embeds_hash),
|
||||
(("hash2", 0), block2_embeds_hash),
|
||||
)
|
||||
)
|
||||
assert block_hashes[1] == expected_hash2
|
||||
|
||||
@@ -1570,20 +1570,24 @@ def test_mm_prefix_caching():
|
||||
block_hashes = req0.block_hashes
|
||||
assert len(block_hashes) == 3
|
||||
assert block_hashes[0] == sha256(
|
||||
(kv_cache_utils.NONE_HASH, tuple(all_token_ids[:block_size]), ("aaa",))
|
||||
(
|
||||
kv_cache_utils.NONE_HASH,
|
||||
tuple(all_token_ids[:block_size]),
|
||||
(("aaa", 11),),
|
||||
)
|
||||
)
|
||||
assert block_hashes[1] == sha256(
|
||||
(
|
||||
block_hashes[0],
|
||||
tuple(all_token_ids[block_size : block_size * 2]),
|
||||
("aaa", "bbb"),
|
||||
(("aaa", -5), ("bbb", 14)),
|
||||
)
|
||||
)
|
||||
assert block_hashes[2] == sha256(
|
||||
(
|
||||
block_hashes[1],
|
||||
tuple(all_token_ids[block_size * 2 : block_size * 3]),
|
||||
("bbb",),
|
||||
(("bbb", -2),),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1603,7 +1607,11 @@ def test_mm_prefix_caching():
|
||||
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
|
||||
assert len(block_hashes) == 4
|
||||
assert block_hashes[3] == sha256(
|
||||
(block_hashes[2], tuple(all_token_ids[3 * block_size :] + [8] * 5), ("ccc",))
|
||||
(
|
||||
block_hashes[2],
|
||||
tuple(all_token_ids[3 * block_size :] + [8] * 5),
|
||||
(("ccc", 0),),
|
||||
)
|
||||
)
|
||||
|
||||
# Cache hit.
|
||||
|
||||
@@ -413,7 +413,7 @@ def _gen_mm_extra_hash_keys(
|
||||
# We do not need to check all mm inputs if the start token index is out of
|
||||
# range. This usually happens in the late prefill phase and decoding phase.
|
||||
last_pos = mm_features[-1].mm_position
|
||||
if last_pos.offset + last_pos.length < start_token_idx:
|
||||
if last_pos.offset + last_pos.length <= start_token_idx:
|
||||
return extra_keys, start_mm_idx
|
||||
|
||||
# Support start_mm_idx == -1 to indicate the last mm input.
|
||||
@@ -428,13 +428,16 @@ def _gen_mm_extra_hash_keys(
|
||||
offset = mm_feature.mm_position.offset
|
||||
length = mm_feature.mm_position.length
|
||||
if end_token_idx > offset:
|
||||
if start_token_idx > offset + length:
|
||||
if start_token_idx >= offset + length:
|
||||
# This block has passed the current mm input.
|
||||
curr_mm_idx += 1
|
||||
continue
|
||||
|
||||
# The block contains the current mm input.
|
||||
extra_keys.append(mm_feature.identifier)
|
||||
# The block contains the current mm input. Include its offset
|
||||
# relative to the start of the block so prefix-cache keys stay
|
||||
# distinct when the same MM item appears at different positions
|
||||
# within otherwise-identical placeholder blocks.
|
||||
extra_keys.append((mm_feature.identifier, offset - start_token_idx))
|
||||
|
||||
if end_token_idx >= offset + length:
|
||||
# If this block contains the end of the current mm input,
|
||||
|
||||
Reference in New Issue
Block a user