[torch.compile] Stop compiling identical artifacts (#34003)
Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
@@ -37,12 +37,12 @@ def test_moe_compilation_cold_start(monkeypatch, use_fresh_inductor_cache):
|
||||
# The forward pass consists of 32 transformer layers.
|
||||
# Then, we split on the attention operation. This results in
|
||||
# 33 subgraphs (not including the attention operation).
|
||||
# The 33 subgraphs then get standalone_compile'd.
|
||||
# We then generate compiled artifacts for the unique subgraphs.
|
||||
#
|
||||
# There are actually only 3 unique subgraphs for this model
|
||||
# (all of its transformer layers are the same modulo weights);
|
||||
# this is true for most vLLM models.
|
||||
# So we test that during cold start, the aot_autograd cache
|
||||
# misses for 3 subgraphs and hits for the rest.
|
||||
# So we test that during cold start, we are only compling
|
||||
# for 3 unique subgraphs.
|
||||
assert counters["aot_autograd"]["autograd_cache_miss"] == 3
|
||||
assert counters["aot_autograd"]["autograd_cache_hit"] == 30
|
||||
assert counters["aot_autograd"]["autograd_cache_hit"] == 0
|
||||
|
||||
@@ -134,6 +134,7 @@ class CompilerManager:
|
||||
self.is_cache_updated = False
|
||||
self.compilation_config = compilation_config
|
||||
self.compiler = make_compiler(compilation_config)
|
||||
self.loaded_artifacts: dict[str, Any] = {}
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
return self.compiler.compute_hash(vllm_config)
|
||||
@@ -282,13 +283,61 @@ class CompilerManager:
|
||||
maybe_key += f"{compile_range.start}_{compile_range.end}"
|
||||
maybe_key += f"_subgraph_{graph_index}"
|
||||
with self.compile_context(compile_range):
|
||||
compiled_graph, handle = self.compiler.compile(
|
||||
graph,
|
||||
example_inputs,
|
||||
additional_inductor_config,
|
||||
compile_range,
|
||||
maybe_key,
|
||||
)
|
||||
# There is a compilation time optimization here.
|
||||
#
|
||||
# If the (input metdata, graph, compiler config) are the same, then
|
||||
# we want to avoid compiling the same artifact again. If we didn't
|
||||
# do this optimization, the backend compilation (InductorAdaptor or
|
||||
# InductorStandaloneAdaptor)
|
||||
# is able to cache hit and produce an artifact faster if it was
|
||||
# already created, but it is still a duplicate artifact that
|
||||
# requires unnecessary things e.g. disk IO.
|
||||
#
|
||||
# The optimization is: If the backend compilation cache hits,
|
||||
# then do an early return from the backend compilation and look up
|
||||
# which of the previous in-memory artifacts we created to reuse.
|
||||
#
|
||||
# We implemented this by monkey-patching torch (torch does not
|
||||
# easily expose the cache_key function), but in the future torch
|
||||
# should expose the cache_key function that we can just call
|
||||
# directly before invoking backend compilation.
|
||||
cache_key = None
|
||||
orig = torch._functorch._aot_autograd.autograd_cache.autograd_cache_key
|
||||
|
||||
def autograd_cache_key(*args, **kwargs):
|
||||
result = orig(*args, **kwargs)
|
||||
if result is None:
|
||||
return None
|
||||
nonlocal cache_key
|
||||
cache_key = result[0]
|
||||
if cache_key in self.loaded_artifacts:
|
||||
raise StopCompiling()
|
||||
return result
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
with (
|
||||
# Graphs that are isometric (different node names but same
|
||||
# structure) should be treated as the same.
|
||||
torch._functorch.config.patch(autograd_cache_normalize_inputs=True),
|
||||
patch(
|
||||
"torch._functorch._aot_autograd.autograd_cache.autograd_cache_key",
|
||||
autograd_cache_key,
|
||||
),
|
||||
):
|
||||
try:
|
||||
compiled_graph, handle = self.compiler.compile(
|
||||
graph,
|
||||
example_inputs,
|
||||
additional_inductor_config,
|
||||
compile_range,
|
||||
maybe_key,
|
||||
)
|
||||
except StopCompiling:
|
||||
assert cache_key is not None
|
||||
return self.loaded_artifacts[cache_key]
|
||||
if cache_key is not None and compiled_graph is not None:
|
||||
self.loaded_artifacts[cache_key] = compiled_graph
|
||||
|
||||
assert compiled_graph is not None, "Failed to compile the graph"
|
||||
|
||||
@@ -326,6 +375,10 @@ class CompilerManager:
|
||||
return compiled_graph
|
||||
|
||||
|
||||
class StopCompiling(BaseException):
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SplitItem:
|
||||
submod_name: str
|
||||
|
||||
Reference in New Issue
Block a user