remove resolve_op_overloads and use splitting_ops directly (#28081)
Signed-off-by: Boyuan Feng <boyuan@meta.com>
This commit is contained in:
@@ -214,28 +214,72 @@ def test_splitting_ops_dynamic():
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||
|
||||
|
||||
def test_resolve_operator_overload():
|
||||
def test_should_split():
|
||||
import torch
|
||||
|
||||
from vllm.compilation.partition_rules import resolve_defined_ops
|
||||
from vllm.compilation.partition_rules import should_split
|
||||
|
||||
# Test valid operator names
|
||||
resolved = resolve_defined_ops(["aten::mm.default", "aten::addmm.default"])
|
||||
assert len(resolved) == 2
|
||||
assert resolved[0] is torch.ops.aten.mm.default
|
||||
assert resolved[1] is torch.ops.aten.addmm.default
|
||||
|
||||
# Test that invalid operators are skipped (not raising exceptions)
|
||||
resolved = resolve_defined_ops(
|
||||
[
|
||||
"aten::mm.default",
|
||||
"aten::nonexistent_op.default", # This should be skipped
|
||||
"aten::addmm.default",
|
||||
]
|
||||
graph = torch.fx.Graph()
|
||||
node = torch.fx.Node(
|
||||
graph=graph,
|
||||
name="dummy_node",
|
||||
op="call_function",
|
||||
target=torch.ops.aten.add.default,
|
||||
args=(),
|
||||
kwargs={},
|
||||
)
|
||||
assert len(resolved) == 2 # Only 2 valid ops
|
||||
assert resolved[0] is torch.ops.aten.mm.default
|
||||
assert resolved[1] is torch.ops.aten.addmm.default
|
||||
|
||||
# supports OpOverloadPacket
|
||||
splitting_ops = ["aten::add"]
|
||||
assert should_split(node, splitting_ops)
|
||||
|
||||
# supports OpOverload
|
||||
splitting_ops = ["aten::add.default"]
|
||||
assert should_split(node, splitting_ops)
|
||||
|
||||
# supports OpOverload
|
||||
splitting_ops = ["aten::add.Tensor"]
|
||||
assert not should_split(node, splitting_ops)
|
||||
|
||||
@torch.library.custom_op(
|
||||
"silly::attention",
|
||||
mutates_args=["out"],
|
||||
)
|
||||
def attention(
|
||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
|
||||
) -> None:
|
||||
out.copy_(q + k + v)
|
||||
|
||||
q, k, v, out = [torch.randn(1)] * 4
|
||||
|
||||
# supports custom ops as OpOverloadPacket
|
||||
node = torch.fx.Node(
|
||||
graph=graph,
|
||||
name="dummy_node",
|
||||
op="call_function",
|
||||
target=torch.ops.silly.attention,
|
||||
args=(q, k, v, out),
|
||||
kwargs={},
|
||||
)
|
||||
|
||||
splitting_ops = ["silly::attention"]
|
||||
assert should_split(node, splitting_ops)
|
||||
|
||||
# supports custom ops as OpOverload
|
||||
node = torch.fx.Node(
|
||||
graph=graph,
|
||||
name="dummy_node",
|
||||
op="call_function",
|
||||
target=torch.ops.silly.attention.default,
|
||||
args=(q, k, v, out),
|
||||
kwargs={},
|
||||
)
|
||||
|
||||
splitting_ops = ["silly::attention"]
|
||||
assert should_split(node, splitting_ops)
|
||||
|
||||
splitting_ops = ["silly::attention.default"]
|
||||
assert should_split(node, splitting_ops)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
|
||||
Reference in New Issue
Block a user