[Bugfix] Fix inconsistent handling of cache reset (#33481)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-02-01 12:23:41 +08:00
committed by GitHub
parent d6416fdde9
commit 79b6ec6aab
7 changed files with 35 additions and 11 deletions

View File

@@ -82,7 +82,7 @@ vllm bench sweep serve \
You can use `--dry-run` to preview the commands to be run.
We only start the server once for each `--serve-params`, and keep it running for multiple `--bench-params`.
Between each benchmark run, we call the `/reset_prefix_cache` and `/reset_mm_cache` endpoints to get a clean slate for the next run.
Between each benchmark run, we call all `/reset_*_cache` endpoints to get a clean slate for the next run.
In case you are using a custom `--serve-cmd`, you can override the commands used for resetting the state by setting `--after-bench-cmd`.
!!! note

View File

@@ -12,6 +12,12 @@ from typing_extensions import Self
class ServerProcess:
VLLM_RESET_CACHE_ENDPOINTS = [
"/reset_prefix_cache",
"/reset_mm_cache",
"/reset_encoder_cache",
]
def __init__(
self,
server_cmd: list[str],
@@ -120,11 +126,9 @@ class ServerProcess:
server_address = self._get_vllm_server_address()
print(f"Resetting caches at {server_address}")
res = requests.post(f"{server_address}/reset_prefix_cache")
res.raise_for_status()
res = requests.post(f"{server_address}/reset_mm_cache")
res.raise_for_status()
for endpoint in self.VLLM_RESET_CACHE_ENDPOINTS:
res = requests.post(server_address + endpoint)
res.raise_for_status()
elif server_cmd[0].endswith("infinity_emb"):
if "--vector-disk-cache" in server_cmd:
raise NotImplementedError(

View File

@@ -286,10 +286,6 @@ class OpenAIServing:
raise TypeError(f"{reasoning_parser_name=} has not been registered") from e
return parser
async def reset_mm_cache(self) -> None:
self.input_processor.clear_mm_cache()
await self.engine_client.reset_mm_cache()
async def beam_search(
self,
prompt: PromptType,

View File

@@ -741,6 +741,7 @@ class AsyncLLM(EngineClient):
if clear_cache:
await self.reset_prefix_cache()
await self.reset_mm_cache()
await self.reset_encoder_cache()
async def resume_generation(self) -> None:
"""Resume generation after :meth:`pause_generation`."""

View File

@@ -31,6 +31,22 @@ class EncoderRunner:
self.req_id_to_mm_features: dict[str, list[MultiModalFeatureSpec]] = {}
self.encoder_cache: dict[str, torch.Tensor] = {}
def reset_mm_cache(self) -> None:
"""
Clear the multi-modal cache that was used during profiling,
but no longer needed during inference.
"""
# TODO: Implement MM budget for encoder dummy run
pass
def reset_encoder_cache(self) -> None:
"""Clear the GPU-side encoder cache storing vision embeddings.
This should be called when model weights are updated to ensure
stale embeddings computed with old weights are not reused.
"""
self.encoder_cache.clear()
def add_request(self, req_id: str, mm_features: list[MultiModalFeatureSpec]):
self.req_id_to_mm_features[req_id] = mm_features

View File

@@ -339,7 +339,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
gc.collect()
def reset_mm_cache(self) -> None:
pass
self.encoder_runner.reset_mm_cache()
def reset_encoder_cache(self) -> None:
self.encoder_runner.reset_encoder_cache()
def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
# SP is not supported yet.

View File

@@ -717,6 +717,10 @@ class GPUModelRunner(
self.effective_drafter_max_model_len = self.max_model_len
def reset_mm_cache(self) -> None:
"""
Clear the multi-modal cache that was used during profiling,
but no longer needed during inference.
"""
if self.mm_budget:
self.mm_budget.reset_cache()