[torch.compile] Significantly speed up cold start times (#33641)

Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
Richard Zou
2026-02-03 09:12:11 -08:00
committed by GitHub
parent 2267cb1cfd
commit b1bb18de8d
3 changed files with 41 additions and 21 deletions

View File

@@ -37,12 +37,13 @@ def test_moe_compilation_cold_start(monkeypatch, use_fresh_inductor_cache):
# The forward pass consists of 32 transformer layers.
# Then, we split on the attention operation. This results in
# 33 subgraphs (not including the attention operation).
# The 33 subgraphs then get standalone_compile'd.
# We then standalone_compile the unique subgraphs.
#
# There are actually only 3 unique subgraphs for this model
# (all of its transformer layers are the same modulo weights);
# this is true for most vLLM models.
# So we test that during cold start, the aot_autograd cache
# misses for 3 subgraphs and hits for the rest.
# So we test that during cold start, only 3 subgraphs are compiled
# These 3 subgraphs should cache miss, and then there should be
# no other compilation (so no cache hits).
assert counters["aot_autograd"]["autograd_cache_miss"] == 3
assert counters["aot_autograd"]["autograd_cache_hit"] == 30
assert counters["aot_autograd"]["autograd_cache_hit"] == 0

View File

@@ -121,7 +121,7 @@ class CompilerManager:
and compiling the graph.
The cache is a dict mapping
`(runtime_shape, graph_index, backend_name)`
`(runtime_shape, graph_hash, backend_name)`
to `any_data` returned from the compiler.
When serializing the cache, we save it to a Python file
@@ -130,7 +130,7 @@ class CompilerManager:
"""
def __init__(self, compilation_config: CompilationConfig) -> None:
self.cache: dict[tuple[Range, int, str], Any] = dict()
self.cache: dict[tuple[Range, str, str], Any] = dict()
self.is_cache_updated = False
self.compilation_config = compilation_config
self.compiler = make_compiler(compilation_config)
@@ -173,6 +173,7 @@ class CompilerManager:
self.disable_cache = disable_cache
self.cache_dir = cache_dir
self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")
self.loaded_cache_entries: dict[tuple[Range, str, str], Any] = {}
if not disable_cache and os.path.exists(self.cache_file_path):
# load the cache from the file
@@ -186,9 +187,9 @@ class CompilerManager:
if not isinstance(value, ty):
raise TypeError(f"Expected {ty} but got {type(value)} for {value}")
def parse_key(key: Any) -> tuple[Range, int, str]:
range_tuple, graph_index, compiler_name = key
check_type(graph_index, int)
def parse_key(key: Any) -> tuple[Range, str, str]:
range_tuple, graph_hash, compiler_name = key
check_type(graph_hash, str)
check_type(compiler_name, str)
if isinstance(range_tuple, tuple):
start, end = range_tuple
@@ -196,7 +197,7 @@ class CompilerManager:
check_type(end, int)
range_tuple = Range(start=start, end=end)
check_type(range_tuple, Range)
return range_tuple, graph_index, compiler_name
return range_tuple, graph_hash, compiler_name
self.cache = {parse_key(key): value for key, value in cache.items()}
@@ -216,18 +217,25 @@ class CompilerManager:
self,
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
graph_hash: str,
compile_range: Range,
) -> Callable[..., Any] | None:
if (compile_range, graph_index, self.compiler.name) not in self.cache:
key = (compile_range, graph_hash, self.compiler.name)
# See if we've already loaded this cache entry
if key in self.loaded_cache_entries:
return self.loaded_cache_entries[key]
# Otherwise, go load it from disk
if key not in self.cache:
return None
handle = self.cache[(compile_range, graph_index, self.compiler.name)]
handle = self.cache[key]
compiled_graph = self.compiler.load(
handle, graph, example_inputs, graph_index, compile_range
handle, graph, example_inputs, compile_range
)
self.loaded_cache_entries[key] = compiled_graph
logger.debug(
"Directly load the %s-th graph for compile range %sfrom %s via handle %s",
graph_index,
"Directly load the graph (hash %s) for compile range "
"%sfrom %s via handle %s",
graph_hash,
str(compile_range),
self.compiler.name,
handle,
@@ -249,12 +257,22 @@ class CompilerManager:
global compilation_start_time
compilation_start_time = time.time()
from torch._functorch._aot_autograd.autograd_cache import (
AOTAutogradCachePickler,
sanitize_gm_for_cache,
)
with sanitize_gm_for_cache(graph):
pickler = AOTAutogradCachePickler(graph)
dumped_graph = pickler.dumps(graph)
graph_hash = hashlib.sha256(dumped_graph).hexdigest()
compilation_counter.num_backend_compilations += 1
compiled_graph = None
# try to load from the cache
compiled_graph = self.load(graph, example_inputs, graph_index, compile_range)
compiled_graph = self.load(graph, example_inputs, graph_hash, compile_range)
if compiled_graph is not None:
if graph_index == num_graphs - 1:
# after loading the last graph for this shape, record the time.
@@ -290,9 +308,13 @@ class CompilerManager:
assert compiled_graph is not None, "Failed to compile the graph"
self.loaded_cache_entries[(compile_range, graph_hash, self.compiler.name)] = (
compiled_graph
)
# store the artifact in the cache
if is_compile_cache_enabled(additional_inductor_config) and handle is not None:
self.cache[(compile_range, graph_index, self.compiler.name)] = handle
self.cache[(compile_range, graph_hash, self.compiler.name)] = handle
compilation_counter.num_cache_entries_updated += 1
self.is_cache_updated = True
if graph_index == 0:

View File

@@ -101,7 +101,6 @@ class CompilerInterface:
handle: Any,
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
compile_range: Range,
) -> Callable[..., Any]:
"""
@@ -302,7 +301,6 @@ class InductorStandaloneAdaptor(CompilerInterface):
handle: Any,
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
compile_range: Range,
) -> Callable[..., Any]:
assert isinstance(handle, tuple)
@@ -527,7 +525,6 @@ class InductorAdaptor(CompilerInterface):
handle: Any,
graph: fx.GraphModule,
example_inputs: list[Any],
graph_index: int,
compile_range: Range,
) -> Callable[..., Any]:
assert isinstance(handle, tuple)