[BugFix] Work around graph partition x torch.compile cache issue (#26956)

Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
Richard Zou
2025-10-15 23:06:11 -04:00
committed by GitHub
parent e19b16dde6
commit 9b6504c307
2 changed files with 31 additions and 12 deletions

View File

@@ -337,9 +337,8 @@ def run_model(llama_config, compile_config: CompilationConfig) -> torch.Tensor:
def test_toy_llama( def test_toy_llama(
backend: str, use_inductor_graph_partition: bool, monkeypatch, tmp_path backend: str, use_inductor_graph_partition: bool, monkeypatch, tmp_path
): ):
# We disable the vLLM compile cache into a new tmp dir for 2 reasons: # We disable the vLLM compile cache into a new tmp dir for 1 reason:
# 1. To make sure we can properly track the number of Inductor compilations. # 1. To make sure we can properly track the number of Inductor compilations.
# 2. Inductor partitioning does not play nicely with Autograd cache (below)
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
@@ -369,15 +368,6 @@ def test_toy_llama(
cudagraph_capture_sizes=[1, 2], cudagraph_capture_sizes=[1, 2],
) )
# FIXME(luka/boyuan): the graph from the previous test case
# (no inductor partition) gets cached by AotAutograd so then the
# compilation with inductor partitioning incorrectly loads an unpartitioned
# graph and never partitions. I think this is a bug with custom inductor
# partitioning but does not affect vLLM more generally as vLLM uses its own
# cache (which takes inductor partitioning into account).
if use_inductor_graph_partition:
compile_config_no_split.inductor_compile_config["force_disable_caches"] = True
compile_config_split = deepcopy(compile_config_no_split) compile_config_split = deepcopy(compile_config_no_split)
compile_config_split.splitting_ops = ["silly::attention"] compile_config_split.splitting_ops = ["silly::attention"]

View File

@@ -110,6 +110,27 @@ class PostGradPassManager(CustomGraphPass):
self.post_cleanup = PostCleanupPass(config) self.post_cleanup = PostCleanupPass(config)
self.fix_functionalization = FixFunctionalizationPass(config) self.fix_functionalization = FixFunctionalizationPass(config)
# [HACK: Bug with Inductor graph partition and torch.compile cache]
# In PyTorch 2.9, torch.compile has a bug where the graph
# partition is not taken into account during caching.
# Because vLLM's Mode.VLLM_COMPILE is the only mode that uses
# Inductor graph partition, and VLLM_COMPILE implies there
# is a PostGradPassManager, we put the list of operators to graph
# partition into the PostGradPassManager's uuid (which
# then gets incorporated into Inductor's FX graph cache key).
# Remove this hack whenever torch.compile fixes it.
# This is the list of operators that vLLM asks Inductor to split.
self.inductor_splitting_ops = []
if (
config.compilation_config.use_inductor_graph_partition
and config.compilation_config.splitting_ops is not None
):
# Sort them so we're not dependent on the ordering.
self.inductor_splitting_ops = sorted(
config.compilation_config.splitting_ops
)
def add(self, pass_: InductorPass): def add(self, pass_: InductorPass):
assert isinstance(pass_, InductorPass) assert isinstance(pass_, InductorPass)
self.passes.append(pass_) self.passes.append(pass_)
@@ -120,8 +141,16 @@ class PostGradPassManager(CustomGraphPass):
affects compilation caching. Its uuid depends on the UUIDs of all affects compilation caching. Its uuid depends on the UUIDs of all
dependent passes and the pass config. See InductorPass for more info. dependent passes and the pass config. See InductorPass for more info.
""" """
state = {"pass_config": self.pass_config.uuid(), "passes": []} state = {
"pass_config": self.pass_config.uuid(),
"passes": [],
"inductor_splitting_ops": [],
}
for pass_ in self.passes: for pass_ in self.passes:
state["passes"].append(pass_.uuid()) state["passes"].append(pass_.uuid())
state["passes"].append(self.fix_functionalization.uuid()) state["passes"].append(self.fix_functionalization.uuid())
# See [HACK: Bug with Inductor graph partition and torch.compile cache]
state["inductor_splitting_ops"].extend(self.inductor_splitting_ops)
return InductorPass.hash_dict(state) return InductorPass.hash_dict(state)