diff --git a/tests/compile/test_startup.py b/tests/compile/test_startup.py index acdce9d0b..545299565 100644 --- a/tests/compile/test_startup.py +++ b/tests/compile/test_startup.py @@ -61,11 +61,11 @@ def test_moe_startup(monkeypatch, vllm_runner, fresh_vllm_cache): counters.clear() with compilation_counter.expect( num_compiled_artifacts_loaded=3, - # TODO: warm start should not save any artifacts - # https://github.com/vllm-project/vllm/issues/35708 - num_compiled_artifacts_saved=1, + num_compiled_artifacts_saved=0, ): _run_vllm(vllm_runner) assert counters["aot_autograd"]["total"] == 30 assert counters["aot_autograd"]["autograd_cache_miss"] == 0 - assert counters["aot_autograd"]["autograd_cache_hit"] == 1 + assert ( + counters["aot_autograd"]["autograd_cache_hit"] == 0 + ) # No miss at aot_autograd level causing disk I/O. diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 7b493d9b9..9d37a5331 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -221,10 +221,28 @@ class CompilerManager: ) -> Callable[..., Any] | None: if (compile_range, graph_index, self.compiler.name) not in self.cache: return None - handle = self.cache[(compile_range, graph_index, self.compiler.name)] + + def parse_value(value: Any) -> tuple[tuple[str, str], str]: + assert isinstance(value, dict) + handle = value["graph_handle"] + assert isinstance(handle[0], str) + assert isinstance(handle[1], str) + cache_key = value["cache_key"] + return handle, cache_key + + try: + handle, cache_key = parse_value( + self.cache[(compile_range, graph_index, self.compiler.name)] + ) + except Exception: + # When the cache is outdated, we should ignore the existing file. + # This should cause the correct cache to be generated again. + return None + compiled_graph = self.compiler.load( handle, graph, example_inputs, graph_index, compile_range ) + self.loaded_artifacts[cache_key] = compiled_graph logger.debug( "Directly load the %s-th graph for compile range %sfrom %s via handle %s", graph_index, @@ -341,7 +359,10 @@ class CompilerManager: # store the artifact in the cache if is_compile_cache_enabled(additional_inductor_config) and handle is not None: - self.cache[(compile_range, graph_index, self.compiler.name)] = handle + self.cache[(compile_range, graph_index, self.compiler.name)] = { + "graph_handle": handle, + "cache_key": cache_key, + } compilation_counter.num_cache_entries_updated += 1 self.is_cache_updated = True if graph_index == 0: