[compile] Fix extra cache save on warm start. (#35921)
Signed-off-by: zhxchen17 <zhxchen17@fb.com>
This commit is contained in:
@@ -61,11 +61,11 @@ def test_moe_startup(monkeypatch, vllm_runner, fresh_vllm_cache):
|
||||
counters.clear()
|
||||
with compilation_counter.expect(
|
||||
num_compiled_artifacts_loaded=3,
|
||||
# TODO: warm start should not save any artifacts
|
||||
# https://github.com/vllm-project/vllm/issues/35708
|
||||
num_compiled_artifacts_saved=1,
|
||||
num_compiled_artifacts_saved=0,
|
||||
):
|
||||
_run_vllm(vllm_runner)
|
||||
assert counters["aot_autograd"]["total"] == 30
|
||||
assert counters["aot_autograd"]["autograd_cache_miss"] == 0
|
||||
assert counters["aot_autograd"]["autograd_cache_hit"] == 1
|
||||
assert (
|
||||
counters["aot_autograd"]["autograd_cache_hit"] == 0
|
||||
) # No miss at aot_autograd level causing disk I/O.
|
||||
|
||||
@@ -221,10 +221,28 @@ class CompilerManager:
|
||||
) -> Callable[..., Any] | None:
|
||||
if (compile_range, graph_index, self.compiler.name) not in self.cache:
|
||||
return None
|
||||
handle = self.cache[(compile_range, graph_index, self.compiler.name)]
|
||||
|
||||
def parse_value(value: Any) -> tuple[tuple[str, str], str]:
|
||||
assert isinstance(value, dict)
|
||||
handle = value["graph_handle"]
|
||||
assert isinstance(handle[0], str)
|
||||
assert isinstance(handle[1], str)
|
||||
cache_key = value["cache_key"]
|
||||
return handle, cache_key
|
||||
|
||||
try:
|
||||
handle, cache_key = parse_value(
|
||||
self.cache[(compile_range, graph_index, self.compiler.name)]
|
||||
)
|
||||
except Exception:
|
||||
# When the cache is outdated, we should ignore the existing file.
|
||||
# This should cause the correct cache to be generated again.
|
||||
return None
|
||||
|
||||
compiled_graph = self.compiler.load(
|
||||
handle, graph, example_inputs, graph_index, compile_range
|
||||
)
|
||||
self.loaded_artifacts[cache_key] = compiled_graph
|
||||
logger.debug(
|
||||
"Directly load the %s-th graph for compile range %sfrom %s via handle %s",
|
||||
graph_index,
|
||||
@@ -341,7 +359,10 @@ class CompilerManager:
|
||||
|
||||
# store the artifact in the cache
|
||||
if is_compile_cache_enabled(additional_inductor_config) and handle is not None:
|
||||
self.cache[(compile_range, graph_index, self.compiler.name)] = handle
|
||||
self.cache[(compile_range, graph_index, self.compiler.name)] = {
|
||||
"graph_handle": handle,
|
||||
"cache_key": cache_key,
|
||||
}
|
||||
compilation_counter.num_cache_entries_updated += 1
|
||||
self.is_cache_updated = True
|
||||
if graph_index == 0:
|
||||
|
||||
Reference in New Issue
Block a user