[torch.compile] Make inductor partition rules respect splitting_ops #25691 (#25845)

Signed-off-by: baonudesifeizhai <baonudesifeizhai@gmail.com>
Signed-off-by: baonudesifeizhai <85092850+baonudesifeizhai@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
baonudesifeizhai
2025-10-10 12:35:28 -04:00
committed by GitHub
parent e519281920
commit cddce79fda
9 changed files with 267 additions and 112 deletions

View File

@@ -209,8 +209,23 @@ class CompilationConfig:
disabled when running with Inductor: level>=PIECEWISE and use_inductor=True.
Inductor generates (fused) Triton kernels for disabled custom ops."""
splitting_ops: Optional[list[str]] = None
"""A list of ops to split the full graph into subgraphs, used in piecewise
compilation."""
"""A list of ops to exclude from cudagraphs, used in piecewise compilation.
The behavior depends on use_inductor_graph_partition:
- When use_inductor_graph_partition=False (default):
These ops are used for Dynamo FX-level graph splitting. The graph is
split at these ops before Inductor compilation, creating separate
subgraphs for cudagraph capture.
- When use_inductor_graph_partition=True:
These ops are used to register Inductor partition rules. The graph
partitioning happens at Inductor codegen time after all passes and
fusions are finished, allowing compilation and custom passes to operate
on the full graph while still excluding these ops from cudagraphs.
If None, defaults to attention ops for piecewise cudagraphs.
If empty list [], no ops are excluded (suitable for full cudagraphs)."""
# Inductor capture
use_inductor: bool = True
@@ -367,18 +382,19 @@ class CompilationConfig:
model code, e.g., Attention, FusedMOE when dp_size>1."""
# Attention ops; used for piecewise cudagraphs
# Use PyTorch operator format: "namespace::name"
_attention_ops: ClassVar[list[str]] = [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
"vllm.unified_mla_attention",
"vllm.unified_mla_attention_with_output",
"vllm.mamba_mixer2",
"vllm.mamba_mixer",
"vllm.short_conv",
"vllm.linear_attention",
"vllm.plamo2_mamba_mixer",
"vllm.gdn_attention",
"vllm.sparse_attn_indexer",
"vllm::unified_attention",
"vllm::unified_attention_with_output",
"vllm::unified_mla_attention",
"vllm::unified_mla_attention_with_output",
"vllm::mamba_mixer2",
"vllm::mamba_mixer",
"vllm::short_conv",
"vllm::linear_attention",
"vllm::plamo2_mamba_mixer",
"vllm::gdn_attention",
"vllm::sparse_attn_indexer",
]
def compute_hash(self) -> str:
@@ -654,31 +670,25 @@ class CompilationConfig:
def set_splitting_ops_for_inductor_graph_partition(self):
assert self.use_inductor_graph_partition
use_inductor_graph_partition_msg = (
"When use_inductor_graph_partition=True, splitting_ops "
"are ignored and set to an empty list. Instead, "
'"tags=(torch._C.Tag.cudagraph_unsafe, )," is '
"used to annotate custom ops for graph partition."
)
if self.splitting_ops is not None and len(self.splitting_ops) > 0:
logger.warning_once(use_inductor_graph_partition_msg)
self.splitting_ops = []
if self.splitting_ops is None:
self.splitting_ops = list(self._attention_ops)
def set_splitting_ops_for_attn_fusion(self):
assert self.pass_config.enable_attn_fusion
if self.splitting_ops is None:
self.splitting_ops = []
if self.cudagraph_mode.has_piecewise_cudagraphs():
logger.warning_once(
"enable_attn_fusion is incompatible with piecewise "
"cudagraph when use_inductor_graph_partition is off."
"In this case, splitting_ops will be set to empty "
"list, and cudagraph_mode will be set to FULL. "
"Please ensure you are using attention backends that "
"support cudagraph or set cudagraph_mode to NONE "
"explicitly if encountering any problems."
)
self.cudagraph_mode = CUDAGraphMode.FULL
# For dynamo-partition (non-inductor) attention fusion,
# set splitting_ops to empty to avoid splitting at attention ops
self.splitting_ops = []
if self.cudagraph_mode.has_piecewise_cudagraphs():
logger.warning_once(
"enable_attn_fusion is incompatible with piecewise "
"cudagraph when use_inductor_graph_partition is off. "
"In this case, splitting_ops will be set to empty "
"list, and cudagraph_mode will be set to FULL. "
"Please ensure you are using attention backends that "
"support cudagraph or set cudagraph_mode to NONE "
"explicitly if encountering any problems."
)
self.cudagraph_mode = CUDAGraphMode.FULL
assert not self.splitting_ops_contain_attention(), (
"attention ops should not be in splitting_ops "
@@ -691,23 +701,17 @@ class CompilationConfig:
)
def is_attention_compiled_piecewise(self) -> bool:
use_fx_graph_piecewise_compilation = (
self.level == CompilationLevel.PIECEWISE
and self.splitting_ops_contain_attention()
)
if not self.splitting_ops_contain_attention():
return False
inductor_used = (
self.level == CompilationLevel.PIECEWISE and self.use_inductor
) or (
self.level >= CompilationLevel.DYNAMO_AS_IS and self.backend == "inductor"
)
use_inductor_piecewise_compilation = (
inductor_used
and self.use_inductor_graph_partition
and not self.splitting_ops_contain_attention()
)
if not self.use_inductor_graph_partition:
# Dynamo-level FX split case
return self.level == CompilationLevel.PIECEWISE
return use_fx_graph_piecewise_compilation or use_inductor_piecewise_compilation
# Inductor partition case
return (
self.level > CompilationLevel.NO_COMPILATION and self.backend == "inductor"
)
def custom_op_log_check(self):
"""