Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Richard Zou <zou3519@gmail.com>
(cherry picked from commit 15f40b20aa)
187 lines
6.5 KiB
Python
187 lines
6.5 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import operator
|
|
|
|
import pytest
|
|
import torch
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
|
|
from vllm.compilation.backends import split_graph
|
|
from vllm.compilation.fx_utils import find_op_nodes
|
|
|
|
# This import automatically registers `torch.ops.silly.attention`
|
|
from . import silly_attention # noqa: F401
|
|
|
|
|
|
def test_getitem_moved_to_producer_subgraph():
|
|
"""
|
|
Test that getitem operations are moved to the same subgraph as their input,
|
|
preventing tuple inputs to submodules.
|
|
"""
|
|
|
|
def model_fn(x: torch.Tensor) -> torch.Tensor:
|
|
# torch.split returns a tuple, creating real getitem operations
|
|
# Should become first submodule that produces tuple
|
|
chunks = torch.split(x, x.shape[0] // 2, dim=0)
|
|
|
|
# Following ops should become second submodule that consumes tuple
|
|
result_0 = torch.relu(chunks[0])
|
|
result_1 = torch.relu(chunks[1])
|
|
return torch.cat([result_0, result_1], dim=0)
|
|
|
|
x = torch.randn(4, 3)
|
|
gm = make_fx(model_fn)(x)
|
|
|
|
has_getitem = any(
|
|
node.op == "call_function" and node.target == operator.getitem
|
|
for node in gm.graph.nodes
|
|
)
|
|
assert has_getitem, "Test setup failed: graph should contain getitem operations"
|
|
|
|
# Split on tuple producer aten::split
|
|
split_ops = ["aten::split.Tensor"]
|
|
split_gm, split_items = split_graph(gm, split_ops)
|
|
assert len(split_items) == 2, "Graph should be split into 2 submodules"
|
|
|
|
for split_item in split_items:
|
|
submodule = split_item.graph
|
|
|
|
getitem_on_placeholder = []
|
|
for node in submodule.graph.nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target == operator.getitem
|
|
and node.args[0].op == "placeholder"
|
|
):
|
|
getitem_on_placeholder.append(node)
|
|
|
|
assert len(getitem_on_placeholder) == 0, (
|
|
f"Submodule {split_item.submod_name} has getitem operations on "
|
|
f"placeholder nodes: {[n.name for n in getitem_on_placeholder]}. "
|
|
"This means tuple inputs were not properly eliminated."
|
|
)
|
|
|
|
new_x = torch.randn(4, 3)
|
|
output_original = gm(new_x)
|
|
output_split = split_gm(new_x)
|
|
|
|
assert torch.allclose(output_original, output_split), "Output mismatch"
|
|
|
|
|
|
def test_no_tuple_inputs_with_multiple_consumers():
|
|
"""
|
|
Test that when a tuple is consumed by multiple split operations,
|
|
getitem operations are properly moved to avoid tuple inputs.
|
|
"""
|
|
|
|
def model_fn(x: torch.Tensor) -> torch.Tensor:
|
|
# torch.split returns a tuple, creating real getitem operations
|
|
# Should become first submodule that produces tuple
|
|
chunks = torch.split(x, x.shape[0] // 2, dim=0)
|
|
|
|
# These should become second submodule consuming tuple
|
|
result_1 = torch.relu(chunks[0])
|
|
result_2 = torch.relu(chunks[1])
|
|
|
|
# Artificial graph splitting point to create another
|
|
# independent submodule that consumes tuple later
|
|
# This would become the third submodule
|
|
result_1 = torch.sigmoid(result_1)
|
|
|
|
# Fourth submodule that consumes tuple
|
|
result = torch.cat([chunks[0], chunks[1], result_1, result_2])
|
|
return result
|
|
|
|
x = torch.randn(4, 3)
|
|
gm = make_fx(model_fn)(x)
|
|
|
|
has_getitem = any(
|
|
node.op == "call_function" and node.target == operator.getitem
|
|
for node in gm.graph.nodes
|
|
)
|
|
assert has_getitem, "Test setup failed: graph should contain getitem operations"
|
|
|
|
split_ops = ["aten::split.Tensor", "aten::sigmoid"]
|
|
split_gm, split_items = split_graph(gm, split_ops)
|
|
assert len(split_items) == 4, "Graph should be split into 4 submodules"
|
|
|
|
for split_item in split_items:
|
|
submodule = split_item.graph
|
|
|
|
for node in submodule.graph.nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target == operator.getitem
|
|
and node.args[0].op == "placeholder"
|
|
):
|
|
pytest.fail(
|
|
f"Submodule {split_item.submod_name} has getitem on "
|
|
f"placeholder {node.args[0].name}, indicating it receives "
|
|
"a tuple input"
|
|
)
|
|
|
|
new_x = torch.randn(4, 3)
|
|
output_original = gm(new_x)
|
|
output_split = split_gm(new_x)
|
|
|
|
assert torch.allclose(output_original, output_split), "Output mismatch after split"
|
|
|
|
|
|
def test_consecutive_ops_in_split():
|
|
"""
|
|
Test that consecutive splitting operations are grouped into the same subgraph
|
|
"""
|
|
|
|
def model_fn(x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Define a simple model where consecutive operations create opportunities
|
|
for splitting subgraphs.
|
|
"""
|
|
# Apply silly attention followed by consecutive operations
|
|
intermediate = torch.relu(x)
|
|
attn_inout = torch.sqrt(intermediate)
|
|
torch.ops.silly.attention(intermediate, intermediate, attn_inout, attn_inout)
|
|
final_result = torch.sigmoid(attn_inout)
|
|
return final_result
|
|
|
|
torch.set_default_device("cuda")
|
|
|
|
# Create the traced FX graph for the model
|
|
x = torch.randn(8, 4)
|
|
|
|
gm = make_fx(model_fn)(x)
|
|
|
|
# Assert presence of the expected operations in the setup
|
|
assert (
|
|
len(list(find_op_nodes(torch.ops.aten.relu, gm.graph))) == 1
|
|
and len(list(find_op_nodes(torch.ops.aten.sqrt, gm.graph))) == 1
|
|
), "Test setup failed: Expected sqrt and relu operations in the graph."
|
|
|
|
# Configure split operations to test
|
|
splitting_ops = ["silly::attention", "aten::sqrt"]
|
|
split_gm, split_items = split_graph(gm, splitting_ops)
|
|
|
|
# Validate the number of partitions
|
|
assert len(split_items) == 3, (
|
|
"Consecutive splitting operations were not grouped correctly."
|
|
)
|
|
|
|
# Validate that correctness is preserved
|
|
new_x = torch.randn(8, 4)
|
|
output_original = gm(new_x)
|
|
output_split = split_gm(new_x)
|
|
assert torch.allclose(output_original, output_split), (
|
|
"Output mismatch after splitting."
|
|
)
|
|
|
|
# Check the splitting item has 2 nodes exactly (relu and attn)
|
|
splitting_items = list(s for s in split_items if s.is_splitting_graph)
|
|
assert len(splitting_items) == 1, "Expecting a single splitting graph"
|
|
print(splitting_items[0].graph.graph)
|
|
splitting_gm = splitting_items[0].graph
|
|
assert len(splitting_gm.graph.nodes) == 4, "Expecting 4 nodes in splitting graph"
|
|
assert [node.op for node in splitting_gm.graph.nodes] == ["placeholder"] + 2 * [
|
|
"call_function"
|
|
] + ["output"]
|