feat(kv-offload): Strategy A — StoreReusedOffloadingManager gates CPU stores on reuse frequency (#35342)

Signed-off-by: srinivas_oo7 <Sriusa4414@gmail.com>
Signed-off-by: Sriusa4414@gmail.com
Signed-off-by: Srinivasoo7 <158864704+Srinivasoo7@users.noreply.github.com>
Co-authored-by: srinivas_oo7 <sklinkedin0120@gmail.com>
Co-authored-by: Srinivasoo7 <158864704+Srinivasoo7@users.noreply.github.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Srinivasoo7
2026-03-10 09:43:40 -05:00
committed by GitHub
parent ca5fb4bbd8
commit 106ff69c4e
3 changed files with 184 additions and 0 deletions

View File

@@ -544,3 +544,52 @@ def test_arc_manager_full_scenario():
# verify events
events = list(arc_manager.take_events())
assert len(events) > 0 # should have store and eviction events
def test_filter_reused_manager():
"""
Tests FilterReusedOffloadingManager with a CPUBackend.
"""
block_size = 256
cpu_backend = CPUBackend(block_size=block_size, num_blocks=4)
lru_manager = LRUOffloadingManager(cpu_backend, enable_events=True)
from vllm.v1.kv_offload.reuse_manager import FilterReusedOffloadingManager
manager = FilterReusedOffloadingManager(
backing=lru_manager, store_threshold=2, max_tracker_size=3
)
# Lookup [1, 2] -> 1st time, added to tracker but not eligible for store yet
assert manager.lookup(to_hashes([1, 2])) == 0
# prepare store [1, 2] -> should be filtered
prepare_store_output = manager.prepare_store(to_hashes([1, 2]))
assert prepare_store_output is not None
assert prepare_store_output.block_hashes_to_store == []
# Lookup [1] -> 2nd time, eligible now
assert manager.lookup(to_hashes([1])) == 0
# prepare store [1, 2] -> [1] should be eligible, [2] should be filtered
prepare_store_output = manager.prepare_store(to_hashes([1, 2]))
assert prepare_store_output is not None
assert prepare_store_output.block_hashes_to_store == to_hashes([1])
# Lookup [3, 4] -> 1st time
# (evicts [2] from tracker since max_size is 3 and tracker has [1])
assert manager.lookup(to_hashes([3, 4])) == 0
# Verify [2] was evicted from the tracker (tracker now has: [1], [3], [4])
assert to_hashes([2])[0] not in manager.counts
# Lookup [2] again -> (this adds [2] back to the tracker as 1st time)
assert manager.lookup(to_hashes([2])) == 0
# Verify [2] was re-added with count=1 (not eligible yet)
assert manager.counts.get(to_hashes([2])[0]) == 1
# prepare store [2] -> should still be filtered out since count was reset
prepare_store_output = manager.prepare_store(to_hashes([2]))
assert prepare_store_output is not None
assert prepare_store_output.block_hashes_to_store == []
manager.complete_store(to_hashes([1]))

View File

@@ -13,6 +13,7 @@ from vllm.v1.kv_offload.arc_manager import ARCOffloadingManager
from vllm.v1.kv_offload.backends.cpu import CPUBackend
from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
from vllm.v1.kv_offload.reuse_manager import FilterReusedOffloadingManager
from vllm.v1.kv_offload.spec import OffloadingSpec
from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandlers
from vllm.v1.kv_offload.worker.worker import OffloadingHandler
@@ -83,6 +84,20 @@ class CPUOffloadingSpec(OffloadingSpec):
f"Unknown eviction policy: {self.eviction_policy}. "
f"Supported policies: lru, arc"
)
# store_threshold: how many times a block must appear in lookup()
# before it is eligible for CPU offloading. Values < 2 disable
# filtering (a threshold of 1 equals no filter; 0 is the default).
store_threshold = int(self.extra_config.get("store_threshold", 0))
if store_threshold >= 2:
max_tracker_size = int(
self.extra_config.get("max_tracker_size", 64_000)
)
self._manager = FilterReusedOffloadingManager(
backing=self._manager,
store_threshold=store_threshold,
max_tracker_size=max_tracker_size,
)
return self._manager
def get_handlers(

View File

@@ -0,0 +1,120 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Reuse-frequency gating for CPU KV-cache offload stores.
FilterReusedOffloadingManager — OffloadingManager decorator that skips
storing blocks that have not yet been seen enough times.
"""
from collections import OrderedDict
from collections.abc import Iterable
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.kv_offload.abstract import (
LoadStoreSpec,
OffloadingEvent,
OffloadingManager,
PrepareStoreOutput,
)
class FilterReusedOffloadingManager(OffloadingManager):
"""An :class:`OffloadingManager` decorator that skips storing blocks
whose reuse frequency is below *store_threshold*.
All methods are delegated to the *backing* manager. Two methods are
intercepted:
* ``lookup`` — records each visited block hash in an internal LRU counter.
* ``prepare_store`` — filters out block hashes that have not yet
crossed the threshold *before* calling the backing
``prepare_store``.
Args:
backing: The underlying ``OffloadingManager`` to delegate to.
store_threshold: A block must be seen at least this many times in
``lookup()`` before it is eligible for offloading. Must be >= 2
(a value of 1 would be equivalent to no filtering).
max_tracker_size: Maximum entries in the internal tracker's LRU table.
"""
def __init__(
self,
backing: OffloadingManager,
store_threshold: int = 2,
max_tracker_size: int = 64_000,
):
if store_threshold < 2:
raise ValueError(
"FilterReusedOffloadingManager store_threshold must be >= 2, "
f"got {store_threshold}"
)
if max_tracker_size < 1:
raise ValueError(
"FilterReusedOffloadingManager max_tracker_size must be >= 1, "
f"got {max_tracker_size}"
)
self._backing = backing
self.store_threshold = store_threshold
self.max_tracker_size = max_tracker_size
# Ordered so we can evict the LRU entry in O(1).
self.counts: OrderedDict[BlockHash, int] = OrderedDict()
# ------------------------------------------------------------------
# Intercepted methods
# ------------------------------------------------------------------
def lookup(self, block_hashes: Iterable[BlockHash]) -> int | None:
"""Record each hash, then delegate lookup to backing manager."""
block_hashes = list(block_hashes)
for block_hash in block_hashes:
if block_hash in self.counts:
self.counts.move_to_end(block_hash)
self.counts[block_hash] += 1
else:
if len(self.counts) >= self.max_tracker_size:
self.counts.popitem(last=False) # evict LRU
self.counts[block_hash] = 1
return self._backing.lookup(block_hashes)
def prepare_store(
self, block_hashes: Iterable[BlockHash]
) -> PrepareStoreOutput | None:
"""Filter out blocks below threshold, then delegate to backing.
Filtering is evaluated *before* calling the backing manager's
``prepare_store`` so that blocks that would be skipped do not
consume any CPU offload capacity.
"""
block_hashes = list(block_hashes)
eligible = [
bh for bh in block_hashes if self.counts.get(bh, 0) >= self.store_threshold
]
# Delegate to the backing manager with only the eligible hashes.
# Passing an empty list is intentional and safe — both
# LRUOffloadingManager and ARCOffloadingManager handle it correctly,
# returning a PrepareStoreOutput with empty lists.
return self._backing.prepare_store(eligible)
# ------------------------------------------------------------------
# Delegated methods
# ------------------------------------------------------------------
def prepare_load(self, block_hashes: Iterable[BlockHash]) -> LoadStoreSpec:
return self._backing.prepare_load(block_hashes)
def touch(self, block_hashes: Iterable[BlockHash]) -> None:
return self._backing.touch(block_hashes)
def complete_load(self, block_hashes: Iterable[BlockHash]) -> None:
return self._backing.complete_load(block_hashes)
def complete_store(
self, block_hashes: Iterable[BlockHash], success: bool = True
) -> None:
return self._backing.complete_store(block_hashes, success)
def take_events(self) -> Iterable[OffloadingEvent]:
return self._backing.take_events()