[Graph Partition][Cache] Use inductor partition ops config (#27702)

Signed-off-by: Boyuan Feng <boyuan@meta.com>
This commit is contained in:
Boyuan Feng
2025-11-05 05:04:48 -08:00
committed by GitHub
parent 6b7a81185d
commit 6ab183813c
4 changed files with 37 additions and 63 deletions

View File

@@ -97,10 +97,9 @@ class CompilerManager:
compilation (e.g. partition rules, pass context).""" compilation (e.g. partition rules, pass context)."""
with pass_context(runtime_shape): with pass_context(runtime_shape):
if self.compilation_config.use_inductor_graph_partition: if self.compilation_config.use_inductor_graph_partition:
inductor_partition_ops = resolve_defined_ops( with inductor_partition_rule_context(
self.compilation_config.splitting_ops self.compilation_config.splitting_ops
) ):
with inductor_partition_rule_context(inductor_partition_ops):
yield yield
else: else:
yield yield

View File

@@ -3,15 +3,12 @@
import contextlib import contextlib
import logging import logging
from typing import TYPE_CHECKING
import torch
from torch._library.utils import lookup_op from torch._library.utils import lookup_op
from vllm.logger import init_logger from vllm.logger import init_logger
if TYPE_CHECKING:
import torch
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -56,47 +53,35 @@ def resolve_defined_ops(op_names: list[str]) -> list["torch._ops.OpOverload"]:
@contextlib.contextmanager @contextlib.contextmanager
def inductor_partition_rule_context(overloads: list["torch._ops.OpOverload"]): def inductor_partition_rule_context(splitting_ops: list[str]):
"""Context manager to temporarily register Inductor partition rules. """Context manager to temporarily register Inductor partition rules.
Registers custom partition rules for specified operators, forcing the Registers custom partition rules for specified operators, forcing the
Inductor scheduler to partition the graph at these operators. The rules Inductor scheduler to partition the graph at these operators. The rules
are automatically restored to their previous state on exit. are automatically restored to their previous state on exit.
Note: Callers should use resolve_defined_ops() to convert operator names
to OpOverload objects before calling this function.
Args: Args:
overloads: List of resolved operator overload objects. splitting_ops: List of operator names to partition on.
""" """
if not overloads: if not splitting_ops:
logger.debug("No partition ops provided; skipping rule registration.") logger.debug("No partition ops provided; skipping rule registration.")
yield yield
return return
from torch._inductor.scheduler import ( # type: ignore
_custom_should_partition_fns,
register_should_partition_rule,
)
def _always_partition(*_args, **_kwargs):
return True
# Save current state before registering # Save current state before registering
saved_rules = _custom_should_partition_fns.copy()
for overload in overloads: saved_splitting_ops: list[str] = list(
register_should_partition_rule( torch._inductor.config.custom_should_partition_ops
overload, )
_always_partition, torch._inductor.config.custom_should_partition_ops = splitting_ops
)
logger.debug("Registered inductor partition rules for %d operators", len(overloads)) logger.debug(
"Registered inductor partition rules for %d operators", len(splitting_ops)
)
try: try:
yield yield
finally: finally:
# Clear and restore previous state # Clear and restore previous state
_custom_should_partition_fns.clear() torch._inductor.config.custom_should_partition_ops = saved_splitting_ops
_custom_should_partition_fns.update(saved_rules)
logger.debug("Restored previous partition rules state.") logger.debug("Restored previous partition rules state.")

View File

@@ -113,27 +113,6 @@ class PostGradPassManager(CustomGraphPass):
self.post_cleanup = PostCleanupPass(config) self.post_cleanup = PostCleanupPass(config)
self.fix_functionalization = FixFunctionalizationPass(config) self.fix_functionalization = FixFunctionalizationPass(config)
# [HACK: Bug with Inductor graph partition and torch.compile cache]
# In PyTorch 2.9, torch.compile has a bug where the graph
# partition is not taken into account during caching.
# Because vLLM's Mode.VLLM_COMPILE is the only mode that uses
# Inductor graph partition, and VLLM_COMPILE implies there
# is a PostGradPassManager, we put the list of operators to graph
# partition into the PostGradPassManager's uuid (which
# then gets incorporated into Inductor's FX graph cache key).
# Remove this hack whenever torch.compile fixes it.
# This is the list of operators that vLLM asks Inductor to split.
self.inductor_splitting_ops = []
if (
config.compilation_config.use_inductor_graph_partition
and config.compilation_config.splitting_ops is not None
):
# Sort them so we're not dependent on the ordering.
self.inductor_splitting_ops = sorted(
config.compilation_config.splitting_ops
)
def add(self, pass_: InductorPass): def add(self, pass_: InductorPass):
assert isinstance(pass_, InductorPass) assert isinstance(pass_, InductorPass)
self.passes.append(pass_) self.passes.append(pass_)
@@ -144,16 +123,9 @@ class PostGradPassManager(CustomGraphPass):
affects compilation caching. Its uuid depends on the UUIDs of all affects compilation caching. Its uuid depends on the UUIDs of all
dependent passes and the pass config. See InductorPass for more info. dependent passes and the pass config. See InductorPass for more info.
""" """
state = { state = {"pass_config": self.pass_config.uuid(), "passes": []}
"pass_config": self.pass_config.uuid(),
"passes": [],
"inductor_splitting_ops": [],
}
for pass_ in self.passes: for pass_ in self.passes:
state["passes"].append(pass_.uuid()) state["passes"].append(pass_.uuid())
state["passes"].append(self.fix_functionalization.uuid()) state["passes"].append(self.fix_functionalization.uuid())
# See [HACK: Bug with Inductor graph partition and torch.compile cache]
state["inductor_splitting_ops"].extend(self.inductor_splitting_ops)
return InductorPass.hash_dict(state) return InductorPass.hash_dict(state)

View File

@@ -272,7 +272,6 @@ def should_partition_patched(self, node, should_log: bool = False) -> bool:
from torch._inductor.scheduler import ( from torch._inductor.scheduler import (
BaseSchedulerNode, BaseSchedulerNode,
FusedSchedulerNode, FusedSchedulerNode,
_custom_should_partition_fns,
) )
from torch._inductor.utils import ( from torch._inductor.utils import (
_unstable_customized_partition_wrapper, _unstable_customized_partition_wrapper,
@@ -283,9 +282,21 @@ def should_partition_patched(self, node, should_log: bool = False) -> bool:
# Allow users to manually specify if a node should be partitioned # Allow users to manually specify if a node should be partitioned
# Can only do this for FallbackKernels # Can only do this for FallbackKernels
ir_node = node.node ir_node = node.node
if isinstance(ir_node, ir.FallbackKernel): if isinstance(ir_node, torch._inductor.ir.FallbackKernel) and (
operator = ir_node.op_overload op := ir_node.op_overload
if operator is not None and operator in _custom_should_partition_fns: ):
op_overload_packet_name = op.name()
op_overload_name = (
f"{op_overload_packet_name}.{op._overloadname}"
if isinstance(op, torch._ops.OpOverload)
else op_overload_packet_name
)
if (
op_overload_packet_name
in torch._inductor.config.custom_should_partition_ops
or op_overload_name in torch._inductor.config.custom_should_partition_ops
):
assert isinstance(op, torch._ops.OpOverload)
return True return True
# When not using cudagraphs, keep all kernels in the `call` function # When not using cudagraphs, keep all kernels in the `call` function
@@ -355,6 +366,13 @@ def _update_scheduler_patched(self) -> None:
if is_torch_equal("2.9.0"): if is_torch_equal("2.9.0"):
from torch._inductor.codegen.wrapper import PythonWrapperCodegen from torch._inductor.codegen.wrapper import PythonWrapperCodegen
from torch._inductor.graph import GraphLowering from torch._inductor.graph import GraphLowering
from torch.utils._config_module import _Config, _ConfigEntry
# `custom_should_partition_ops` is a new config after 2.9.0. So this would
# not overwrite any user configs.
torch._inductor.config._config["custom_should_partition_ops"] = _ConfigEntry(
_Config(default=[])
)
PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched
GraphLowering._update_scheduler = _update_scheduler_patched GraphLowering._update_scheduler = _update_scheduler_patched