Files
vllm/tests/compile/test_graph_partition.py

330 lines
11 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import operator
import pytest
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from vllm.compilation.backends import _is_empty_allocation_node, split_graph
from vllm.compilation.passes.fx_utils import find_op_nodes
# This import automatically registers `torch.ops.silly.attention`
from . import silly_attention # noqa: F401
def test_getitem_moved_to_producer_subgraph():
"""
Test that getitem operations are moved to the same subgraph as their input,
preventing tuple inputs to submodules.
"""
def model_fn(x: torch.Tensor) -> torch.Tensor:
# torch.split returns a tuple, creating real getitem operations
# Should become first submodule that produces tuple
chunks = torch.split(x, x.shape[0] // 2, dim=0)
# Following ops should become second submodule that consumes tuple
result_0 = torch.relu(chunks[0])
result_1 = torch.relu(chunks[1])
return torch.cat([result_0, result_1], dim=0)
x = torch.randn(4, 3)
gm = make_fx(model_fn)(x)
has_getitem = any(
node.op == "call_function" and node.target == operator.getitem
for node in gm.graph.nodes
)
assert has_getitem, "Test setup failed: graph should contain getitem operations"
# Split on tuple producer aten::split
split_ops = ["aten::split.Tensor"]
split_gm, split_items = split_graph(gm, split_ops)
assert len(split_items) == 2, "Graph should be split into 2 submodules"
for split_item in split_items:
submodule = split_item.graph
getitem_on_placeholder = []
for node in submodule.graph.nodes:
if (
node.op == "call_function"
and node.target == operator.getitem
and node.args[0].op == "placeholder"
):
getitem_on_placeholder.append(node)
assert len(getitem_on_placeholder) == 0, (
f"Submodule {split_item.submod_name} has getitem operations on "
f"placeholder nodes: {[n.name for n in getitem_on_placeholder]}. "
"This means tuple inputs were not properly eliminated."
)
new_x = torch.randn(4, 3)
output_original = gm(new_x)
output_split = split_gm(new_x)
assert torch.allclose(output_original, output_split), "Output mismatch"
def test_no_tuple_inputs_with_multiple_consumers():
"""
Test that when a tuple is consumed by multiple split operations,
getitem operations are properly moved to avoid tuple inputs.
"""
def model_fn(x: torch.Tensor) -> torch.Tensor:
# torch.split returns a tuple, creating real getitem operations
# Should become first submodule that produces tuple
chunks = torch.split(x, x.shape[0] // 2, dim=0)
# These should become second submodule consuming tuple
result_1 = torch.relu(chunks[0])
result_2 = torch.relu(chunks[1])
# Artificial graph splitting point to create another
# independent submodule that consumes tuple later
# This would become the third submodule
result_1 = torch.sigmoid(result_1)
# Fourth submodule that consumes tuple
result = torch.cat([chunks[0], chunks[1], result_1, result_2])
return result
x = torch.randn(4, 3)
gm = make_fx(model_fn)(x)
has_getitem = any(
node.op == "call_function" and node.target == operator.getitem
for node in gm.graph.nodes
)
assert has_getitem, "Test setup failed: graph should contain getitem operations"
split_ops = ["aten::split.Tensor", "aten::sigmoid"]
split_gm, split_items = split_graph(gm, split_ops)
assert len(split_items) == 4, "Graph should be split into 4 submodules"
for split_item in split_items:
submodule = split_item.graph
for node in submodule.graph.nodes:
if (
node.op == "call_function"
and node.target == operator.getitem
and node.args[0].op == "placeholder"
):
pytest.fail(
f"Submodule {split_item.submod_name} has getitem on "
f"placeholder {node.args[0].name}, indicating it receives "
"a tuple input"
)
new_x = torch.randn(4, 3)
output_original = gm(new_x)
output_split = split_gm(new_x)
assert torch.allclose(output_original, output_split), "Output mismatch after split"
def test_consecutive_ops_in_split():
"""
Test that consecutive splitting operations are grouped into the same subgraph
"""
def model_fn(x: torch.Tensor) -> torch.Tensor:
"""
Define a simple model where consecutive operations create opportunities
for splitting subgraphs.
"""
# Apply silly attention followed by consecutive operations
intermediate = torch.relu(x)
attn_inout = torch.sqrt(intermediate)
torch.ops.silly.attention(intermediate, intermediate, attn_inout, attn_inout)
final_result = torch.sigmoid(attn_inout)
return final_result
torch.set_default_device("cuda")
# Create the traced FX graph for the model
x = torch.randn(8, 4)
gm = make_fx(model_fn)(x)
# Assert presence of the expected operations in the setup
assert (
len(list(find_op_nodes(torch.ops.aten.relu, gm.graph))) == 1
and len(list(find_op_nodes(torch.ops.aten.sqrt, gm.graph))) == 1
), "Test setup failed: Expected sqrt and relu operations in the graph."
# Configure split operations to test
splitting_ops = ["silly::attention", "aten::sqrt"]
split_gm, split_items = split_graph(gm, splitting_ops)
# Validate the number of partitions
assert len(split_items) == 3, (
"Consecutive splitting operations were not grouped correctly."
)
# Validate that correctness is preserved
new_x = torch.randn(8, 4)
output_original = gm(new_x)
output_split = split_gm(new_x)
assert torch.allclose(output_original, output_split), (
"Output mismatch after splitting."
)
# Check the splitting item has 2 nodes exactly (relu and attn)
splitting_items = list(s for s in split_items if s.is_splitting_graph)
assert len(splitting_items) == 1, "Expecting a single splitting graph"
print(splitting_items[0].graph.graph)
splitting_gm = splitting_items[0].graph
assert len(splitting_gm.graph.nodes) == 4, "Expecting 4 nodes in splitting graph"
assert [node.op for node in splitting_gm.graph.nodes] == ["placeholder"] + 2 * [
"call_function"
] + ["output"]
def _get_empty_nodes(split_item):
return [
node for node in split_item.graph.graph.nodes if _is_empty_allocation_node(node)
]
def _subgraphs_with_empty_nodes(split_items, *, is_splitting_graph):
return [
split_item
for split_item in split_items
if split_item.is_splitting_graph == is_splitting_graph
and _get_empty_nodes(split_item)
]
def test_empty_only_partition_stays_separate_after_splitting_predecessor():
"""
Empty-only subgraphs should not be merged when the only predecessor is
a splitting-op subgraph.
"""
def model_fn(x: torch.Tensor) -> torch.Tensor:
y = torch.sin(x)
out = torch.empty_like(y)
torch.ops.aten.cos.out(y, out=out)
return out
x = torch.randn(4, 3)
gm = make_fx(model_fn)(x)
split_ops = ["aten::sin", "aten::cos.out"]
split_gm, split_items = split_graph(gm, split_ops)
# Graph partitioning for this pattern is:
# [sin], [empty_like], [cos.out].
assert len(split_items) == 3, (
"Empty-only partition should not merge into splitting-op subgraph"
)
splitting_with_empty = _subgraphs_with_empty_nodes(
split_items, is_splitting_graph=True
)
assert len(splitting_with_empty) == 0, (
"Splitting-op subgraphs should not contain empty allocation nodes: "
f"{[item.submod_name for item in splitting_with_empty]}"
)
output_original = gm(x)
output_split = split_gm(x)
assert torch.allclose(output_original, output_split), "Output mismatch after split"
def test_empty_only_partition_is_merged():
"""
Empty-only subgraphs should still be merged when a non-splitting predecessor
exists. The merged empty node must remain outside splitting-op subgraphs.
"""
def model_fn(x: torch.Tensor) -> torch.Tensor:
base = x + 1
y = torch.sin(base)
out = torch.empty_like(base)
torch.ops.aten.cos.out(base, out=out)
return out + y
x = torch.randn(4, 3)
gm = make_fx(model_fn)(x)
split_gm, split_items = split_graph(gm, ["aten::sin", "aten::cos.out"])
# Partitioning should be:
# [add, empty_like], [sin], [cos.out], [add].
assert len(split_items) == 4, (
"Empty-only partition should be merged into non-splitting predecessor"
)
splitting_with_empty = _subgraphs_with_empty_nodes(
split_items, is_splitting_graph=True
)
assert len(splitting_with_empty) == 0, (
"Splitting-op subgraphs should not contain empty allocation nodes: "
f"{[item.submod_name for item in splitting_with_empty]}"
)
non_splitting_with_empty = _subgraphs_with_empty_nodes(
split_items, is_splitting_graph=False
)
assert len(non_splitting_with_empty) == 1, (
"Exactly one non-splitting subgraph should contain the merged empty node"
)
assert len(_get_empty_nodes(non_splitting_with_empty[0])) == 1, (
"Expected exactly one empty allocation node in merged subgraph"
)
output_original = gm(x)
output_split = split_gm(x)
assert torch.allclose(output_original, output_split), "Output mismatch after split"
def test_builtin_empty_only_partition_is_merged():
"""
In Dynamo graphs, torch.empty/empty_like may appear as builtin call targets
(not aten OpOverload). Ensure empty-only partitions are still merged.
"""
def model_fn(x: torch.Tensor) -> torch.Tensor:
hidden = x + 1
out1 = torch.empty_like(hidden)
torch.ops.silly.attention(hidden, hidden, hidden, out1)
out2 = torch.empty_like(hidden)
torch.ops.silly.attention(out1, out1, hidden, out2)
return out2 + hidden
gm = torch.fx.symbolic_trace(model_fn)
split_gm, split_items = split_graph(gm, ["silly::attention"])
# Without empty-only merge, this graph would split into:
# [add, empty_like], [attention], [empty_like], [attention], [add].
assert len(split_items) == 4, "Builtin empty-only partition should be merged"
splitting_with_empty = _subgraphs_with_empty_nodes(
split_items, is_splitting_graph=True
)
assert len(splitting_with_empty) == 0, (
"Splitting-op subgraphs should not contain empty allocation nodes: "
f"{[item.submod_name for item in splitting_with_empty]}"
)
non_splitting_with_empty = _subgraphs_with_empty_nodes(
split_items, is_splitting_graph=False
)
assert len(non_splitting_with_empty) == 1, (
"Exactly one non-splitting subgraph should contain merged empty nodes"
)
assert len(_get_empty_nodes(non_splitting_with_empty[0])) == 2, (
"Expected two builtin empty_like nodes in merged non-splitting subgraph"
)
x = torch.randn(2, 3, device="cuda")
output_original = gm(x)
output_split = split_gm(x)
assert torch.allclose(output_original, output_split), "Output mismatch after split"