Revert "[torch.compile] Significantly speed up cold start times" (#33820)
Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
@@ -121,7 +121,7 @@ class CompilerManager:
|
||||
and compiling the graph.
|
||||
|
||||
The cache is a dict mapping
|
||||
`(runtime_shape, graph_hash, backend_name)`
|
||||
`(runtime_shape, graph_index, backend_name)`
|
||||
to `any_data` returned from the compiler.
|
||||
|
||||
When serializing the cache, we save it to a Python file
|
||||
@@ -130,7 +130,7 @@ class CompilerManager:
|
||||
"""
|
||||
|
||||
def __init__(self, compilation_config: CompilationConfig) -> None:
|
||||
self.cache: dict[tuple[Range, str, str], Any] = dict()
|
||||
self.cache: dict[tuple[Range, int, str], Any] = dict()
|
||||
self.is_cache_updated = False
|
||||
self.compilation_config = compilation_config
|
||||
self.compiler = make_compiler(compilation_config)
|
||||
@@ -173,7 +173,6 @@ class CompilerManager:
|
||||
self.disable_cache = disable_cache
|
||||
self.cache_dir = cache_dir
|
||||
self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")
|
||||
self.loaded_cache_entries: dict[tuple[Range, str, str], Any] = {}
|
||||
|
||||
if not disable_cache and os.path.exists(self.cache_file_path):
|
||||
# load the cache from the file
|
||||
@@ -187,9 +186,9 @@ class CompilerManager:
|
||||
if not isinstance(value, ty):
|
||||
raise TypeError(f"Expected {ty} but got {type(value)} for {value}")
|
||||
|
||||
def parse_key(key: Any) -> tuple[Range, str, str]:
|
||||
range_tuple, graph_hash, compiler_name = key
|
||||
check_type(graph_hash, str)
|
||||
def parse_key(key: Any) -> tuple[Range, int, str]:
|
||||
range_tuple, graph_index, compiler_name = key
|
||||
check_type(graph_index, int)
|
||||
check_type(compiler_name, str)
|
||||
if isinstance(range_tuple, tuple):
|
||||
start, end = range_tuple
|
||||
@@ -197,7 +196,7 @@ class CompilerManager:
|
||||
check_type(end, int)
|
||||
range_tuple = Range(start=start, end=end)
|
||||
check_type(range_tuple, Range)
|
||||
return range_tuple, graph_hash, compiler_name
|
||||
return range_tuple, graph_index, compiler_name
|
||||
|
||||
self.cache = {parse_key(key): value for key, value in cache.items()}
|
||||
|
||||
@@ -217,25 +216,18 @@ class CompilerManager:
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_hash: str,
|
||||
graph_index: int,
|
||||
compile_range: Range,
|
||||
) -> Callable[..., Any] | None:
|
||||
key = (compile_range, graph_hash, self.compiler.name)
|
||||
# See if we've already loaded this cache entry
|
||||
if key in self.loaded_cache_entries:
|
||||
return self.loaded_cache_entries[key]
|
||||
# Otherwise, go load it from disk
|
||||
if key not in self.cache:
|
||||
if (compile_range, graph_index, self.compiler.name) not in self.cache:
|
||||
return None
|
||||
handle = self.cache[key]
|
||||
handle = self.cache[(compile_range, graph_index, self.compiler.name)]
|
||||
compiled_graph = self.compiler.load(
|
||||
handle, graph, example_inputs, compile_range
|
||||
handle, graph, example_inputs, graph_index, compile_range
|
||||
)
|
||||
self.loaded_cache_entries[key] = compiled_graph
|
||||
logger.debug(
|
||||
"Directly load the graph (hash %s) for compile range "
|
||||
"%sfrom %s via handle %s",
|
||||
graph_hash,
|
||||
"Directly load the %s-th graph for compile range %sfrom %s via handle %s",
|
||||
graph_index,
|
||||
str(compile_range),
|
||||
self.compiler.name,
|
||||
handle,
|
||||
@@ -257,22 +249,12 @@ class CompilerManager:
|
||||
global compilation_start_time
|
||||
compilation_start_time = time.time()
|
||||
|
||||
from torch._functorch._aot_autograd.autograd_cache import (
|
||||
AOTAutogradCachePickler,
|
||||
sanitize_gm_for_cache,
|
||||
)
|
||||
|
||||
with sanitize_gm_for_cache(graph):
|
||||
pickler = AOTAutogradCachePickler(graph)
|
||||
dumped_graph = pickler.dumps(graph)
|
||||
graph_hash = hashlib.sha256(dumped_graph).hexdigest()
|
||||
|
||||
compilation_counter.num_backend_compilations += 1
|
||||
|
||||
compiled_graph = None
|
||||
|
||||
# try to load from the cache
|
||||
compiled_graph = self.load(graph, example_inputs, graph_hash, compile_range)
|
||||
compiled_graph = self.load(graph, example_inputs, graph_index, compile_range)
|
||||
if compiled_graph is not None:
|
||||
if graph_index == num_graphs - 1:
|
||||
# after loading the last graph for this shape, record the time.
|
||||
@@ -308,13 +290,9 @@ class CompilerManager:
|
||||
|
||||
assert compiled_graph is not None, "Failed to compile the graph"
|
||||
|
||||
self.loaded_cache_entries[(compile_range, graph_hash, self.compiler.name)] = (
|
||||
compiled_graph
|
||||
)
|
||||
|
||||
# store the artifact in the cache
|
||||
if is_compile_cache_enabled(additional_inductor_config) and handle is not None:
|
||||
self.cache[(compile_range, graph_hash, self.compiler.name)] = handle
|
||||
self.cache[(compile_range, graph_index, self.compiler.name)] = handle
|
||||
compilation_counter.num_cache_entries_updated += 1
|
||||
self.is_cache_updated = True
|
||||
if graph_index == 0:
|
||||
|
||||
@@ -101,6 +101,7 @@ class CompilerInterface:
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
compile_range: Range,
|
||||
) -> Callable[..., Any]:
|
||||
"""
|
||||
@@ -301,6 +302,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
compile_range: Range,
|
||||
) -> Callable[..., Any]:
|
||||
assert isinstance(handle, tuple)
|
||||
@@ -525,6 +527,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
compile_range: Range,
|
||||
) -> Callable[..., Any]:
|
||||
assert isinstance(handle, tuple)
|
||||
|
||||
Reference in New Issue
Block a user