remove resolve_op_overloads and use splitting_ops directly (#28081)
Signed-off-by: Boyuan Feng <boyuan@meta.com>
This commit is contained in:
@@ -19,7 +19,7 @@ import vllm.envs as envs
|
||||
from vllm.compilation.inductor_pass import pass_context
|
||||
from vllm.compilation.partition_rules import (
|
||||
inductor_partition_rule_context,
|
||||
resolve_defined_ops,
|
||||
should_split,
|
||||
)
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
@@ -303,7 +303,7 @@ class SplitItem:
|
||||
|
||||
|
||||
def split_graph(
|
||||
graph: fx.GraphModule, resolved_ops: list[torch._ops.OpOverload]
|
||||
graph: fx.GraphModule, splitting_ops: list[str]
|
||||
) -> tuple[fx.GraphModule, list[SplitItem]]:
|
||||
# split graph by ops
|
||||
subgraph_id = 0
|
||||
@@ -312,12 +312,8 @@ def split_graph(
|
||||
for node in graph.graph.nodes:
|
||||
if node.op in ("output", "placeholder"):
|
||||
continue
|
||||
# Match node.target against resolved_ops
|
||||
# node.target can be OpOverloadPacket, need to check .default
|
||||
if node.op == "call_function" and (
|
||||
node.target in resolved_ops
|
||||
or (hasattr(node.target, "default") and node.target.default in resolved_ops)
|
||||
):
|
||||
|
||||
if should_split(node, splitting_ops):
|
||||
subgraph_id += 1
|
||||
node_to_subgraph_id[node] = subgraph_id
|
||||
split_op_graphs.append(subgraph_id)
|
||||
@@ -653,8 +649,7 @@ class VllmBackend:
|
||||
else:
|
||||
fx_split_ops = self.compilation_config.splitting_ops or []
|
||||
|
||||
resolved_split_ops = resolve_defined_ops(fx_split_ops)
|
||||
self.split_gm, self.piecewise_graphs = split_graph(graph, resolved_split_ops)
|
||||
self.split_gm, self.piecewise_graphs = split_graph(graph, fx_split_ops)
|
||||
|
||||
from torch._dynamo.utils import lazy_format_graph_code
|
||||
|
||||
|
||||
@@ -2,54 +2,39 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch._library.utils import lookup_op
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def resolve_defined_ops(op_names: list[str]) -> list["torch._ops.OpOverload"]:
|
||||
"""Resolve operator names to OpOverload objects.
|
||||
|
||||
Skips operators that fail to resolve (e.g., operators not registered or
|
||||
model-specific operators not present in the current model).
|
||||
|
||||
Note: Users should inspect the operator graph before lowering and ensure
|
||||
the specified operators are present in the final graph. Built-in PyTorch
|
||||
operators (aten::*, torch::*) may be decomposed, fused, or transformed
|
||||
during Inductor's compilation passes, so use them with caution.
|
||||
|
||||
Args:
|
||||
op_names: List of operator names in PyTorch format
|
||||
(e.g., "vllm::unified_attention")
|
||||
|
||||
Returns:
|
||||
List of successfully resolved operator overloads
|
||||
def should_split(node: torch.fx.Node, splitting_ops: list[str]) -> bool:
|
||||
"""
|
||||
Check if a node should be split for dynamo graph partition.
|
||||
It operates on dynamo graph, so the node.target can be anything.
|
||||
We need to check and split only on OpOverload and OpOverloadPacket.
|
||||
"""
|
||||
resolved = []
|
||||
for op_name in op_names:
|
||||
try:
|
||||
resolved.append(lookup_op(op_name))
|
||||
except Exception:
|
||||
# Skip operators that don't exist (e.g., model-specific ops)
|
||||
# Do not warn for attention ops, warn for others
|
||||
# (most likely manually specified)
|
||||
from vllm.config import CompilationConfig
|
||||
|
||||
logger.log(
|
||||
logging.DEBUG
|
||||
if op_name in CompilationConfig._attention_ops
|
||||
else logging.WARNING,
|
||||
"Failed to resolve operator for CUDAGraph partition: %s",
|
||||
op_name,
|
||||
)
|
||||
continue
|
||||
if node.op != "call_function":
|
||||
return False
|
||||
|
||||
return resolved
|
||||
target = node.target
|
||||
|
||||
if isinstance(target, torch._ops.OpOverloadPacket):
|
||||
# Example: "aten::add"
|
||||
return target._qualified_op_name in splitting_ops
|
||||
|
||||
if isinstance(target, torch._ops.OpOverload):
|
||||
# Example: "aten::add"
|
||||
packet_name = target.name()
|
||||
|
||||
# Example: "aten::add.default"
|
||||
op_overload_name = f"{packet_name}.{target._overloadname}"
|
||||
return op_overload_name in splitting_ops or packet_name in splitting_ops
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
||||
Reference in New Issue
Block a user