[Graph Partition][Cache] Use inductor partition ops config (#27702)

Signed-off-by: Boyuan Feng <boyuan@meta.com>
This commit is contained in:
Boyuan Feng
2025-11-05 05:04:48 -08:00
committed by GitHub
parent 6b7a81185d
commit 6ab183813c
4 changed files with 37 additions and 63 deletions

View File

@@ -97,10 +97,9 @@ class CompilerManager:
compilation (e.g. partition rules, pass context)."""
with pass_context(runtime_shape):
if self.compilation_config.use_inductor_graph_partition:
inductor_partition_ops = resolve_defined_ops(
with inductor_partition_rule_context(
self.compilation_config.splitting_ops
)
with inductor_partition_rule_context(inductor_partition_ops):
):
yield
else:
yield

View File

@@ -3,15 +3,12 @@
import contextlib
import logging
from typing import TYPE_CHECKING
import torch
from torch._library.utils import lookup_op
from vllm.logger import init_logger
if TYPE_CHECKING:
import torch
logger = init_logger(__name__)
@@ -56,47 +53,35 @@ def resolve_defined_ops(op_names: list[str]) -> list["torch._ops.OpOverload"]:
@contextlib.contextmanager
def inductor_partition_rule_context(overloads: list["torch._ops.OpOverload"]):
def inductor_partition_rule_context(splitting_ops: list[str]):
"""Context manager to temporarily register Inductor partition rules.
Registers custom partition rules for specified operators, forcing the
Inductor scheduler to partition the graph at these operators. The rules
are automatically restored to their previous state on exit.
Note: Callers should use resolve_defined_ops() to convert operator names
to OpOverload objects before calling this function.
Args:
overloads: List of resolved operator overload objects.
splitting_ops: List of operator names to partition on.
"""
if not overloads:
if not splitting_ops:
logger.debug("No partition ops provided; skipping rule registration.")
yield
return
from torch._inductor.scheduler import ( # type: ignore
_custom_should_partition_fns,
register_should_partition_rule,
)
def _always_partition(*_args, **_kwargs):
return True
# Save current state before registering
saved_rules = _custom_should_partition_fns.copy()
for overload in overloads:
register_should_partition_rule(
overload,
_always_partition,
)
saved_splitting_ops: list[str] = list(
torch._inductor.config.custom_should_partition_ops
)
torch._inductor.config.custom_should_partition_ops = splitting_ops
logger.debug("Registered inductor partition rules for %d operators", len(overloads))
logger.debug(
"Registered inductor partition rules for %d operators", len(splitting_ops)
)
try:
yield
finally:
# Clear and restore previous state
_custom_should_partition_fns.clear()
_custom_should_partition_fns.update(saved_rules)
torch._inductor.config.custom_should_partition_ops = saved_splitting_ops
logger.debug("Restored previous partition rules state.")

View File

@@ -113,27 +113,6 @@ class PostGradPassManager(CustomGraphPass):
self.post_cleanup = PostCleanupPass(config)
self.fix_functionalization = FixFunctionalizationPass(config)
# [HACK: Bug with Inductor graph partition and torch.compile cache]
# In PyTorch 2.9, torch.compile has a bug where the graph
# partition is not taken into account during caching.
# Because vLLM's Mode.VLLM_COMPILE is the only mode that uses
# Inductor graph partition, and VLLM_COMPILE implies there
# is a PostGradPassManager, we put the list of operators to graph
# partition into the PostGradPassManager's uuid (which
# then gets incorporated into Inductor's FX graph cache key).
# Remove this hack whenever torch.compile fixes it.
# This is the list of operators that vLLM asks Inductor to split.
self.inductor_splitting_ops = []
if (
config.compilation_config.use_inductor_graph_partition
and config.compilation_config.splitting_ops is not None
):
# Sort them so we're not dependent on the ordering.
self.inductor_splitting_ops = sorted(
config.compilation_config.splitting_ops
)
def add(self, pass_: InductorPass):
assert isinstance(pass_, InductorPass)
self.passes.append(pass_)
@@ -144,16 +123,9 @@ class PostGradPassManager(CustomGraphPass):
affects compilation caching. Its uuid depends on the UUIDs of all
dependent passes and the pass config. See InductorPass for more info.
"""
state = {
"pass_config": self.pass_config.uuid(),
"passes": [],
"inductor_splitting_ops": [],
}
state = {"pass_config": self.pass_config.uuid(), "passes": []}
for pass_ in self.passes:
state["passes"].append(pass_.uuid())
state["passes"].append(self.fix_functionalization.uuid())
# See [HACK: Bug with Inductor graph partition and torch.compile cache]
state["inductor_splitting_ops"].extend(self.inductor_splitting_ops)
return InductorPass.hash_dict(state)