[Graph Partition][Cache] Use inductor partition ops config (#27702)
Signed-off-by: Boyuan Feng <boyuan@meta.com>
This commit is contained in:
@@ -97,10 +97,9 @@ class CompilerManager:
|
|||||||
compilation (e.g. partition rules, pass context)."""
|
compilation (e.g. partition rules, pass context)."""
|
||||||
with pass_context(runtime_shape):
|
with pass_context(runtime_shape):
|
||||||
if self.compilation_config.use_inductor_graph_partition:
|
if self.compilation_config.use_inductor_graph_partition:
|
||||||
inductor_partition_ops = resolve_defined_ops(
|
with inductor_partition_rule_context(
|
||||||
self.compilation_config.splitting_ops
|
self.compilation_config.splitting_ops
|
||||||
)
|
):
|
||||||
with inductor_partition_rule_context(inductor_partition_ops):
|
|
||||||
yield
|
yield
|
||||||
else:
|
else:
|
||||||
yield
|
yield
|
||||||
|
|||||||
@@ -3,15 +3,12 @@
|
|||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch._library.utils import lookup_op
|
from torch._library.utils import lookup_op
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
import torch
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -56,47 +53,35 @@ def resolve_defined_ops(op_names: list[str]) -> list["torch._ops.OpOverload"]:
|
|||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def inductor_partition_rule_context(overloads: list["torch._ops.OpOverload"]):
|
def inductor_partition_rule_context(splitting_ops: list[str]):
|
||||||
"""Context manager to temporarily register Inductor partition rules.
|
"""Context manager to temporarily register Inductor partition rules.
|
||||||
|
|
||||||
Registers custom partition rules for specified operators, forcing the
|
Registers custom partition rules for specified operators, forcing the
|
||||||
Inductor scheduler to partition the graph at these operators. The rules
|
Inductor scheduler to partition the graph at these operators. The rules
|
||||||
are automatically restored to their previous state on exit.
|
are automatically restored to their previous state on exit.
|
||||||
|
|
||||||
Note: Callers should use resolve_defined_ops() to convert operator names
|
|
||||||
to OpOverload objects before calling this function.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
overloads: List of resolved operator overload objects.
|
splitting_ops: List of operator names to partition on.
|
||||||
"""
|
"""
|
||||||
if not overloads:
|
if not splitting_ops:
|
||||||
logger.debug("No partition ops provided; skipping rule registration.")
|
logger.debug("No partition ops provided; skipping rule registration.")
|
||||||
yield
|
yield
|
||||||
return
|
return
|
||||||
|
|
||||||
from torch._inductor.scheduler import ( # type: ignore
|
|
||||||
_custom_should_partition_fns,
|
|
||||||
register_should_partition_rule,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _always_partition(*_args, **_kwargs):
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Save current state before registering
|
# Save current state before registering
|
||||||
saved_rules = _custom_should_partition_fns.copy()
|
|
||||||
|
|
||||||
for overload in overloads:
|
saved_splitting_ops: list[str] = list(
|
||||||
register_should_partition_rule(
|
torch._inductor.config.custom_should_partition_ops
|
||||||
overload,
|
|
||||||
_always_partition,
|
|
||||||
)
|
)
|
||||||
|
torch._inductor.config.custom_should_partition_ops = splitting_ops
|
||||||
|
|
||||||
logger.debug("Registered inductor partition rules for %d operators", len(overloads))
|
logger.debug(
|
||||||
|
"Registered inductor partition rules for %d operators", len(splitting_ops)
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
# Clear and restore previous state
|
# Clear and restore previous state
|
||||||
_custom_should_partition_fns.clear()
|
torch._inductor.config.custom_should_partition_ops = saved_splitting_ops
|
||||||
_custom_should_partition_fns.update(saved_rules)
|
|
||||||
logger.debug("Restored previous partition rules state.")
|
logger.debug("Restored previous partition rules state.")
|
||||||
|
|||||||
@@ -113,27 +113,6 @@ class PostGradPassManager(CustomGraphPass):
|
|||||||
self.post_cleanup = PostCleanupPass(config)
|
self.post_cleanup = PostCleanupPass(config)
|
||||||
self.fix_functionalization = FixFunctionalizationPass(config)
|
self.fix_functionalization = FixFunctionalizationPass(config)
|
||||||
|
|
||||||
# [HACK: Bug with Inductor graph partition and torch.compile cache]
|
|
||||||
# In PyTorch 2.9, torch.compile has a bug where the graph
|
|
||||||
# partition is not taken into account during caching.
|
|
||||||
# Because vLLM's Mode.VLLM_COMPILE is the only mode that uses
|
|
||||||
# Inductor graph partition, and VLLM_COMPILE implies there
|
|
||||||
# is a PostGradPassManager, we put the list of operators to graph
|
|
||||||
# partition into the PostGradPassManager's uuid (which
|
|
||||||
# then gets incorporated into Inductor's FX graph cache key).
|
|
||||||
# Remove this hack whenever torch.compile fixes it.
|
|
||||||
|
|
||||||
# This is the list of operators that vLLM asks Inductor to split.
|
|
||||||
self.inductor_splitting_ops = []
|
|
||||||
if (
|
|
||||||
config.compilation_config.use_inductor_graph_partition
|
|
||||||
and config.compilation_config.splitting_ops is not None
|
|
||||||
):
|
|
||||||
# Sort them so we're not dependent on the ordering.
|
|
||||||
self.inductor_splitting_ops = sorted(
|
|
||||||
config.compilation_config.splitting_ops
|
|
||||||
)
|
|
||||||
|
|
||||||
def add(self, pass_: InductorPass):
|
def add(self, pass_: InductorPass):
|
||||||
assert isinstance(pass_, InductorPass)
|
assert isinstance(pass_, InductorPass)
|
||||||
self.passes.append(pass_)
|
self.passes.append(pass_)
|
||||||
@@ -144,16 +123,9 @@ class PostGradPassManager(CustomGraphPass):
|
|||||||
affects compilation caching. Its uuid depends on the UUIDs of all
|
affects compilation caching. Its uuid depends on the UUIDs of all
|
||||||
dependent passes and the pass config. See InductorPass for more info.
|
dependent passes and the pass config. See InductorPass for more info.
|
||||||
"""
|
"""
|
||||||
state = {
|
state = {"pass_config": self.pass_config.uuid(), "passes": []}
|
||||||
"pass_config": self.pass_config.uuid(),
|
|
||||||
"passes": [],
|
|
||||||
"inductor_splitting_ops": [],
|
|
||||||
}
|
|
||||||
for pass_ in self.passes:
|
for pass_ in self.passes:
|
||||||
state["passes"].append(pass_.uuid())
|
state["passes"].append(pass_.uuid())
|
||||||
state["passes"].append(self.fix_functionalization.uuid())
|
state["passes"].append(self.fix_functionalization.uuid())
|
||||||
|
|
||||||
# See [HACK: Bug with Inductor graph partition and torch.compile cache]
|
|
||||||
state["inductor_splitting_ops"].extend(self.inductor_splitting_ops)
|
|
||||||
|
|
||||||
return InductorPass.hash_dict(state)
|
return InductorPass.hash_dict(state)
|
||||||
|
|||||||
@@ -272,7 +272,6 @@ def should_partition_patched(self, node, should_log: bool = False) -> bool:
|
|||||||
from torch._inductor.scheduler import (
|
from torch._inductor.scheduler import (
|
||||||
BaseSchedulerNode,
|
BaseSchedulerNode,
|
||||||
FusedSchedulerNode,
|
FusedSchedulerNode,
|
||||||
_custom_should_partition_fns,
|
|
||||||
)
|
)
|
||||||
from torch._inductor.utils import (
|
from torch._inductor.utils import (
|
||||||
_unstable_customized_partition_wrapper,
|
_unstable_customized_partition_wrapper,
|
||||||
@@ -283,9 +282,21 @@ def should_partition_patched(self, node, should_log: bool = False) -> bool:
|
|||||||
# Allow users to manually specify if a node should be partitioned
|
# Allow users to manually specify if a node should be partitioned
|
||||||
# Can only do this for FallbackKernels
|
# Can only do this for FallbackKernels
|
||||||
ir_node = node.node
|
ir_node = node.node
|
||||||
if isinstance(ir_node, ir.FallbackKernel):
|
if isinstance(ir_node, torch._inductor.ir.FallbackKernel) and (
|
||||||
operator = ir_node.op_overload
|
op := ir_node.op_overload
|
||||||
if operator is not None and operator in _custom_should_partition_fns:
|
):
|
||||||
|
op_overload_packet_name = op.name()
|
||||||
|
op_overload_name = (
|
||||||
|
f"{op_overload_packet_name}.{op._overloadname}"
|
||||||
|
if isinstance(op, torch._ops.OpOverload)
|
||||||
|
else op_overload_packet_name
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
op_overload_packet_name
|
||||||
|
in torch._inductor.config.custom_should_partition_ops
|
||||||
|
or op_overload_name in torch._inductor.config.custom_should_partition_ops
|
||||||
|
):
|
||||||
|
assert isinstance(op, torch._ops.OpOverload)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# When not using cudagraphs, keep all kernels in the `call` function
|
# When not using cudagraphs, keep all kernels in the `call` function
|
||||||
@@ -355,6 +366,13 @@ def _update_scheduler_patched(self) -> None:
|
|||||||
if is_torch_equal("2.9.0"):
|
if is_torch_equal("2.9.0"):
|
||||||
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
|
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
|
||||||
from torch._inductor.graph import GraphLowering
|
from torch._inductor.graph import GraphLowering
|
||||||
|
from torch.utils._config_module import _Config, _ConfigEntry
|
||||||
|
|
||||||
|
# `custom_should_partition_ops` is a new config after 2.9.0. So this would
|
||||||
|
# not overwrite any user configs.
|
||||||
|
torch._inductor.config._config["custom_should_partition_ops"] = _ConfigEntry(
|
||||||
|
_Config(default=[])
|
||||||
|
)
|
||||||
|
|
||||||
PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched
|
PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched
|
||||||
GraphLowering._update_scheduler = _update_scheduler_patched
|
GraphLowering._update_scheduler = _update_scheduler_patched
|
||||||
|
|||||||
Reference in New Issue
Block a user