diff --git a/tests/compile/test_cold_start.py b/tests/compile/test_cold_start.py index 1d24d1839..dd770c583 100644 --- a/tests/compile/test_cold_start.py +++ b/tests/compile/test_cold_start.py @@ -37,12 +37,13 @@ def test_moe_compilation_cold_start(monkeypatch, use_fresh_inductor_cache): # The forward pass consists of 32 transformer layers. # Then, we split on the attention operation. This results in # 33 subgraphs (not including the attention operation). - # The 33 subgraphs then get standalone_compile'd. + # We then standalone_compile the unique subgraphs. # # There are actually only 3 unique subgraphs for this model # (all of its transformer layers are the same modulo weights); # this is true for most vLLM models. - # So we test that during cold start, the aot_autograd cache - # misses for 3 subgraphs and hits for the rest. + # So we test that during cold start, only 3 subgraphs are compiled + # These 3 subgraphs should cache miss, and then there should be + # no other compilation (so no cache hits). assert counters["aot_autograd"]["autograd_cache_miss"] == 3 - assert counters["aot_autograd"]["autograd_cache_hit"] == 30 + assert counters["aot_autograd"]["autograd_cache_hit"] == 0 diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 89981fc29..38ba97c7f 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -121,7 +121,7 @@ class CompilerManager: and compiling the graph. The cache is a dict mapping - `(runtime_shape, graph_index, backend_name)` + `(runtime_shape, graph_hash, backend_name)` to `any_data` returned from the compiler. When serializing the cache, we save it to a Python file @@ -130,7 +130,7 @@ class CompilerManager: """ def __init__(self, compilation_config: CompilationConfig) -> None: - self.cache: dict[tuple[Range, int, str], Any] = dict() + self.cache: dict[tuple[Range, str, str], Any] = dict() self.is_cache_updated = False self.compilation_config = compilation_config self.compiler = make_compiler(compilation_config) @@ -173,6 +173,7 @@ class CompilerManager: self.disable_cache = disable_cache self.cache_dir = cache_dir self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py") + self.loaded_cache_entries: dict[tuple[Range, str, str], Any] = {} if not disable_cache and os.path.exists(self.cache_file_path): # load the cache from the file @@ -186,9 +187,9 @@ class CompilerManager: if not isinstance(value, ty): raise TypeError(f"Expected {ty} but got {type(value)} for {value}") - def parse_key(key: Any) -> tuple[Range, int, str]: - range_tuple, graph_index, compiler_name = key - check_type(graph_index, int) + def parse_key(key: Any) -> tuple[Range, str, str]: + range_tuple, graph_hash, compiler_name = key + check_type(graph_hash, str) check_type(compiler_name, str) if isinstance(range_tuple, tuple): start, end = range_tuple @@ -196,7 +197,7 @@ class CompilerManager: check_type(end, int) range_tuple = Range(start=start, end=end) check_type(range_tuple, Range) - return range_tuple, graph_index, compiler_name + return range_tuple, graph_hash, compiler_name self.cache = {parse_key(key): value for key, value in cache.items()} @@ -216,18 +217,25 @@ class CompilerManager: self, graph: fx.GraphModule, example_inputs: list[Any], - graph_index: int, + graph_hash: str, compile_range: Range, ) -> Callable[..., Any] | None: - if (compile_range, graph_index, self.compiler.name) not in self.cache: + key = (compile_range, graph_hash, self.compiler.name) + # See if we've already loaded this cache entry + if key in self.loaded_cache_entries: + return self.loaded_cache_entries[key] + # Otherwise, go load it from disk + if key not in self.cache: return None - handle = self.cache[(compile_range, graph_index, self.compiler.name)] + handle = self.cache[key] compiled_graph = self.compiler.load( - handle, graph, example_inputs, graph_index, compile_range + handle, graph, example_inputs, compile_range ) + self.loaded_cache_entries[key] = compiled_graph logger.debug( - "Directly load the %s-th graph for compile range %sfrom %s via handle %s", - graph_index, + "Directly load the graph (hash %s) for compile range " + "%sfrom %s via handle %s", + graph_hash, str(compile_range), self.compiler.name, handle, @@ -249,12 +257,22 @@ class CompilerManager: global compilation_start_time compilation_start_time = time.time() + from torch._functorch._aot_autograd.autograd_cache import ( + AOTAutogradCachePickler, + sanitize_gm_for_cache, + ) + + with sanitize_gm_for_cache(graph): + pickler = AOTAutogradCachePickler(graph) + dumped_graph = pickler.dumps(graph) + graph_hash = hashlib.sha256(dumped_graph).hexdigest() + compilation_counter.num_backend_compilations += 1 compiled_graph = None # try to load from the cache - compiled_graph = self.load(graph, example_inputs, graph_index, compile_range) + compiled_graph = self.load(graph, example_inputs, graph_hash, compile_range) if compiled_graph is not None: if graph_index == num_graphs - 1: # after loading the last graph for this shape, record the time. @@ -290,9 +308,13 @@ class CompilerManager: assert compiled_graph is not None, "Failed to compile the graph" + self.loaded_cache_entries[(compile_range, graph_hash, self.compiler.name)] = ( + compiled_graph + ) + # store the artifact in the cache if is_compile_cache_enabled(additional_inductor_config) and handle is not None: - self.cache[(compile_range, graph_index, self.compiler.name)] = handle + self.cache[(compile_range, graph_hash, self.compiler.name)] = handle compilation_counter.num_cache_entries_updated += 1 self.is_cache_updated = True if graph_index == 0: diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 606503539..875e628d6 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -101,7 +101,6 @@ class CompilerInterface: handle: Any, graph: fx.GraphModule, example_inputs: list[Any], - graph_index: int, compile_range: Range, ) -> Callable[..., Any]: """ @@ -302,7 +301,6 @@ class InductorStandaloneAdaptor(CompilerInterface): handle: Any, graph: fx.GraphModule, example_inputs: list[Any], - graph_index: int, compile_range: Range, ) -> Callable[..., Any]: assert isinstance(handle, tuple) @@ -527,7 +525,6 @@ class InductorAdaptor(CompilerInterface): handle: Any, graph: fx.GraphModule, example_inputs: list[Any], - graph_index: int, compile_range: Range, ) -> Callable[..., Any]: assert isinstance(handle, tuple)