Support clear mm and encoder cache (#33452)
Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -172,6 +172,7 @@ These endpoints are **only available when the environment variable `VLLM_SERVER_
|
||||
- `/server_info` - Get detailed server configuration
|
||||
- `/reset_prefix_cache` - Reset prefix cache (can disrupt service)
|
||||
- `/reset_mm_cache` - Reset multimodal cache (can disrupt service)
|
||||
- `/reset_encoder_cache` - Reset encoder cache (can disrupt service)
|
||||
- `/sleep` - Put engine to sleep (causes denial of service)
|
||||
- `/wake_up` - Wake engine from sleep
|
||||
- `/is_sleeping` - Check if engine is sleeping
|
||||
|
||||
@@ -4,7 +4,10 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange
|
||||
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
|
||||
from vllm.v1.core.encoder_cache_manager import (
|
||||
EncoderCacheManager,
|
||||
EncoderDecoderCacheManager,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
@@ -247,3 +250,88 @@ def test_encoder_cache_mask_based_retrieval():
|
||||
|
||||
assert num_embeds_before == 0
|
||||
assert num_embeds_in_range == 2
|
||||
|
||||
|
||||
def test_reset_clears_all_state():
|
||||
"""Test that reset() clears all cached entries and restores capacity."""
|
||||
manager = EncoderCacheManager(cache_size=20)
|
||||
|
||||
req1 = MockRequest("req1", ["img1", "img2"], [5, 3])
|
||||
req2 = MockRequest("req2", ["img3"], [4])
|
||||
|
||||
manager.allocate(req1, 0)
|
||||
manager.allocate(req1, 1)
|
||||
manager.allocate(req2, 0)
|
||||
manager.free_encoder_input(req1, 0)
|
||||
|
||||
req3 = MockRequest("req3", ["img4"], [10])
|
||||
manager.free_encoder_input(req1, 1)
|
||||
manager.free_encoder_input(req2, 0)
|
||||
manager.can_allocate(req3, 0, int(1e9), 0)
|
||||
manager.allocate(req3, 0)
|
||||
|
||||
assert len(manager.cached) > 0
|
||||
assert manager.num_free_slots < 20
|
||||
|
||||
manager.reset()
|
||||
|
||||
assert len(manager.cached) == 0
|
||||
assert len(manager.freeable) == 0
|
||||
assert len(manager.freed) == 0
|
||||
assert manager.num_free_slots == 20
|
||||
assert manager.num_freeable_slots == 20
|
||||
|
||||
|
||||
def test_reset_allows_fresh_allocations():
|
||||
manager = EncoderCacheManager(cache_size=10)
|
||||
|
||||
req1 = MockRequest("req1", ["img1"], [10])
|
||||
manager.allocate(req1, 0)
|
||||
assert manager.num_free_slots == 0
|
||||
|
||||
manager.reset()
|
||||
|
||||
req2 = MockRequest("req2", ["img2"], [8])
|
||||
assert manager.can_allocate(req2, 0, int(1e9), 0)
|
||||
manager.allocate(req2, 0)
|
||||
|
||||
assert manager.num_free_slots == 2
|
||||
assert "img2" in manager.cached
|
||||
assert "img1" not in manager.cached
|
||||
|
||||
|
||||
def test_encoder_decoder_cache_manager_reset():
|
||||
manager = EncoderDecoderCacheManager(cache_size=20)
|
||||
|
||||
req1 = MockRequest("req1", ["img1"], [5])
|
||||
req2 = MockRequest("req2", ["img2"], [3])
|
||||
|
||||
manager.allocate(req1, 0)
|
||||
manager.allocate(req2, 0)
|
||||
manager.free(req1)
|
||||
manager.get_freed_mm_hashes()
|
||||
|
||||
assert manager.num_free_slots < 20
|
||||
|
||||
manager.reset()
|
||||
|
||||
assert len(manager.allocated) == 0
|
||||
assert len(manager.to_free) == 0
|
||||
assert manager.num_free_slots == 20
|
||||
|
||||
|
||||
def test_encoder_decoder_cache_manager_reset_allows_fresh_allocations():
|
||||
manager = EncoderDecoderCacheManager(cache_size=10)
|
||||
|
||||
req1 = MockRequest("req1", ["img1"], [10])
|
||||
manager.allocate(req1, 0)
|
||||
assert manager.num_free_slots == 0
|
||||
|
||||
manager.reset()
|
||||
|
||||
req2 = MockRequest("req2", ["img2"], [8])
|
||||
assert manager.can_allocate(req2, 0, int(1e9), 0)
|
||||
manager.allocate(req2, 0)
|
||||
|
||||
assert manager.num_free_slots == 2
|
||||
assert "img2" in manager.allocated
|
||||
|
||||
@@ -113,6 +113,11 @@ class EngineClient(ABC):
|
||||
"""Reset the multi-modal cache"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def reset_encoder_cache(self) -> None:
|
||||
"""Reset the encoder cache"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def reset_prefix_cache(
|
||||
self, reset_running_requests: bool = False, reset_connector: bool = False
|
||||
|
||||
11
vllm/entrypoints/serve/cache/api_router.py
vendored
11
vllm/entrypoints/serve/cache/api_router.py
vendored
@@ -55,6 +55,17 @@ async def reset_mm_cache(raw_request: Request):
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.post("/reset_encoder_cache")
|
||||
async def reset_encoder_cache(raw_request: Request):
|
||||
"""
|
||||
Reset the encoder cache. Note that we currently do not check if the
|
||||
encoder cache is successfully reset in the API server.
|
||||
"""
|
||||
logger.info("Resetting encoder cache...")
|
||||
await engine_client(raw_request).reset_encoder_cache()
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
if not envs.VLLM_SERVER_DEV_MODE:
|
||||
return
|
||||
|
||||
@@ -77,6 +77,18 @@ class EncoderCacheManager:
|
||||
self.freeable: OrderedDict[str, int] = OrderedDict()
|
||||
self.freed: list[str] = []
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the encoder cache to its initial state.
|
||||
|
||||
This clears all cached encoder outputs and resets capacity tracking.
|
||||
Called when model weights are updated to invalidate stale embeddings.
|
||||
"""
|
||||
self.cached.clear()
|
||||
self.freeable.clear()
|
||||
self.freed.clear()
|
||||
self.num_free_slots = self.cache_size
|
||||
self.num_freeable_slots = self.cache_size
|
||||
|
||||
def check_and_update_cache(self, request: Request, input_id: int) -> bool:
|
||||
"""Check if encoder output for a specific multimodal input is cached.
|
||||
|
||||
@@ -360,6 +372,12 @@ class EncoderDecoderCacheManager(EncoderCacheManager):
|
||||
self.allocated: list[str] = []
|
||||
self.to_free: list[str] = []
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the encoder cache to its initial state."""
|
||||
self.num_free_slots = self.cache_size
|
||||
self.allocated.clear()
|
||||
self.to_free.clear()
|
||||
|
||||
def check_and_update_cache(self, request: Request, input_id: int) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@@ -183,6 +183,15 @@ class SchedulerInterface(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def reset_encoder_cache(self) -> None:
|
||||
"""Reset the encoder cache to invalidate all cached encoder outputs.
|
||||
|
||||
This should be called when model weights are updated to ensure
|
||||
stale vision embeddings are not reused.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_request_counts(self) -> tuple[int, int]:
|
||||
"""Returns (num_running_reqs, num_waiting_reqs)."""
|
||||
|
||||
@@ -1763,6 +1763,14 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
return True
|
||||
|
||||
def reset_encoder_cache(self) -> None:
|
||||
"""Reset the encoder cache to invalidate all cached encoder outputs.
|
||||
|
||||
This should be called when model weights are updated to ensure
|
||||
stale vision embeddings are not reused.
|
||||
"""
|
||||
self.encoder_cache_manager.reset()
|
||||
|
||||
def make_stats(
|
||||
self,
|
||||
spec_decoding_stats: SpecDecodingStats | None = None,
|
||||
@@ -1788,6 +1796,7 @@ class Scheduler(SchedulerInterface):
|
||||
num_running_reqs=len(self.running),
|
||||
num_waiting_reqs=len(self.waiting),
|
||||
kv_cache_usage=self.kv_cache_manager.usage,
|
||||
encoder_cache_usage=self._get_encoder_cache_usage(),
|
||||
prefix_cache_stats=prefix_cache_stats,
|
||||
connector_prefix_cache_stats=connector_prefix_cache_stats,
|
||||
kv_cache_eviction_events=eviction_events,
|
||||
@@ -1797,6 +1806,14 @@ class Scheduler(SchedulerInterface):
|
||||
perf_stats=perf_stats,
|
||||
)
|
||||
|
||||
def _get_encoder_cache_usage(self) -> float:
|
||||
"""Get encoder cache usage as a fraction (0.0 to 1.0)."""
|
||||
ecm = self.encoder_cache_manager
|
||||
if ecm.cache_size == 0:
|
||||
return 0.0
|
||||
used_slots = ecm.cache_size - ecm.num_free_slots
|
||||
return used_slots / ecm.cache_size
|
||||
|
||||
def make_spec_decoding_stats(
|
||||
self,
|
||||
spec_decoding_stats: SpecDecodingStats | None,
|
||||
|
||||
@@ -882,6 +882,9 @@ class AsyncLLM(EngineClient):
|
||||
reset_running_requests, reset_connector
|
||||
)
|
||||
|
||||
async def reset_encoder_cache(self) -> None:
|
||||
await self.engine_core.reset_encoder_cache_async()
|
||||
|
||||
async def sleep(self, level: int = 1) -> None:
|
||||
await self.reset_prefix_cache()
|
||||
await self.engine_core.sleep_async(level)
|
||||
|
||||
@@ -565,6 +565,26 @@ class EngineCore:
|
||||
reset_running_requests, reset_connector
|
||||
)
|
||||
|
||||
def reset_encoder_cache(self) -> None:
|
||||
"""Reset the encoder cache to invalidate all cached encoder outputs.
|
||||
|
||||
This should be called when model weights are updated to ensure
|
||||
stale vision embeddings computed with old weights are not reused.
|
||||
Clears both the scheduler's cache manager and the GPU model runner's cache.
|
||||
"""
|
||||
# NOTE: Since this is mainly for debugging, we don't attempt to
|
||||
# re-sync the internal caches (P0 sender, P1 receiver)
|
||||
if self.scheduler.has_unfinished_requests():
|
||||
logger.warning(
|
||||
"Resetting the encoder cache when requests are "
|
||||
"in progress may lead to desynced internal caches."
|
||||
)
|
||||
|
||||
# Reset the scheduler's encoder cache manager (logical state)
|
||||
self.scheduler.reset_encoder_cache()
|
||||
# Reset the GPU model runner's encoder cache (physical storage)
|
||||
self.model_executor.reset_encoder_cache()
|
||||
|
||||
def sleep(self, level: int = 1):
|
||||
self.model_executor.sleep(level)
|
||||
|
||||
|
||||
@@ -144,6 +144,9 @@ class EngineCoreClient(ABC):
|
||||
) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def reset_encoder_cache(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -216,6 +219,9 @@ class EngineCoreClient(ABC):
|
||||
) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
async def reset_encoder_cache_async(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def sleep_async(self, level: int = 1) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -300,6 +306,9 @@ class InprocClient(EngineCoreClient):
|
||||
reset_running_requests, reset_connector
|
||||
)
|
||||
|
||||
def reset_encoder_cache(self) -> None:
|
||||
self.engine_core.reset_encoder_cache()
|
||||
|
||||
def sleep(self, level: int = 1) -> None:
|
||||
self.engine_core.sleep(level)
|
||||
|
||||
@@ -765,6 +774,9 @@ class SyncMPClient(MPClient):
|
||||
"reset_prefix_cache", reset_running_requests, reset_connector
|
||||
)
|
||||
|
||||
def reset_encoder_cache(self) -> None:
|
||||
self.call_utility("reset_encoder_cache")
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.call_utility("add_lora", lora_request)
|
||||
|
||||
@@ -973,6 +985,9 @@ class AsyncMPClient(MPClient):
|
||||
"reset_prefix_cache", reset_running_requests, reset_connector
|
||||
)
|
||||
|
||||
async def reset_encoder_cache_async(self) -> None:
|
||||
await self.call_utility_async("reset_encoder_cache")
|
||||
|
||||
async def sleep_async(self, level: int = 1) -> None:
|
||||
await self.call_utility_async("sleep", level)
|
||||
|
||||
|
||||
@@ -332,6 +332,14 @@ class LLMEngine:
|
||||
reset_running_requests, reset_connector
|
||||
)
|
||||
|
||||
def reset_encoder_cache(self) -> None:
|
||||
"""Reset the encoder cache to invalidate all cached encoder outputs.
|
||||
|
||||
This should be called when model weights are updated to ensure
|
||||
stale vision embeddings computed with old weights are not reused.
|
||||
"""
|
||||
self.engine_core.reset_encoder_cache()
|
||||
|
||||
def sleep(self, level: int = 1):
|
||||
self.engine_core.sleep(level)
|
||||
|
||||
|
||||
@@ -294,6 +294,10 @@ class Executor(ABC):
|
||||
"""Reset the multi-modal cache in each worker."""
|
||||
self.collective_rpc("reset_mm_cache")
|
||||
|
||||
def reset_encoder_cache(self) -> None:
|
||||
"""Reset the encoder cache in each worker to clear cached encoder outputs."""
|
||||
self.collective_rpc("reset_encoder_cache")
|
||||
|
||||
def sleep(self, level: int = 1):
|
||||
if self.is_sleeping:
|
||||
logger.warning("Executor is already sleeping.")
|
||||
|
||||
@@ -173,6 +173,7 @@ class SchedulerStats:
|
||||
current_wave: int = 0
|
||||
|
||||
kv_cache_usage: float = 0.0
|
||||
encoder_cache_usage: float = 0.0
|
||||
|
||||
prefix_cache_stats: PrefixCacheStats = field(default_factory=PrefixCacheStats)
|
||||
connector_prefix_cache_stats: PrefixCacheStats | None = None
|
||||
|
||||
@@ -720,6 +720,14 @@ class GPUModelRunner(
|
||||
if self.mm_budget:
|
||||
self.mm_budget.reset_cache()
|
||||
|
||||
def reset_encoder_cache(self) -> None:
|
||||
"""Clear the GPU-side encoder cache storing vision embeddings.
|
||||
|
||||
This should be called when model weights are updated to ensure
|
||||
stale embeddings computed with old weights are not reused.
|
||||
"""
|
||||
self.encoder_cache.clear()
|
||||
|
||||
@torch.inference_mode()
|
||||
def init_fp8_kv_scales(self) -> None:
|
||||
"""
|
||||
|
||||
@@ -539,6 +539,9 @@ class Worker(WorkerBase):
|
||||
def reset_mm_cache(self) -> None:
|
||||
self.model_runner.reset_mm_cache()
|
||||
|
||||
def reset_encoder_cache(self) -> None:
|
||||
self.model_runner.reset_encoder_cache()
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model_runner.get_model()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user