Support clear mm and encoder cache (#33452)

Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
jma99_2333
2026-01-31 07:22:25 -08:00
committed by GitHub
parent 13b842f271
commit 22d9a056d5
15 changed files with 212 additions and 1 deletions

View File

@@ -172,6 +172,7 @@ These endpoints are **only available when the environment variable `VLLM_SERVER_
- `/server_info` - Get detailed server configuration
- `/reset_prefix_cache` - Reset prefix cache (can disrupt service)
- `/reset_mm_cache` - Reset multimodal cache (can disrupt service)
- `/reset_encoder_cache` - Reset encoder cache (can disrupt service)
- `/sleep` - Put engine to sleep (causes denial of service)
- `/wake_up` - Wake engine from sleep
- `/is_sleeping` - Check if engine is sleeping

View File

@@ -4,7 +4,10 @@ import pytest
import torch
from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
from vllm.v1.core.encoder_cache_manager import (
EncoderCacheManager,
EncoderDecoderCacheManager,
)
pytestmark = pytest.mark.cpu_test
@@ -247,3 +250,88 @@ def test_encoder_cache_mask_based_retrieval():
assert num_embeds_before == 0
assert num_embeds_in_range == 2
def test_reset_clears_all_state():
"""Test that reset() clears all cached entries and restores capacity."""
manager = EncoderCacheManager(cache_size=20)
req1 = MockRequest("req1", ["img1", "img2"], [5, 3])
req2 = MockRequest("req2", ["img3"], [4])
manager.allocate(req1, 0)
manager.allocate(req1, 1)
manager.allocate(req2, 0)
manager.free_encoder_input(req1, 0)
req3 = MockRequest("req3", ["img4"], [10])
manager.free_encoder_input(req1, 1)
manager.free_encoder_input(req2, 0)
manager.can_allocate(req3, 0, int(1e9), 0)
manager.allocate(req3, 0)
assert len(manager.cached) > 0
assert manager.num_free_slots < 20
manager.reset()
assert len(manager.cached) == 0
assert len(manager.freeable) == 0
assert len(manager.freed) == 0
assert manager.num_free_slots == 20
assert manager.num_freeable_slots == 20
def test_reset_allows_fresh_allocations():
manager = EncoderCacheManager(cache_size=10)
req1 = MockRequest("req1", ["img1"], [10])
manager.allocate(req1, 0)
assert manager.num_free_slots == 0
manager.reset()
req2 = MockRequest("req2", ["img2"], [8])
assert manager.can_allocate(req2, 0, int(1e9), 0)
manager.allocate(req2, 0)
assert manager.num_free_slots == 2
assert "img2" in manager.cached
assert "img1" not in manager.cached
def test_encoder_decoder_cache_manager_reset():
manager = EncoderDecoderCacheManager(cache_size=20)
req1 = MockRequest("req1", ["img1"], [5])
req2 = MockRequest("req2", ["img2"], [3])
manager.allocate(req1, 0)
manager.allocate(req2, 0)
manager.free(req1)
manager.get_freed_mm_hashes()
assert manager.num_free_slots < 20
manager.reset()
assert len(manager.allocated) == 0
assert len(manager.to_free) == 0
assert manager.num_free_slots == 20
def test_encoder_decoder_cache_manager_reset_allows_fresh_allocations():
manager = EncoderDecoderCacheManager(cache_size=10)
req1 = MockRequest("req1", ["img1"], [10])
manager.allocate(req1, 0)
assert manager.num_free_slots == 0
manager.reset()
req2 = MockRequest("req2", ["img2"], [8])
assert manager.can_allocate(req2, 0, int(1e9), 0)
manager.allocate(req2, 0)
assert manager.num_free_slots == 2
assert "img2" in manager.allocated

View File

@@ -113,6 +113,11 @@ class EngineClient(ABC):
"""Reset the multi-modal cache"""
...
@abstractmethod
async def reset_encoder_cache(self) -> None:
"""Reset the encoder cache"""
...
@abstractmethod
async def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False

View File

@@ -55,6 +55,17 @@ async def reset_mm_cache(raw_request: Request):
return Response(status_code=200)
@router.post("/reset_encoder_cache")
async def reset_encoder_cache(raw_request: Request):
"""
Reset the encoder cache. Note that we currently do not check if the
encoder cache is successfully reset in the API server.
"""
logger.info("Resetting encoder cache...")
await engine_client(raw_request).reset_encoder_cache()
return Response(status_code=200)
def attach_router(app: FastAPI):
if not envs.VLLM_SERVER_DEV_MODE:
return

View File

@@ -77,6 +77,18 @@ class EncoderCacheManager:
self.freeable: OrderedDict[str, int] = OrderedDict()
self.freed: list[str] = []
def reset(self) -> None:
"""Reset the encoder cache to its initial state.
This clears all cached encoder outputs and resets capacity tracking.
Called when model weights are updated to invalidate stale embeddings.
"""
self.cached.clear()
self.freeable.clear()
self.freed.clear()
self.num_free_slots = self.cache_size
self.num_freeable_slots = self.cache_size
def check_and_update_cache(self, request: Request, input_id: int) -> bool:
"""Check if encoder output for a specific multimodal input is cached.
@@ -360,6 +372,12 @@ class EncoderDecoderCacheManager(EncoderCacheManager):
self.allocated: list[str] = []
self.to_free: list[str] = []
def reset(self) -> None:
"""Reset the encoder cache to its initial state."""
self.num_free_slots = self.cache_size
self.allocated.clear()
self.to_free.clear()
def check_and_update_cache(self, request: Request, input_id: int) -> bool:
return False

View File

@@ -183,6 +183,15 @@ class SchedulerInterface(ABC):
"""
raise NotImplementedError
@abstractmethod
def reset_encoder_cache(self) -> None:
"""Reset the encoder cache to invalidate all cached encoder outputs.
This should be called when model weights are updated to ensure
stale vision embeddings are not reused.
"""
raise NotImplementedError
@abstractmethod
def get_request_counts(self) -> tuple[int, int]:
"""Returns (num_running_reqs, num_waiting_reqs)."""

View File

@@ -1763,6 +1763,14 @@ class Scheduler(SchedulerInterface):
return True
def reset_encoder_cache(self) -> None:
"""Reset the encoder cache to invalidate all cached encoder outputs.
This should be called when model weights are updated to ensure
stale vision embeddings are not reused.
"""
self.encoder_cache_manager.reset()
def make_stats(
self,
spec_decoding_stats: SpecDecodingStats | None = None,
@@ -1788,6 +1796,7 @@ class Scheduler(SchedulerInterface):
num_running_reqs=len(self.running),
num_waiting_reqs=len(self.waiting),
kv_cache_usage=self.kv_cache_manager.usage,
encoder_cache_usage=self._get_encoder_cache_usage(),
prefix_cache_stats=prefix_cache_stats,
connector_prefix_cache_stats=connector_prefix_cache_stats,
kv_cache_eviction_events=eviction_events,
@@ -1797,6 +1806,14 @@ class Scheduler(SchedulerInterface):
perf_stats=perf_stats,
)
def _get_encoder_cache_usage(self) -> float:
"""Get encoder cache usage as a fraction (0.0 to 1.0)."""
ecm = self.encoder_cache_manager
if ecm.cache_size == 0:
return 0.0
used_slots = ecm.cache_size - ecm.num_free_slots
return used_slots / ecm.cache_size
def make_spec_decoding_stats(
self,
spec_decoding_stats: SpecDecodingStats | None,

View File

@@ -882,6 +882,9 @@ class AsyncLLM(EngineClient):
reset_running_requests, reset_connector
)
async def reset_encoder_cache(self) -> None:
await self.engine_core.reset_encoder_cache_async()
async def sleep(self, level: int = 1) -> None:
await self.reset_prefix_cache()
await self.engine_core.sleep_async(level)

View File

@@ -565,6 +565,26 @@ class EngineCore:
reset_running_requests, reset_connector
)
def reset_encoder_cache(self) -> None:
"""Reset the encoder cache to invalidate all cached encoder outputs.
This should be called when model weights are updated to ensure
stale vision embeddings computed with old weights are not reused.
Clears both the scheduler's cache manager and the GPU model runner's cache.
"""
# NOTE: Since this is mainly for debugging, we don't attempt to
# re-sync the internal caches (P0 sender, P1 receiver)
if self.scheduler.has_unfinished_requests():
logger.warning(
"Resetting the encoder cache when requests are "
"in progress may lead to desynced internal caches."
)
# Reset the scheduler's encoder cache manager (logical state)
self.scheduler.reset_encoder_cache()
# Reset the GPU model runner's encoder cache (physical storage)
self.model_executor.reset_encoder_cache()
def sleep(self, level: int = 1):
self.model_executor.sleep(level)

View File

@@ -144,6 +144,9 @@ class EngineCoreClient(ABC):
) -> bool:
raise NotImplementedError
def reset_encoder_cache(self) -> None:
raise NotImplementedError
def sleep(self, level: int = 1) -> None:
raise NotImplementedError
@@ -216,6 +219,9 @@ class EngineCoreClient(ABC):
) -> bool:
raise NotImplementedError
async def reset_encoder_cache_async(self) -> None:
raise NotImplementedError
async def sleep_async(self, level: int = 1) -> None:
raise NotImplementedError
@@ -300,6 +306,9 @@ class InprocClient(EngineCoreClient):
reset_running_requests, reset_connector
)
def reset_encoder_cache(self) -> None:
self.engine_core.reset_encoder_cache()
def sleep(self, level: int = 1) -> None:
self.engine_core.sleep(level)
@@ -765,6 +774,9 @@ class SyncMPClient(MPClient):
"reset_prefix_cache", reset_running_requests, reset_connector
)
def reset_encoder_cache(self) -> None:
self.call_utility("reset_encoder_cache")
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.call_utility("add_lora", lora_request)
@@ -973,6 +985,9 @@ class AsyncMPClient(MPClient):
"reset_prefix_cache", reset_running_requests, reset_connector
)
async def reset_encoder_cache_async(self) -> None:
await self.call_utility_async("reset_encoder_cache")
async def sleep_async(self, level: int = 1) -> None:
await self.call_utility_async("sleep", level)

View File

@@ -332,6 +332,14 @@ class LLMEngine:
reset_running_requests, reset_connector
)
def reset_encoder_cache(self) -> None:
"""Reset the encoder cache to invalidate all cached encoder outputs.
This should be called when model weights are updated to ensure
stale vision embeddings computed with old weights are not reused.
"""
self.engine_core.reset_encoder_cache()
def sleep(self, level: int = 1):
self.engine_core.sleep(level)

View File

@@ -294,6 +294,10 @@ class Executor(ABC):
"""Reset the multi-modal cache in each worker."""
self.collective_rpc("reset_mm_cache")
def reset_encoder_cache(self) -> None:
"""Reset the encoder cache in each worker to clear cached encoder outputs."""
self.collective_rpc("reset_encoder_cache")
def sleep(self, level: int = 1):
if self.is_sleeping:
logger.warning("Executor is already sleeping.")

View File

@@ -173,6 +173,7 @@ class SchedulerStats:
current_wave: int = 0
kv_cache_usage: float = 0.0
encoder_cache_usage: float = 0.0
prefix_cache_stats: PrefixCacheStats = field(default_factory=PrefixCacheStats)
connector_prefix_cache_stats: PrefixCacheStats | None = None

View File

@@ -720,6 +720,14 @@ class GPUModelRunner(
if self.mm_budget:
self.mm_budget.reset_cache()
def reset_encoder_cache(self) -> None:
"""Clear the GPU-side encoder cache storing vision embeddings.
This should be called when model weights are updated to ensure
stale embeddings computed with old weights are not reused.
"""
self.encoder_cache.clear()
@torch.inference_mode()
def init_fp8_kv_scales(self) -> None:
"""

View File

@@ -539,6 +539,9 @@ class Worker(WorkerBase):
def reset_mm_cache(self) -> None:
self.model_runner.reset_mm_cache()
def reset_encoder_cache(self) -> None:
self.model_runner.reset_encoder_cache()
def get_model(self) -> nn.Module:
return self.model_runner.get_model()