[compile] Fix extra cache save on warm start. (#35921)

Signed-off-by: zhxchen17 <zhxchen17@fb.com>
This commit is contained in:
Zhengxu Chen
2026-03-04 23:56:30 -05:00
committed by GitHub
parent 26366009c5
commit dd6dbd93f8
2 changed files with 27 additions and 6 deletions

View File

@@ -61,11 +61,11 @@ def test_moe_startup(monkeypatch, vllm_runner, fresh_vllm_cache):
counters.clear()
with compilation_counter.expect(
num_compiled_artifacts_loaded=3,
# TODO: warm start should not save any artifacts
# https://github.com/vllm-project/vllm/issues/35708
num_compiled_artifacts_saved=1,
num_compiled_artifacts_saved=0,
):
_run_vllm(vllm_runner)
assert counters["aot_autograd"]["total"] == 30
assert counters["aot_autograd"]["autograd_cache_miss"] == 0
assert counters["aot_autograd"]["autograd_cache_hit"] == 1
assert (
counters["aot_autograd"]["autograd_cache_hit"] == 0
) # No miss at aot_autograd level causing disk I/O.

View File

@@ -221,10 +221,28 @@ class CompilerManager:
) -> Callable[..., Any] | None:
if (compile_range, graph_index, self.compiler.name) not in self.cache:
return None
handle = self.cache[(compile_range, graph_index, self.compiler.name)]
def parse_value(value: Any) -> tuple[tuple[str, str], str]:
assert isinstance(value, dict)
handle = value["graph_handle"]
assert isinstance(handle[0], str)
assert isinstance(handle[1], str)
cache_key = value["cache_key"]
return handle, cache_key
try:
handle, cache_key = parse_value(
self.cache[(compile_range, graph_index, self.compiler.name)]
)
except Exception:
# When the cache is outdated, we should ignore the existing file.
# This should cause the correct cache to be generated again.
return None
compiled_graph = self.compiler.load(
handle, graph, example_inputs, graph_index, compile_range
)
self.loaded_artifacts[cache_key] = compiled_graph
logger.debug(
"Directly load the %s-th graph for compile range %sfrom %s via handle %s",
graph_index,
@@ -341,7 +359,10 @@ class CompilerManager:
# store the artifact in the cache
if is_compile_cache_enabled(additional_inductor_config) and handle is not None:
self.cache[(compile_range, graph_index, self.compiler.name)] = handle
self.cache[(compile_range, graph_index, self.compiler.name)] = {
"graph_handle": handle,
"cache_key": cache_key,
}
compilation_counter.num_cache_entries_updated += 1
self.is_cache_updated = True
if graph_index == 0: