remove resolve_op_overloads and use splitting_ops directly (#28081)

Signed-off-by: Boyuan Feng <boyuan@meta.com>
This commit is contained in:
Boyuan Feng
2025-11-07 17:13:13 -08:00
committed by GitHub
parent 1aaecda078
commit b158df2813
3 changed files with 89 additions and 65 deletions

View File

@@ -214,28 +214,72 @@ def test_splitting_ops_dynamic():
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
def test_resolve_operator_overload():
def test_should_split():
import torch
from vllm.compilation.partition_rules import resolve_defined_ops
from vllm.compilation.partition_rules import should_split
# Test valid operator names
resolved = resolve_defined_ops(["aten::mm.default", "aten::addmm.default"])
assert len(resolved) == 2
assert resolved[0] is torch.ops.aten.mm.default
assert resolved[1] is torch.ops.aten.addmm.default
# Test that invalid operators are skipped (not raising exceptions)
resolved = resolve_defined_ops(
[
"aten::mm.default",
"aten::nonexistent_op.default", # This should be skipped
"aten::addmm.default",
]
graph = torch.fx.Graph()
node = torch.fx.Node(
graph=graph,
name="dummy_node",
op="call_function",
target=torch.ops.aten.add.default,
args=(),
kwargs={},
)
assert len(resolved) == 2 # Only 2 valid ops
assert resolved[0] is torch.ops.aten.mm.default
assert resolved[1] is torch.ops.aten.addmm.default
# supports OpOverloadPacket
splitting_ops = ["aten::add"]
assert should_split(node, splitting_ops)
# supports OpOverload
splitting_ops = ["aten::add.default"]
assert should_split(node, splitting_ops)
# supports OpOverload
splitting_ops = ["aten::add.Tensor"]
assert not should_split(node, splitting_ops)
@torch.library.custom_op(
"silly::attention",
mutates_args=["out"],
)
def attention(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
) -> None:
out.copy_(q + k + v)
q, k, v, out = [torch.randn(1)] * 4
# supports custom ops as OpOverloadPacket
node = torch.fx.Node(
graph=graph,
name="dummy_node",
op="call_function",
target=torch.ops.silly.attention,
args=(q, k, v, out),
kwargs={},
)
splitting_ops = ["silly::attention"]
assert should_split(node, splitting_ops)
# supports custom ops as OpOverload
node = torch.fx.Node(
graph=graph,
name="dummy_node",
op="call_function",
target=torch.ops.silly.attention.default,
args=(q, k, v, out),
kwargs={},
)
splitting_ops = ["silly::attention"]
assert should_split(node, splitting_ops)
splitting_ops = ["silly::attention.default"]
assert should_split(node, splitting_ops)
@pytest.mark.skipif(