fix: disambiguate multimodal prefix cache keys (#36708)

Signed-off-by: tianshu.yu <tianshuyu.formal@gmail.com>
This commit is contained in:
tianshu-Michael-yu
2026-03-19 19:33:20 -07:00
committed by GitHub
parent e5a77a5015
commit 269bf46d99
3 changed files with 29 additions and 16 deletions

View File

@@ -447,12 +447,12 @@ def test_generate_block_hash_extra_keys():
# Test with no extra keys
extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 0, 5, 0)
assert extra_keys == ("hash1",)
assert extra_keys == (("hash1", 0),)
assert next_mm_idx == 1
# Test with partial overlap
extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 3, 8, 0)
assert extra_keys == ("hash1",)
assert extra_keys == (("hash1", -3),)
assert next_mm_idx == 1
# Test with no overlap
@@ -462,7 +462,7 @@ def test_generate_block_hash_extra_keys():
# Test with multiple extra keys
extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 0, 15, 0)
assert extra_keys == ("hash1", "hash2")
assert extra_keys == (("hash1", 0), ("hash2", 10))
assert next_mm_idx == 2
@@ -513,7 +513,7 @@ def test_generate_block_hash_extra_keys_cache_salt():
# Test with no extra keys
extra_keys, next_mm_idx = generate_block_hash_extra_keys(request_mm, 0, 5, 0)
assert extra_keys == ("hash1", "salt")
assert extra_keys == (("hash1", 0), "salt")
assert next_mm_idx == 1
@@ -637,8 +637,10 @@ def test_request_block_hasher(hash_fn):
block_hashes = request.block_hashes
assert len(block_hashes) == 2
assert block_hashes[0] == hash_fn((kv_cache_utils.NONE_HASH, (0, 1, 2), ("hash1",)))
assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), ("hash2",)))
assert block_hashes[0] == hash_fn(
(kv_cache_utils.NONE_HASH, (0, 1, 2), (("hash1", 0),))
)
assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), (("hash2", 0),)))
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
@@ -1973,7 +1975,7 @@ def test_request_with_prompt_embeds_and_mm_inputs(hash_fn: Callable[[Any], bytes
(
kv_cache_utils.NONE_HASH,
tuple(prompt_token_ids[:block_size]),
("hash1", block1_embeds_hash),
(("hash1", 0), block1_embeds_hash),
)
)
assert block_hashes[0] == expected_hash1
@@ -1985,7 +1987,7 @@ def test_request_with_prompt_embeds_and_mm_inputs(hash_fn: Callable[[Any], bytes
(
block_hashes[0],
tuple(prompt_token_ids[block_size:num_tokens]),
("hash2", block2_embeds_hash),
(("hash2", 0), block2_embeds_hash),
)
)
assert block_hashes[1] == expected_hash2

View File

@@ -1570,20 +1570,24 @@ def test_mm_prefix_caching():
block_hashes = req0.block_hashes
assert len(block_hashes) == 3
assert block_hashes[0] == sha256(
(kv_cache_utils.NONE_HASH, tuple(all_token_ids[:block_size]), ("aaa",))
(
kv_cache_utils.NONE_HASH,
tuple(all_token_ids[:block_size]),
(("aaa", 11),),
)
)
assert block_hashes[1] == sha256(
(
block_hashes[0],
tuple(all_token_ids[block_size : block_size * 2]),
("aaa", "bbb"),
(("aaa", -5), ("bbb", 14)),
)
)
assert block_hashes[2] == sha256(
(
block_hashes[1],
tuple(all_token_ids[block_size * 2 : block_size * 3]),
("bbb",),
(("bbb", -2),),
)
)
@@ -1603,7 +1607,11 @@ def test_mm_prefix_caching():
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
assert len(block_hashes) == 4
assert block_hashes[3] == sha256(
(block_hashes[2], tuple(all_token_ids[3 * block_size :] + [8] * 5), ("ccc",))
(
block_hashes[2],
tuple(all_token_ids[3 * block_size :] + [8] * 5),
(("ccc", 0),),
)
)
# Cache hit.

View File

@@ -413,7 +413,7 @@ def _gen_mm_extra_hash_keys(
# We do not need to check all mm inputs if the start token index is out of
# range. This usually happens in the late prefill phase and decoding phase.
last_pos = mm_features[-1].mm_position
if last_pos.offset + last_pos.length < start_token_idx:
if last_pos.offset + last_pos.length <= start_token_idx:
return extra_keys, start_mm_idx
# Support start_mm_idx == -1 to indicate the last mm input.
@@ -428,13 +428,16 @@ def _gen_mm_extra_hash_keys(
offset = mm_feature.mm_position.offset
length = mm_feature.mm_position.length
if end_token_idx > offset:
if start_token_idx > offset + length:
if start_token_idx >= offset + length:
# This block has passed the current mm input.
curr_mm_idx += 1
continue
# The block contains the current mm input.
extra_keys.append(mm_feature.identifier)
# The block contains the current mm input. Include its offset
# relative to the start of the block so prefix-cache keys stay
# distinct when the same MM item appears at different positions
# within otherwise-identical placeholder blocks.
extra_keys.append((mm_feature.identifier, offset - start_token_idx))
if end_token_idx >= offset + length:
# If this block contains the end of the current mm input,