diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 8153fed69..d8ecf28cb 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -447,12 +447,12 @@ def test_generate_block_hash_extra_keys(): # Test with no extra keys extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 0, 5, 0) - assert extra_keys == ("hash1",) + assert extra_keys == (("hash1", 0),) assert next_mm_idx == 1 # Test with partial overlap extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 3, 8, 0) - assert extra_keys == ("hash1",) + assert extra_keys == (("hash1", -3),) assert next_mm_idx == 1 # Test with no overlap @@ -462,7 +462,7 @@ def test_generate_block_hash_extra_keys(): # Test with multiple extra keys extra_keys, next_mm_idx = generate_block_hash_extra_keys(request, 0, 15, 0) - assert extra_keys == ("hash1", "hash2") + assert extra_keys == (("hash1", 0), ("hash2", 10)) assert next_mm_idx == 2 @@ -513,7 +513,7 @@ def test_generate_block_hash_extra_keys_cache_salt(): # Test with no extra keys extra_keys, next_mm_idx = generate_block_hash_extra_keys(request_mm, 0, 5, 0) - assert extra_keys == ("hash1", "salt") + assert extra_keys == (("hash1", 0), "salt") assert next_mm_idx == 1 @@ -637,8 +637,10 @@ def test_request_block_hasher(hash_fn): block_hashes = request.block_hashes assert len(block_hashes) == 2 - assert block_hashes[0] == hash_fn((kv_cache_utils.NONE_HASH, (0, 1, 2), ("hash1",))) - assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), ("hash2",))) + assert block_hashes[0] == hash_fn( + (kv_cache_utils.NONE_HASH, (0, 1, 2), (("hash1", 0),)) + ) + assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), (("hash2", 0),))) @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) @@ -1973,7 +1975,7 @@ def test_request_with_prompt_embeds_and_mm_inputs(hash_fn: Callable[[Any], bytes ( kv_cache_utils.NONE_HASH, tuple(prompt_token_ids[:block_size]), - ("hash1", block1_embeds_hash), + (("hash1", 0), block1_embeds_hash), ) ) assert block_hashes[0] == expected_hash1 @@ -1985,7 +1987,7 @@ def test_request_with_prompt_embeds_and_mm_inputs(hash_fn: Callable[[Any], bytes ( block_hashes[0], tuple(prompt_token_ids[block_size:num_tokens]), - ("hash2", block2_embeds_hash), + (("hash2", 0), block2_embeds_hash), ) ) assert block_hashes[1] == expected_hash2 diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 28355eb54..b8b387fff 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -1570,20 +1570,24 @@ def test_mm_prefix_caching(): block_hashes = req0.block_hashes assert len(block_hashes) == 3 assert block_hashes[0] == sha256( - (kv_cache_utils.NONE_HASH, tuple(all_token_ids[:block_size]), ("aaa",)) + ( + kv_cache_utils.NONE_HASH, + tuple(all_token_ids[:block_size]), + (("aaa", 11),), + ) ) assert block_hashes[1] == sha256( ( block_hashes[0], tuple(all_token_ids[block_size : block_size * 2]), - ("aaa", "bbb"), + (("aaa", -5), ("bbb", 14)), ) ) assert block_hashes[2] == sha256( ( block_hashes[1], tuple(all_token_ids[block_size * 2 : block_size * 3]), - ("bbb",), + (("bbb", -2),), ) ) @@ -1603,7 +1607,11 @@ def test_mm_prefix_caching(): assert new_blocks is not None and len(new_blocks.blocks[0]) == 0 assert len(block_hashes) == 4 assert block_hashes[3] == sha256( - (block_hashes[2], tuple(all_token_ids[3 * block_size :] + [8] * 5), ("ccc",)) + ( + block_hashes[2], + tuple(all_token_ids[3 * block_size :] + [8] * 5), + (("ccc", 0),), + ) ) # Cache hit. diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 83ada0530..9ab5af0f6 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -413,7 +413,7 @@ def _gen_mm_extra_hash_keys( # We do not need to check all mm inputs if the start token index is out of # range. This usually happens in the late prefill phase and decoding phase. last_pos = mm_features[-1].mm_position - if last_pos.offset + last_pos.length < start_token_idx: + if last_pos.offset + last_pos.length <= start_token_idx: return extra_keys, start_mm_idx # Support start_mm_idx == -1 to indicate the last mm input. @@ -428,13 +428,16 @@ def _gen_mm_extra_hash_keys( offset = mm_feature.mm_position.offset length = mm_feature.mm_position.length if end_token_idx > offset: - if start_token_idx > offset + length: + if start_token_idx >= offset + length: # This block has passed the current mm input. curr_mm_idx += 1 continue - # The block contains the current mm input. - extra_keys.append(mm_feature.identifier) + # The block contains the current mm input. Include its offset + # relative to the start of the block so prefix-cache keys stay + # distinct when the same MM item appears at different positions + # within otherwise-identical placeholder blocks. + extra_keys.append((mm_feature.identifier, offset - start_token_idx)) if end_token_idx >= offset + length: # If this block contains the end of the current mm input,