[torch.compile] Stop compiling identical artifacts (#34003)

Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
Richard Zou
2026-02-07 08:24:48 -05:00
committed by GitHub
parent dd6a6e1190
commit 81fe69cae5
2 changed files with 64 additions and 11 deletions

View File

@@ -37,12 +37,12 @@ def test_moe_compilation_cold_start(monkeypatch, use_fresh_inductor_cache):
# The forward pass consists of 32 transformer layers.
# Then, we split on the attention operation. This results in
# 33 subgraphs (not including the attention operation).
# The 33 subgraphs then get standalone_compile'd.
# We then generate compiled artifacts for the unique subgraphs.
#
# There are actually only 3 unique subgraphs for this model
# (all of its transformer layers are the same modulo weights);
# this is true for most vLLM models.
# So we test that during cold start, the aot_autograd cache
# misses for 3 subgraphs and hits for the rest.
# So we test that during cold start, we are only compling
# for 3 unique subgraphs.
assert counters["aot_autograd"]["autograd_cache_miss"] == 3
assert counters["aot_autograd"]["autograd_cache_hit"] == 30
assert counters["aot_autograd"]["autograd_cache_hit"] == 0

View File

@@ -134,6 +134,7 @@ class CompilerManager:
self.is_cache_updated = False
self.compilation_config = compilation_config
self.compiler = make_compiler(compilation_config)
self.loaded_artifacts: dict[str, Any] = {}
def compute_hash(self, vllm_config: VllmConfig) -> str:
return self.compiler.compute_hash(vllm_config)
@@ -282,13 +283,61 @@ class CompilerManager:
maybe_key += f"{compile_range.start}_{compile_range.end}"
maybe_key += f"_subgraph_{graph_index}"
with self.compile_context(compile_range):
compiled_graph, handle = self.compiler.compile(
graph,
example_inputs,
additional_inductor_config,
compile_range,
maybe_key,
)
# There is a compilation time optimization here.
#
# If the (input metdata, graph, compiler config) are the same, then
# we want to avoid compiling the same artifact again. If we didn't
# do this optimization, the backend compilation (InductorAdaptor or
# InductorStandaloneAdaptor)
# is able to cache hit and produce an artifact faster if it was
# already created, but it is still a duplicate artifact that
# requires unnecessary things e.g. disk IO.
#
# The optimization is: If the backend compilation cache hits,
# then do an early return from the backend compilation and look up
# which of the previous in-memory artifacts we created to reuse.
#
# We implemented this by monkey-patching torch (torch does not
# easily expose the cache_key function), but in the future torch
# should expose the cache_key function that we can just call
# directly before invoking backend compilation.
cache_key = None
orig = torch._functorch._aot_autograd.autograd_cache.autograd_cache_key
def autograd_cache_key(*args, **kwargs):
result = orig(*args, **kwargs)
if result is None:
return None
nonlocal cache_key
cache_key = result[0]
if cache_key in self.loaded_artifacts:
raise StopCompiling()
return result
from unittest.mock import patch
with (
# Graphs that are isometric (different node names but same
# structure) should be treated as the same.
torch._functorch.config.patch(autograd_cache_normalize_inputs=True),
patch(
"torch._functorch._aot_autograd.autograd_cache.autograd_cache_key",
autograd_cache_key,
),
):
try:
compiled_graph, handle = self.compiler.compile(
graph,
example_inputs,
additional_inductor_config,
compile_range,
maybe_key,
)
except StopCompiling:
assert cache_key is not None
return self.loaded_artifacts[cache_key]
if cache_key is not None and compiled_graph is not None:
self.loaded_artifacts[cache_key] = compiled_graph
assert compiled_graph is not None, "Failed to compile the graph"
@@ -326,6 +375,10 @@ class CompilerManager:
return compiled_graph
class StopCompiling(BaseException):
pass
@dataclasses.dataclass
class SplitItem:
submod_name: str