[Bugfix][CPUOffloadingManager] Prevent eviction of already-stored blocks in LRU/ARC prepare_store() (#35846)
Signed-off-by: Ronen Schaffer <ronen.schaffer@ibm.com>
This commit is contained in:
@@ -4,6 +4,7 @@ from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
from vllm.v1.kv_offload.abstract import (
|
||||
@@ -78,6 +79,54 @@ def verify_events(
|
||||
assert tuple(stores) == to_hash_sets(expected_stores)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("manager_class", [LRUOffloadingManager, ARCOffloadingManager])
|
||||
def test_already_stored_block_not_evicted_during_prepare_store(manager_class):
|
||||
"""
|
||||
Regression test: a block that is already stored must not be evicted
|
||||
by prepare_store() when it needs to make room for new blocks.
|
||||
Applies to both LRUOffloadingManager and ARCOffloadingManager.
|
||||
|
||||
Scenario:
|
||||
- Store blocks [1, 2] and complete.
|
||||
- touch([1]) makes block 2 the LRU candidate.
|
||||
- prepare_store([2, 3, 4, 5]):
|
||||
* block 2 is filtered out as "already stored"
|
||||
* but without the fix, block 2 would be evicted as the LRU
|
||||
candidate to make room for [3, 4, 5]
|
||||
- After complete_store([2, 3, 4, 5]), block 2 must still be present.
|
||||
"""
|
||||
block_size = 256
|
||||
cpu_backend = CPUBackend(block_size=block_size, num_blocks=4)
|
||||
manager = manager_class(cpu_backend, enable_events=True)
|
||||
|
||||
# store [1, 2] and complete
|
||||
manager.prepare_store(to_hashes([1, 2]))
|
||||
manager.complete_store(to_hashes([1, 2]))
|
||||
|
||||
# touch [1] to make block 2 the LRU candidate
|
||||
manager.touch(to_hashes([1]))
|
||||
|
||||
# prepare_store([2, 3, 4, 5]):
|
||||
# - block 2 is already stored → filtered out of block_hashes_to_store
|
||||
# - block 2 must NOT be evicted even though it is the LRU candidate
|
||||
# - block 1 (ID 0) is evicted instead; new blocks [3,4,5] get IDs 2,3,0
|
||||
prepare_store_output = manager.prepare_store(to_hashes([2, 3, 4, 5]))
|
||||
verify_store_output(
|
||||
prepare_store_output,
|
||||
ExpectedPrepareStoreOutput(
|
||||
block_hashes_to_store=[3, 4, 5],
|
||||
store_block_ids=[2, 3, 0],
|
||||
block_hashes_evicted=[1], # block 1 evicted, not block 2
|
||||
),
|
||||
)
|
||||
|
||||
# complete_store must not silently drop block 2
|
||||
manager.complete_store(to_hashes([2, 3, 4, 5]))
|
||||
|
||||
# block 2 must still be present in the cache
|
||||
assert manager.lookup(to_hashes([2])) == 1
|
||||
|
||||
|
||||
def test_cpu_manager():
|
||||
"""
|
||||
Tests LRUOffloadingManager with a CPUBackend.
|
||||
|
||||
@@ -123,8 +123,10 @@ class ARCOffloadingManager(OffloadingManager):
|
||||
def prepare_store(
|
||||
self, block_hashes: Iterable[BlockHash]
|
||||
) -> PrepareStoreOutput | None:
|
||||
block_hashes_list = list(block_hashes)
|
||||
|
||||
block_hashes_to_store = []
|
||||
for block_hash in block_hashes:
|
||||
for block_hash in block_hashes_list:
|
||||
if block_hash not in self.t1 and block_hash not in self.t2:
|
||||
block_hashes_to_store.append(block_hash)
|
||||
|
||||
@@ -140,12 +142,16 @@ class ARCOffloadingManager(OffloadingManager):
|
||||
)
|
||||
|
||||
to_evict = []
|
||||
if num_blocks_to_evict > 0:
|
||||
# Blocks from the original input are excluded from eviction candidates:
|
||||
# a block that was already stored must remain in the cache after this call.
|
||||
protected = set(block_hashes_list)
|
||||
while num_blocks_to_evict > 0:
|
||||
block_to_evict = None
|
||||
if len(self.t1) >= int(self.target_t1_size):
|
||||
# try to evict the least recently used (oldest) block from T1
|
||||
for block_hash, block in self.t1.items():
|
||||
if block.ref_cnt == 0:
|
||||
if block.ref_cnt == 0 and block_hash not in protected:
|
||||
block_to_evict = (block_hash, block)
|
||||
eviction_t = self.t1
|
||||
eviction_b = self.b1
|
||||
@@ -153,7 +159,7 @@ class ARCOffloadingManager(OffloadingManager):
|
||||
if not block_to_evict:
|
||||
# try to evict the least recently used (oldest) block from T2
|
||||
for block_hash, block in self.t2.items():
|
||||
if block.ref_cnt == 0:
|
||||
if block.ref_cnt == 0 and block_hash not in protected:
|
||||
block_to_evict = (block_hash, block)
|
||||
eviction_t = self.t2
|
||||
eviction_b = self.b2
|
||||
|
||||
@@ -57,9 +57,13 @@ class LRUOffloadingManager(OffloadingManager):
|
||||
def prepare_store(
|
||||
self, block_hashes: Iterable[BlockHash]
|
||||
) -> PrepareStoreOutput | None:
|
||||
block_hashes_list = list(block_hashes)
|
||||
|
||||
# filter out blocks that are already stored
|
||||
block_hashes_to_store = [
|
||||
block_hash for block_hash in block_hashes if block_hash not in self.blocks
|
||||
block_hash
|
||||
for block_hash in block_hashes_list
|
||||
if block_hash not in self.blocks
|
||||
]
|
||||
|
||||
num_blocks_to_evict = (
|
||||
@@ -69,8 +73,11 @@ class LRUOffloadingManager(OffloadingManager):
|
||||
# build list of blocks to evict
|
||||
to_evict = []
|
||||
if num_blocks_to_evict > 0:
|
||||
# Blocks from the original input are excluded from eviction candidates:
|
||||
# a block that was already stored must remain in the cache after this call.
|
||||
protected = set(block_hashes_list)
|
||||
for block_hash, block in self.blocks.items():
|
||||
if block.ref_cnt == 0:
|
||||
if block.ref_cnt == 0 and block_hash not in protected:
|
||||
to_evict.append(block_hash)
|
||||
num_blocks_to_evict -= 1
|
||||
if num_blocks_to_evict == 0:
|
||||
|
||||
Reference in New Issue
Block a user