diff --git a/tests/compile/test_cold_start.py b/tests/compile/test_cold_start.py index dd770c583..1d24d1839 100644 --- a/tests/compile/test_cold_start.py +++ b/tests/compile/test_cold_start.py @@ -37,13 +37,12 @@ def test_moe_compilation_cold_start(monkeypatch, use_fresh_inductor_cache): # The forward pass consists of 32 transformer layers. # Then, we split on the attention operation. This results in # 33 subgraphs (not including the attention operation). - # We then standalone_compile the unique subgraphs. + # The 33 subgraphs then get standalone_compile'd. # # There are actually only 3 unique subgraphs for this model # (all of its transformer layers are the same modulo weights); # this is true for most vLLM models. - # So we test that during cold start, only 3 subgraphs are compiled - # These 3 subgraphs should cache miss, and then there should be - # no other compilation (so no cache hits). + # So we test that during cold start, the aot_autograd cache + # misses for 3 subgraphs and hits for the rest. assert counters["aot_autograd"]["autograd_cache_miss"] == 3 - assert counters["aot_autograd"]["autograd_cache_hit"] == 0 + assert counters["aot_autograd"]["autograd_cache_hit"] == 30 diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 38ba97c7f..89981fc29 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -121,7 +121,7 @@ class CompilerManager: and compiling the graph. The cache is a dict mapping - `(runtime_shape, graph_hash, backend_name)` + `(runtime_shape, graph_index, backend_name)` to `any_data` returned from the compiler. When serializing the cache, we save it to a Python file @@ -130,7 +130,7 @@ class CompilerManager: """ def __init__(self, compilation_config: CompilationConfig) -> None: - self.cache: dict[tuple[Range, str, str], Any] = dict() + self.cache: dict[tuple[Range, int, str], Any] = dict() self.is_cache_updated = False self.compilation_config = compilation_config self.compiler = make_compiler(compilation_config) @@ -173,7 +173,6 @@ class CompilerManager: self.disable_cache = disable_cache self.cache_dir = cache_dir self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py") - self.loaded_cache_entries: dict[tuple[Range, str, str], Any] = {} if not disable_cache and os.path.exists(self.cache_file_path): # load the cache from the file @@ -187,9 +186,9 @@ class CompilerManager: if not isinstance(value, ty): raise TypeError(f"Expected {ty} but got {type(value)} for {value}") - def parse_key(key: Any) -> tuple[Range, str, str]: - range_tuple, graph_hash, compiler_name = key - check_type(graph_hash, str) + def parse_key(key: Any) -> tuple[Range, int, str]: + range_tuple, graph_index, compiler_name = key + check_type(graph_index, int) check_type(compiler_name, str) if isinstance(range_tuple, tuple): start, end = range_tuple @@ -197,7 +196,7 @@ class CompilerManager: check_type(end, int) range_tuple = Range(start=start, end=end) check_type(range_tuple, Range) - return range_tuple, graph_hash, compiler_name + return range_tuple, graph_index, compiler_name self.cache = {parse_key(key): value for key, value in cache.items()} @@ -217,25 +216,18 @@ class CompilerManager: self, graph: fx.GraphModule, example_inputs: list[Any], - graph_hash: str, + graph_index: int, compile_range: Range, ) -> Callable[..., Any] | None: - key = (compile_range, graph_hash, self.compiler.name) - # See if we've already loaded this cache entry - if key in self.loaded_cache_entries: - return self.loaded_cache_entries[key] - # Otherwise, go load it from disk - if key not in self.cache: + if (compile_range, graph_index, self.compiler.name) not in self.cache: return None - handle = self.cache[key] + handle = self.cache[(compile_range, graph_index, self.compiler.name)] compiled_graph = self.compiler.load( - handle, graph, example_inputs, compile_range + handle, graph, example_inputs, graph_index, compile_range ) - self.loaded_cache_entries[key] = compiled_graph logger.debug( - "Directly load the graph (hash %s) for compile range " - "%sfrom %s via handle %s", - graph_hash, + "Directly load the %s-th graph for compile range %sfrom %s via handle %s", + graph_index, str(compile_range), self.compiler.name, handle, @@ -257,22 +249,12 @@ class CompilerManager: global compilation_start_time compilation_start_time = time.time() - from torch._functorch._aot_autograd.autograd_cache import ( - AOTAutogradCachePickler, - sanitize_gm_for_cache, - ) - - with sanitize_gm_for_cache(graph): - pickler = AOTAutogradCachePickler(graph) - dumped_graph = pickler.dumps(graph) - graph_hash = hashlib.sha256(dumped_graph).hexdigest() - compilation_counter.num_backend_compilations += 1 compiled_graph = None # try to load from the cache - compiled_graph = self.load(graph, example_inputs, graph_hash, compile_range) + compiled_graph = self.load(graph, example_inputs, graph_index, compile_range) if compiled_graph is not None: if graph_index == num_graphs - 1: # after loading the last graph for this shape, record the time. @@ -308,13 +290,9 @@ class CompilerManager: assert compiled_graph is not None, "Failed to compile the graph" - self.loaded_cache_entries[(compile_range, graph_hash, self.compiler.name)] = ( - compiled_graph - ) - # store the artifact in the cache if is_compile_cache_enabled(additional_inductor_config) and handle is not None: - self.cache[(compile_range, graph_hash, self.compiler.name)] = handle + self.cache[(compile_range, graph_index, self.compiler.name)] = handle compilation_counter.num_cache_entries_updated += 1 self.is_cache_updated = True if graph_index == 0: diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 875e628d6..606503539 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -101,6 +101,7 @@ class CompilerInterface: handle: Any, graph: fx.GraphModule, example_inputs: list[Any], + graph_index: int, compile_range: Range, ) -> Callable[..., Any]: """ @@ -301,6 +302,7 @@ class InductorStandaloneAdaptor(CompilerInterface): handle: Any, graph: fx.GraphModule, example_inputs: list[Any], + graph_index: int, compile_range: Range, ) -> Callable[..., Any]: assert isinstance(handle, tuple) @@ -525,6 +527,7 @@ class InductorAdaptor(CompilerInterface): handle: Any, graph: fx.GraphModule, example_inputs: list[Any], + graph_index: int, compile_range: Range, ) -> Callable[..., Any]: assert isinstance(handle, tuple)