[torch.compile] Significantly speed up cold start times (#33641)
Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
@@ -37,12 +37,13 @@ def test_moe_compilation_cold_start(monkeypatch, use_fresh_inductor_cache):
|
||||
# The forward pass consists of 32 transformer layers.
|
||||
# Then, we split on the attention operation. This results in
|
||||
# 33 subgraphs (not including the attention operation).
|
||||
# The 33 subgraphs then get standalone_compile'd.
|
||||
# We then standalone_compile the unique subgraphs.
|
||||
#
|
||||
# There are actually only 3 unique subgraphs for this model
|
||||
# (all of its transformer layers are the same modulo weights);
|
||||
# this is true for most vLLM models.
|
||||
# So we test that during cold start, the aot_autograd cache
|
||||
# misses for 3 subgraphs and hits for the rest.
|
||||
# So we test that during cold start, only 3 subgraphs are compiled
|
||||
# These 3 subgraphs should cache miss, and then there should be
|
||||
# no other compilation (so no cache hits).
|
||||
assert counters["aot_autograd"]["autograd_cache_miss"] == 3
|
||||
assert counters["aot_autograd"]["autograd_cache_hit"] == 30
|
||||
assert counters["aot_autograd"]["autograd_cache_hit"] == 0
|
||||
|
||||
@@ -121,7 +121,7 @@ class CompilerManager:
|
||||
and compiling the graph.
|
||||
|
||||
The cache is a dict mapping
|
||||
`(runtime_shape, graph_index, backend_name)`
|
||||
`(runtime_shape, graph_hash, backend_name)`
|
||||
to `any_data` returned from the compiler.
|
||||
|
||||
When serializing the cache, we save it to a Python file
|
||||
@@ -130,7 +130,7 @@ class CompilerManager:
|
||||
"""
|
||||
|
||||
def __init__(self, compilation_config: CompilationConfig) -> None:
|
||||
self.cache: dict[tuple[Range, int, str], Any] = dict()
|
||||
self.cache: dict[tuple[Range, str, str], Any] = dict()
|
||||
self.is_cache_updated = False
|
||||
self.compilation_config = compilation_config
|
||||
self.compiler = make_compiler(compilation_config)
|
||||
@@ -173,6 +173,7 @@ class CompilerManager:
|
||||
self.disable_cache = disable_cache
|
||||
self.cache_dir = cache_dir
|
||||
self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")
|
||||
self.loaded_cache_entries: dict[tuple[Range, str, str], Any] = {}
|
||||
|
||||
if not disable_cache and os.path.exists(self.cache_file_path):
|
||||
# load the cache from the file
|
||||
@@ -186,9 +187,9 @@ class CompilerManager:
|
||||
if not isinstance(value, ty):
|
||||
raise TypeError(f"Expected {ty} but got {type(value)} for {value}")
|
||||
|
||||
def parse_key(key: Any) -> tuple[Range, int, str]:
|
||||
range_tuple, graph_index, compiler_name = key
|
||||
check_type(graph_index, int)
|
||||
def parse_key(key: Any) -> tuple[Range, str, str]:
|
||||
range_tuple, graph_hash, compiler_name = key
|
||||
check_type(graph_hash, str)
|
||||
check_type(compiler_name, str)
|
||||
if isinstance(range_tuple, tuple):
|
||||
start, end = range_tuple
|
||||
@@ -196,7 +197,7 @@ class CompilerManager:
|
||||
check_type(end, int)
|
||||
range_tuple = Range(start=start, end=end)
|
||||
check_type(range_tuple, Range)
|
||||
return range_tuple, graph_index, compiler_name
|
||||
return range_tuple, graph_hash, compiler_name
|
||||
|
||||
self.cache = {parse_key(key): value for key, value in cache.items()}
|
||||
|
||||
@@ -216,18 +217,25 @@ class CompilerManager:
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
graph_hash: str,
|
||||
compile_range: Range,
|
||||
) -> Callable[..., Any] | None:
|
||||
if (compile_range, graph_index, self.compiler.name) not in self.cache:
|
||||
key = (compile_range, graph_hash, self.compiler.name)
|
||||
# See if we've already loaded this cache entry
|
||||
if key in self.loaded_cache_entries:
|
||||
return self.loaded_cache_entries[key]
|
||||
# Otherwise, go load it from disk
|
||||
if key not in self.cache:
|
||||
return None
|
||||
handle = self.cache[(compile_range, graph_index, self.compiler.name)]
|
||||
handle = self.cache[key]
|
||||
compiled_graph = self.compiler.load(
|
||||
handle, graph, example_inputs, graph_index, compile_range
|
||||
handle, graph, example_inputs, compile_range
|
||||
)
|
||||
self.loaded_cache_entries[key] = compiled_graph
|
||||
logger.debug(
|
||||
"Directly load the %s-th graph for compile range %sfrom %s via handle %s",
|
||||
graph_index,
|
||||
"Directly load the graph (hash %s) for compile range "
|
||||
"%sfrom %s via handle %s",
|
||||
graph_hash,
|
||||
str(compile_range),
|
||||
self.compiler.name,
|
||||
handle,
|
||||
@@ -249,12 +257,22 @@ class CompilerManager:
|
||||
global compilation_start_time
|
||||
compilation_start_time = time.time()
|
||||
|
||||
from torch._functorch._aot_autograd.autograd_cache import (
|
||||
AOTAutogradCachePickler,
|
||||
sanitize_gm_for_cache,
|
||||
)
|
||||
|
||||
with sanitize_gm_for_cache(graph):
|
||||
pickler = AOTAutogradCachePickler(graph)
|
||||
dumped_graph = pickler.dumps(graph)
|
||||
graph_hash = hashlib.sha256(dumped_graph).hexdigest()
|
||||
|
||||
compilation_counter.num_backend_compilations += 1
|
||||
|
||||
compiled_graph = None
|
||||
|
||||
# try to load from the cache
|
||||
compiled_graph = self.load(graph, example_inputs, graph_index, compile_range)
|
||||
compiled_graph = self.load(graph, example_inputs, graph_hash, compile_range)
|
||||
if compiled_graph is not None:
|
||||
if graph_index == num_graphs - 1:
|
||||
# after loading the last graph for this shape, record the time.
|
||||
@@ -290,9 +308,13 @@ class CompilerManager:
|
||||
|
||||
assert compiled_graph is not None, "Failed to compile the graph"
|
||||
|
||||
self.loaded_cache_entries[(compile_range, graph_hash, self.compiler.name)] = (
|
||||
compiled_graph
|
||||
)
|
||||
|
||||
# store the artifact in the cache
|
||||
if is_compile_cache_enabled(additional_inductor_config) and handle is not None:
|
||||
self.cache[(compile_range, graph_index, self.compiler.name)] = handle
|
||||
self.cache[(compile_range, graph_hash, self.compiler.name)] = handle
|
||||
compilation_counter.num_cache_entries_updated += 1
|
||||
self.is_cache_updated = True
|
||||
if graph_index == 0:
|
||||
|
||||
@@ -101,7 +101,6 @@ class CompilerInterface:
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
compile_range: Range,
|
||||
) -> Callable[..., Any]:
|
||||
"""
|
||||
@@ -302,7 +301,6 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
compile_range: Range,
|
||||
) -> Callable[..., Any]:
|
||||
assert isinstance(handle, tuple)
|
||||
@@ -527,7 +525,6 @@ class InductorAdaptor(CompilerInterface):
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
compile_range: Range,
|
||||
) -> Callable[..., Any]:
|
||||
assert isinstance(handle, tuple)
|
||||
|
||||
Reference in New Issue
Block a user