diff --git a/docs/benchmarking/sweeps.md b/docs/benchmarking/sweeps.md index 467d198ec..d56d8ab45 100644 --- a/docs/benchmarking/sweeps.md +++ b/docs/benchmarking/sweeps.md @@ -82,7 +82,7 @@ vllm bench sweep serve \ You can use `--dry-run` to preview the commands to be run. We only start the server once for each `--serve-params`, and keep it running for multiple `--bench-params`. - Between each benchmark run, we call the `/reset_prefix_cache` and `/reset_mm_cache` endpoints to get a clean slate for the next run. + Between each benchmark run, we call all `/reset_*_cache` endpoints to get a clean slate for the next run. In case you are using a custom `--serve-cmd`, you can override the commands used for resetting the state by setting `--after-bench-cmd`. !!! note diff --git a/vllm/benchmarks/sweep/server.py b/vllm/benchmarks/sweep/server.py index 6c6c0abcb..87d841ac8 100644 --- a/vllm/benchmarks/sweep/server.py +++ b/vllm/benchmarks/sweep/server.py @@ -12,6 +12,12 @@ from typing_extensions import Self class ServerProcess: + VLLM_RESET_CACHE_ENDPOINTS = [ + "/reset_prefix_cache", + "/reset_mm_cache", + "/reset_encoder_cache", + ] + def __init__( self, server_cmd: list[str], @@ -120,11 +126,9 @@ class ServerProcess: server_address = self._get_vllm_server_address() print(f"Resetting caches at {server_address}") - res = requests.post(f"{server_address}/reset_prefix_cache") - res.raise_for_status() - - res = requests.post(f"{server_address}/reset_mm_cache") - res.raise_for_status() + for endpoint in self.VLLM_RESET_CACHE_ENDPOINTS: + res = requests.post(server_address + endpoint) + res.raise_for_status() elif server_cmd[0].endswith("infinity_emb"): if "--vector-disk-cache" in server_cmd: raise NotImplementedError( diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index 7e2f65d26..df8e7b19f 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -286,10 +286,6 @@ class OpenAIServing: raise TypeError(f"{reasoning_parser_name=} has not been registered") from e return parser - async def reset_mm_cache(self) -> None: - self.input_processor.clear_mm_cache() - await self.engine_client.reset_mm_cache() - async def beam_search( self, prompt: PromptType, diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index d0df49f53..f1a3e341f 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -741,6 +741,7 @@ class AsyncLLM(EngineClient): if clear_cache: await self.reset_prefix_cache() await self.reset_mm_cache() + await self.reset_encoder_cache() async def resume_generation(self) -> None: """Resume generation after :meth:`pause_generation`.""" diff --git a/vllm/v1/worker/gpu/mm/encoder_runner.py b/vllm/v1/worker/gpu/mm/encoder_runner.py index bfe0bf1a3..817bdd5d8 100644 --- a/vllm/v1/worker/gpu/mm/encoder_runner.py +++ b/vllm/v1/worker/gpu/mm/encoder_runner.py @@ -31,6 +31,22 @@ class EncoderRunner: self.req_id_to_mm_features: dict[str, list[MultiModalFeatureSpec]] = {} self.encoder_cache: dict[str, torch.Tensor] = {} + def reset_mm_cache(self) -> None: + """ + Clear the multi-modal cache that was used during profiling, + but no longer needed during inference. + """ + # TODO: Implement MM budget for encoder dummy run + pass + + def reset_encoder_cache(self) -> None: + """Clear the GPU-side encoder cache storing vision embeddings. + + This should be called when model weights are updated to ensure + stale embeddings computed with old weights are not reused. + """ + self.encoder_cache.clear() + def add_request(self, req_id: str, mm_features: list[MultiModalFeatureSpec]): self.req_id_to_mm_features[req_id] = mm_features diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 900ffaf5c..66a9d6b96 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -339,7 +339,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): gc.collect() def reset_mm_cache(self) -> None: - pass + self.encoder_runner.reset_mm_cache() + + def reset_encoder_cache(self) -> None: + self.encoder_runner.reset_encoder_cache() def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: # SP is not supported yet. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 95505fd1b..61e166133 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -717,6 +717,10 @@ class GPUModelRunner( self.effective_drafter_max_model_len = self.max_model_len def reset_mm_cache(self) -> None: + """ + Clear the multi-modal cache that was used during profiling, + but no longer needed during inference. + """ if self.mm_budget: self.mm_budget.reset_cache()