[Bugfix][CPUOffloadingManager] Prevent eviction of already-stored blocks in LRU/ARC prepare_store() (#35846)

Signed-off-by: Ronen Schaffer <ronen.schaffer@ibm.com>
This commit is contained in:
Ronen Schaffer
2026-03-04 14:25:33 +02:00
committed by GitHub
parent 1aaec59d79
commit bb6888b8b1
3 changed files with 67 additions and 5 deletions

View File

@@ -4,6 +4,7 @@ from collections.abc import Iterable
from dataclasses import dataclass
import numpy as np
import pytest
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.kv_offload.abstract import (
@@ -78,6 +79,54 @@ def verify_events(
assert tuple(stores) == to_hash_sets(expected_stores)
@pytest.mark.parametrize("manager_class", [LRUOffloadingManager, ARCOffloadingManager])
def test_already_stored_block_not_evicted_during_prepare_store(manager_class):
"""
Regression test: a block that is already stored must not be evicted
by prepare_store() when it needs to make room for new blocks.
Applies to both LRUOffloadingManager and ARCOffloadingManager.
Scenario:
- Store blocks [1, 2] and complete.
- touch([1]) makes block 2 the LRU candidate.
- prepare_store([2, 3, 4, 5]):
* block 2 is filtered out as "already stored"
* but without the fix, block 2 would be evicted as the LRU
candidate to make room for [3, 4, 5]
- After complete_store([2, 3, 4, 5]), block 2 must still be present.
"""
block_size = 256
cpu_backend = CPUBackend(block_size=block_size, num_blocks=4)
manager = manager_class(cpu_backend, enable_events=True)
# store [1, 2] and complete
manager.prepare_store(to_hashes([1, 2]))
manager.complete_store(to_hashes([1, 2]))
# touch [1] to make block 2 the LRU candidate
manager.touch(to_hashes([1]))
# prepare_store([2, 3, 4, 5]):
# - block 2 is already stored → filtered out of block_hashes_to_store
# - block 2 must NOT be evicted even though it is the LRU candidate
# - block 1 (ID 0) is evicted instead; new blocks [3,4,5] get IDs 2,3,0
prepare_store_output = manager.prepare_store(to_hashes([2, 3, 4, 5]))
verify_store_output(
prepare_store_output,
ExpectedPrepareStoreOutput(
block_hashes_to_store=[3, 4, 5],
store_block_ids=[2, 3, 0],
block_hashes_evicted=[1], # block 1 evicted, not block 2
),
)
# complete_store must not silently drop block 2
manager.complete_store(to_hashes([2, 3, 4, 5]))
# block 2 must still be present in the cache
assert manager.lookup(to_hashes([2])) == 1
def test_cpu_manager():
"""
Tests LRUOffloadingManager with a CPUBackend.