diff --git a/docs/usage/security.md b/docs/usage/security.md index e619eec66..0a54221ec 100644 --- a/docs/usage/security.md +++ b/docs/usage/security.md @@ -172,6 +172,7 @@ These endpoints are **only available when the environment variable `VLLM_SERVER_ - `/server_info` - Get detailed server configuration - `/reset_prefix_cache` - Reset prefix cache (can disrupt service) - `/reset_mm_cache` - Reset multimodal cache (can disrupt service) +- `/reset_encoder_cache` - Reset encoder cache (can disrupt service) - `/sleep` - Put engine to sleep (causes denial of service) - `/wake_up` - Wake engine from sleep - `/is_sleeping` - Check if engine is sleeping diff --git a/tests/v1/core/test_encoder_cache_manager.py b/tests/v1/core/test_encoder_cache_manager.py index 511ff48c4..f82c0070c 100644 --- a/tests/v1/core/test_encoder_cache_manager.py +++ b/tests/v1/core/test_encoder_cache_manager.py @@ -4,7 +4,10 @@ import pytest import torch from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange -from vllm.v1.core.encoder_cache_manager import EncoderCacheManager +from vllm.v1.core.encoder_cache_manager import ( + EncoderCacheManager, + EncoderDecoderCacheManager, +) pytestmark = pytest.mark.cpu_test @@ -247,3 +250,88 @@ def test_encoder_cache_mask_based_retrieval(): assert num_embeds_before == 0 assert num_embeds_in_range == 2 + + +def test_reset_clears_all_state(): + """Test that reset() clears all cached entries and restores capacity.""" + manager = EncoderCacheManager(cache_size=20) + + req1 = MockRequest("req1", ["img1", "img2"], [5, 3]) + req2 = MockRequest("req2", ["img3"], [4]) + + manager.allocate(req1, 0) + manager.allocate(req1, 1) + manager.allocate(req2, 0) + manager.free_encoder_input(req1, 0) + + req3 = MockRequest("req3", ["img4"], [10]) + manager.free_encoder_input(req1, 1) + manager.free_encoder_input(req2, 0) + manager.can_allocate(req3, 0, int(1e9), 0) + manager.allocate(req3, 0) + + assert len(manager.cached) > 0 + assert manager.num_free_slots < 20 + + manager.reset() + + assert len(manager.cached) == 0 + assert len(manager.freeable) == 0 + assert len(manager.freed) == 0 + assert manager.num_free_slots == 20 + assert manager.num_freeable_slots == 20 + + +def test_reset_allows_fresh_allocations(): + manager = EncoderCacheManager(cache_size=10) + + req1 = MockRequest("req1", ["img1"], [10]) + manager.allocate(req1, 0) + assert manager.num_free_slots == 0 + + manager.reset() + + req2 = MockRequest("req2", ["img2"], [8]) + assert manager.can_allocate(req2, 0, int(1e9), 0) + manager.allocate(req2, 0) + + assert manager.num_free_slots == 2 + assert "img2" in manager.cached + assert "img1" not in manager.cached + + +def test_encoder_decoder_cache_manager_reset(): + manager = EncoderDecoderCacheManager(cache_size=20) + + req1 = MockRequest("req1", ["img1"], [5]) + req2 = MockRequest("req2", ["img2"], [3]) + + manager.allocate(req1, 0) + manager.allocate(req2, 0) + manager.free(req1) + manager.get_freed_mm_hashes() + + assert manager.num_free_slots < 20 + + manager.reset() + + assert len(manager.allocated) == 0 + assert len(manager.to_free) == 0 + assert manager.num_free_slots == 20 + + +def test_encoder_decoder_cache_manager_reset_allows_fresh_allocations(): + manager = EncoderDecoderCacheManager(cache_size=10) + + req1 = MockRequest("req1", ["img1"], [10]) + manager.allocate(req1, 0) + assert manager.num_free_slots == 0 + + manager.reset() + + req2 = MockRequest("req2", ["img2"], [8]) + assert manager.can_allocate(req2, 0, int(1e9), 0) + manager.allocate(req2, 0) + + assert manager.num_free_slots == 2 + assert "img2" in manager.allocated diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 377ffaa1c..dd31de840 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -113,6 +113,11 @@ class EngineClient(ABC): """Reset the multi-modal cache""" ... + @abstractmethod + async def reset_encoder_cache(self) -> None: + """Reset the encoder cache""" + ... + @abstractmethod async def reset_prefix_cache( self, reset_running_requests: bool = False, reset_connector: bool = False diff --git a/vllm/entrypoints/serve/cache/api_router.py b/vllm/entrypoints/serve/cache/api_router.py index d65989546..10015f02c 100644 --- a/vllm/entrypoints/serve/cache/api_router.py +++ b/vllm/entrypoints/serve/cache/api_router.py @@ -55,6 +55,17 @@ async def reset_mm_cache(raw_request: Request): return Response(status_code=200) +@router.post("/reset_encoder_cache") +async def reset_encoder_cache(raw_request: Request): + """ + Reset the encoder cache. Note that we currently do not check if the + encoder cache is successfully reset in the API server. + """ + logger.info("Resetting encoder cache...") + await engine_client(raw_request).reset_encoder_cache() + return Response(status_code=200) + + def attach_router(app: FastAPI): if not envs.VLLM_SERVER_DEV_MODE: return diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index 6b605bf2f..56f40535e 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -77,6 +77,18 @@ class EncoderCacheManager: self.freeable: OrderedDict[str, int] = OrderedDict() self.freed: list[str] = [] + def reset(self) -> None: + """Reset the encoder cache to its initial state. + + This clears all cached encoder outputs and resets capacity tracking. + Called when model weights are updated to invalidate stale embeddings. + """ + self.cached.clear() + self.freeable.clear() + self.freed.clear() + self.num_free_slots = self.cache_size + self.num_freeable_slots = self.cache_size + def check_and_update_cache(self, request: Request, input_id: int) -> bool: """Check if encoder output for a specific multimodal input is cached. @@ -360,6 +372,12 @@ class EncoderDecoderCacheManager(EncoderCacheManager): self.allocated: list[str] = [] self.to_free: list[str] = [] + def reset(self) -> None: + """Reset the encoder cache to its initial state.""" + self.num_free_slots = self.cache_size + self.allocated.clear() + self.to_free.clear() + def check_and_update_cache(self, request: Request, input_id: int) -> bool: return False diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index 8c8563d13..79aabcdc3 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -183,6 +183,15 @@ class SchedulerInterface(ABC): """ raise NotImplementedError + @abstractmethod + def reset_encoder_cache(self) -> None: + """Reset the encoder cache to invalidate all cached encoder outputs. + + This should be called when model weights are updated to ensure + stale vision embeddings are not reused. + """ + raise NotImplementedError + @abstractmethod def get_request_counts(self) -> tuple[int, int]: """Returns (num_running_reqs, num_waiting_reqs).""" diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 923630d75..b1667c075 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1763,6 +1763,14 @@ class Scheduler(SchedulerInterface): return True + def reset_encoder_cache(self) -> None: + """Reset the encoder cache to invalidate all cached encoder outputs. + + This should be called when model weights are updated to ensure + stale vision embeddings are not reused. + """ + self.encoder_cache_manager.reset() + def make_stats( self, spec_decoding_stats: SpecDecodingStats | None = None, @@ -1788,6 +1796,7 @@ class Scheduler(SchedulerInterface): num_running_reqs=len(self.running), num_waiting_reqs=len(self.waiting), kv_cache_usage=self.kv_cache_manager.usage, + encoder_cache_usage=self._get_encoder_cache_usage(), prefix_cache_stats=prefix_cache_stats, connector_prefix_cache_stats=connector_prefix_cache_stats, kv_cache_eviction_events=eviction_events, @@ -1797,6 +1806,14 @@ class Scheduler(SchedulerInterface): perf_stats=perf_stats, ) + def _get_encoder_cache_usage(self) -> float: + """Get encoder cache usage as a fraction (0.0 to 1.0).""" + ecm = self.encoder_cache_manager + if ecm.cache_size == 0: + return 0.0 + used_slots = ecm.cache_size - ecm.num_free_slots + return used_slots / ecm.cache_size + def make_spec_decoding_stats( self, spec_decoding_stats: SpecDecodingStats | None, diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 952da9e40..f1788d8fc 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -882,6 +882,9 @@ class AsyncLLM(EngineClient): reset_running_requests, reset_connector ) + async def reset_encoder_cache(self) -> None: + await self.engine_core.reset_encoder_cache_async() + async def sleep(self, level: int = 1) -> None: await self.reset_prefix_cache() await self.engine_core.sleep_async(level) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index d5e75824d..216d610b4 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -565,6 +565,26 @@ class EngineCore: reset_running_requests, reset_connector ) + def reset_encoder_cache(self) -> None: + """Reset the encoder cache to invalidate all cached encoder outputs. + + This should be called when model weights are updated to ensure + stale vision embeddings computed with old weights are not reused. + Clears both the scheduler's cache manager and the GPU model runner's cache. + """ + # NOTE: Since this is mainly for debugging, we don't attempt to + # re-sync the internal caches (P0 sender, P1 receiver) + if self.scheduler.has_unfinished_requests(): + logger.warning( + "Resetting the encoder cache when requests are " + "in progress may lead to desynced internal caches." + ) + + # Reset the scheduler's encoder cache manager (logical state) + self.scheduler.reset_encoder_cache() + # Reset the GPU model runner's encoder cache (physical storage) + self.model_executor.reset_encoder_cache() + def sleep(self, level: int = 1): self.model_executor.sleep(level) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index c9a1d53c8..308086198 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -144,6 +144,9 @@ class EngineCoreClient(ABC): ) -> bool: raise NotImplementedError + def reset_encoder_cache(self) -> None: + raise NotImplementedError + def sleep(self, level: int = 1) -> None: raise NotImplementedError @@ -216,6 +219,9 @@ class EngineCoreClient(ABC): ) -> bool: raise NotImplementedError + async def reset_encoder_cache_async(self) -> None: + raise NotImplementedError + async def sleep_async(self, level: int = 1) -> None: raise NotImplementedError @@ -300,6 +306,9 @@ class InprocClient(EngineCoreClient): reset_running_requests, reset_connector ) + def reset_encoder_cache(self) -> None: + self.engine_core.reset_encoder_cache() + def sleep(self, level: int = 1) -> None: self.engine_core.sleep(level) @@ -765,6 +774,9 @@ class SyncMPClient(MPClient): "reset_prefix_cache", reset_running_requests, reset_connector ) + def reset_encoder_cache(self) -> None: + self.call_utility("reset_encoder_cache") + def add_lora(self, lora_request: LoRARequest) -> bool: return self.call_utility("add_lora", lora_request) @@ -973,6 +985,9 @@ class AsyncMPClient(MPClient): "reset_prefix_cache", reset_running_requests, reset_connector ) + async def reset_encoder_cache_async(self) -> None: + await self.call_utility_async("reset_encoder_cache") + async def sleep_async(self, level: int = 1) -> None: await self.call_utility_async("sleep", level) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 5811e94dd..8f72b6178 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -332,6 +332,14 @@ class LLMEngine: reset_running_requests, reset_connector ) + def reset_encoder_cache(self) -> None: + """Reset the encoder cache to invalidate all cached encoder outputs. + + This should be called when model weights are updated to ensure + stale vision embeddings computed with old weights are not reused. + """ + self.engine_core.reset_encoder_cache() + def sleep(self, level: int = 1): self.engine_core.sleep(level) diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 8ada52435..0fef6c1d1 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -294,6 +294,10 @@ class Executor(ABC): """Reset the multi-modal cache in each worker.""" self.collective_rpc("reset_mm_cache") + def reset_encoder_cache(self) -> None: + """Reset the encoder cache in each worker to clear cached encoder outputs.""" + self.collective_rpc("reset_encoder_cache") + def sleep(self, level: int = 1): if self.is_sleeping: logger.warning("Executor is already sleeping.") diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index cb1a860e3..3404a720e 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -173,6 +173,7 @@ class SchedulerStats: current_wave: int = 0 kv_cache_usage: float = 0.0 + encoder_cache_usage: float = 0.0 prefix_cache_stats: PrefixCacheStats = field(default_factory=PrefixCacheStats) connector_prefix_cache_stats: PrefixCacheStats | None = None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9ef024f22..95505fd1b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -720,6 +720,14 @@ class GPUModelRunner( if self.mm_budget: self.mm_budget.reset_cache() + def reset_encoder_cache(self) -> None: + """Clear the GPU-side encoder cache storing vision embeddings. + + This should be called when model weights are updated to ensure + stale embeddings computed with old weights are not reused. + """ + self.encoder_cache.clear() + @torch.inference_mode() def init_fp8_kv_scales(self) -> None: """ diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 529fb8acf..f6e59526e 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -539,6 +539,9 @@ class Worker(WorkerBase): def reset_mm_cache(self) -> None: self.model_runner.reset_mm_cache() + def reset_encoder_cache(self) -> None: + self.model_runner.reset_encoder_cache() + def get_model(self) -> nn.Module: return self.model_runner.get_model()