[Graph Partition][Cache] Use inductor partition ops config (#27702)

Signed-off-by: Boyuan Feng <boyuan@meta.com>
This commit is contained in:
Boyuan Feng
2025-11-05 05:04:48 -08:00
committed by GitHub
parent 6b7a81185d
commit 6ab183813c
4 changed files with 37 additions and 63 deletions

View File

@@ -272,7 +272,6 @@ def should_partition_patched(self, node, should_log: bool = False) -> bool:
from torch._inductor.scheduler import (
BaseSchedulerNode,
FusedSchedulerNode,
_custom_should_partition_fns,
)
from torch._inductor.utils import (
_unstable_customized_partition_wrapper,
@@ -283,9 +282,21 @@ def should_partition_patched(self, node, should_log: bool = False) -> bool:
# Allow users to manually specify if a node should be partitioned
# Can only do this for FallbackKernels
ir_node = node.node
if isinstance(ir_node, ir.FallbackKernel):
operator = ir_node.op_overload
if operator is not None and operator in _custom_should_partition_fns:
if isinstance(ir_node, torch._inductor.ir.FallbackKernel) and (
op := ir_node.op_overload
):
op_overload_packet_name = op.name()
op_overload_name = (
f"{op_overload_packet_name}.{op._overloadname}"
if isinstance(op, torch._ops.OpOverload)
else op_overload_packet_name
)
if (
op_overload_packet_name
in torch._inductor.config.custom_should_partition_ops
or op_overload_name in torch._inductor.config.custom_should_partition_ops
):
assert isinstance(op, torch._ops.OpOverload)
return True
# When not using cudagraphs, keep all kernels in the `call` function
@@ -355,6 +366,13 @@ def _update_scheduler_patched(self) -> None:
if is_torch_equal("2.9.0"):
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
from torch._inductor.graph import GraphLowering
from torch.utils._config_module import _Config, _ConfigEntry
# `custom_should_partition_ops` is a new config after 2.9.0. So this would
# not overwrite any user configs.
torch._inductor.config._config["custom_should_partition_ops"] = _ConfigEntry(
_Config(default=[])
)
PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched
GraphLowering._update_scheduler = _update_scheduler_patched