[BugFix] Work around graph partition x torch.compile cache issue (#26956)

Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
Richard Zou
2025-10-15 23:06:11 -04:00
committed by GitHub
parent e19b16dde6
commit 9b6504c307
2 changed files with 31 additions and 12 deletions

View File

@@ -110,6 +110,27 @@ class PostGradPassManager(CustomGraphPass):
self.post_cleanup = PostCleanupPass(config)
self.fix_functionalization = FixFunctionalizationPass(config)
# [HACK: Bug with Inductor graph partition and torch.compile cache]
# In PyTorch 2.9, torch.compile has a bug where the graph
# partition is not taken into account during caching.
# Because vLLM's Mode.VLLM_COMPILE is the only mode that uses
# Inductor graph partition, and VLLM_COMPILE implies there
# is a PostGradPassManager, we put the list of operators to graph
# partition into the PostGradPassManager's uuid (which
# then gets incorporated into Inductor's FX graph cache key).
# Remove this hack whenever torch.compile fixes it.
# This is the list of operators that vLLM asks Inductor to split.
self.inductor_splitting_ops = []
if (
config.compilation_config.use_inductor_graph_partition
and config.compilation_config.splitting_ops is not None
):
# Sort them so we're not dependent on the ordering.
self.inductor_splitting_ops = sorted(
config.compilation_config.splitting_ops
)
def add(self, pass_: InductorPass):
assert isinstance(pass_, InductorPass)
self.passes.append(pass_)
@@ -120,8 +141,16 @@ class PostGradPassManager(CustomGraphPass):
affects compilation caching. Its uuid depends on the UUIDs of all
dependent passes and the pass config. See InductorPass for more info.
"""
state = {"pass_config": self.pass_config.uuid(), "passes": []}
state = {
"pass_config": self.pass_config.uuid(),
"passes": [],
"inductor_splitting_ops": [],
}
for pass_ in self.passes:
state["passes"].append(pass_.uuid())
state["passes"].append(self.fix_functionalization.uuid())
# See [HACK: Bug with Inductor graph partition and torch.compile cache]
state["inductor_splitting_ops"].extend(self.inductor_splitting_ops)
return InductorPass.hash_dict(state)