[Graph Partition][Cache] Use inductor partition ops config (#27702)
Signed-off-by: Boyuan Feng <boyuan@meta.com>
This commit is contained in:
@@ -272,7 +272,6 @@ def should_partition_patched(self, node, should_log: bool = False) -> bool:
|
||||
from torch._inductor.scheduler import (
|
||||
BaseSchedulerNode,
|
||||
FusedSchedulerNode,
|
||||
_custom_should_partition_fns,
|
||||
)
|
||||
from torch._inductor.utils import (
|
||||
_unstable_customized_partition_wrapper,
|
||||
@@ -283,9 +282,21 @@ def should_partition_patched(self, node, should_log: bool = False) -> bool:
|
||||
# Allow users to manually specify if a node should be partitioned
|
||||
# Can only do this for FallbackKernels
|
||||
ir_node = node.node
|
||||
if isinstance(ir_node, ir.FallbackKernel):
|
||||
operator = ir_node.op_overload
|
||||
if operator is not None and operator in _custom_should_partition_fns:
|
||||
if isinstance(ir_node, torch._inductor.ir.FallbackKernel) and (
|
||||
op := ir_node.op_overload
|
||||
):
|
||||
op_overload_packet_name = op.name()
|
||||
op_overload_name = (
|
||||
f"{op_overload_packet_name}.{op._overloadname}"
|
||||
if isinstance(op, torch._ops.OpOverload)
|
||||
else op_overload_packet_name
|
||||
)
|
||||
if (
|
||||
op_overload_packet_name
|
||||
in torch._inductor.config.custom_should_partition_ops
|
||||
or op_overload_name in torch._inductor.config.custom_should_partition_ops
|
||||
):
|
||||
assert isinstance(op, torch._ops.OpOverload)
|
||||
return True
|
||||
|
||||
# When not using cudagraphs, keep all kernels in the `call` function
|
||||
@@ -355,6 +366,13 @@ def _update_scheduler_patched(self) -> None:
|
||||
if is_torch_equal("2.9.0"):
|
||||
from torch._inductor.codegen.wrapper import PythonWrapperCodegen
|
||||
from torch._inductor.graph import GraphLowering
|
||||
from torch.utils._config_module import _Config, _ConfigEntry
|
||||
|
||||
# `custom_should_partition_ops` is a new config after 2.9.0. So this would
|
||||
# not overwrite any user configs.
|
||||
torch._inductor.config._config["custom_should_partition_ops"] = _ConfigEntry(
|
||||
_Config(default=[])
|
||||
)
|
||||
|
||||
PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched
|
||||
GraphLowering._update_scheduler = _update_scheduler_patched
|
||||
|
||||
Reference in New Issue
Block a user