# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import operator import pytest import torch from torch.fx.experimental.proxy_tensor import make_fx from vllm.compilation.backends import split_graph from vllm.compilation.passes.fx_utils import find_op_nodes # This import automatically registers `torch.ops.silly.attention` from . import silly_attention # noqa: F401 def test_getitem_moved_to_producer_subgraph(): """ Test that getitem operations are moved to the same subgraph as their input, preventing tuple inputs to submodules. """ def model_fn(x: torch.Tensor) -> torch.Tensor: # torch.split returns a tuple, creating real getitem operations # Should become first submodule that produces tuple chunks = torch.split(x, x.shape[0] // 2, dim=0) # Following ops should become second submodule that consumes tuple result_0 = torch.relu(chunks[0]) result_1 = torch.relu(chunks[1]) return torch.cat([result_0, result_1], dim=0) x = torch.randn(4, 3) gm = make_fx(model_fn)(x) has_getitem = any( node.op == "call_function" and node.target == operator.getitem for node in gm.graph.nodes ) assert has_getitem, "Test setup failed: graph should contain getitem operations" # Split on tuple producer aten::split split_ops = ["aten::split.Tensor"] split_gm, split_items = split_graph(gm, split_ops) assert len(split_items) == 2, "Graph should be split into 2 submodules" for split_item in split_items: submodule = split_item.graph getitem_on_placeholder = [] for node in submodule.graph.nodes: if ( node.op == "call_function" and node.target == operator.getitem and node.args[0].op == "placeholder" ): getitem_on_placeholder.append(node) assert len(getitem_on_placeholder) == 0, ( f"Submodule {split_item.submod_name} has getitem operations on " f"placeholder nodes: {[n.name for n in getitem_on_placeholder]}. " "This means tuple inputs were not properly eliminated." ) new_x = torch.randn(4, 3) output_original = gm(new_x) output_split = split_gm(new_x) assert torch.allclose(output_original, output_split), "Output mismatch" def test_no_tuple_inputs_with_multiple_consumers(): """ Test that when a tuple is consumed by multiple split operations, getitem operations are properly moved to avoid tuple inputs. """ def model_fn(x: torch.Tensor) -> torch.Tensor: # torch.split returns a tuple, creating real getitem operations # Should become first submodule that produces tuple chunks = torch.split(x, x.shape[0] // 2, dim=0) # These should become second submodule consuming tuple result_1 = torch.relu(chunks[0]) result_2 = torch.relu(chunks[1]) # Artificial graph splitting point to create another # independent submodule that consumes tuple later # This would become the third submodule result_1 = torch.sigmoid(result_1) # Fourth submodule that consumes tuple result = torch.cat([chunks[0], chunks[1], result_1, result_2]) return result x = torch.randn(4, 3) gm = make_fx(model_fn)(x) has_getitem = any( node.op == "call_function" and node.target == operator.getitem for node in gm.graph.nodes ) assert has_getitem, "Test setup failed: graph should contain getitem operations" split_ops = ["aten::split.Tensor", "aten::sigmoid"] split_gm, split_items = split_graph(gm, split_ops) assert len(split_items) == 4, "Graph should be split into 4 submodules" for split_item in split_items: submodule = split_item.graph for node in submodule.graph.nodes: if ( node.op == "call_function" and node.target == operator.getitem and node.args[0].op == "placeholder" ): pytest.fail( f"Submodule {split_item.submod_name} has getitem on " f"placeholder {node.args[0].name}, indicating it receives " "a tuple input" ) new_x = torch.randn(4, 3) output_original = gm(new_x) output_split = split_gm(new_x) assert torch.allclose(output_original, output_split), "Output mismatch after split" def test_consecutive_ops_in_split(): """ Test that consecutive splitting operations are grouped into the same subgraph """ def model_fn(x: torch.Tensor) -> torch.Tensor: """ Define a simple model where consecutive operations create opportunities for splitting subgraphs. """ # Apply silly attention followed by consecutive operations intermediate = torch.relu(x) attn_inout = torch.sqrt(intermediate) torch.ops.silly.attention(intermediate, intermediate, attn_inout, attn_inout) final_result = torch.sigmoid(attn_inout) return final_result torch.set_default_device("cuda") # Create the traced FX graph for the model x = torch.randn(8, 4) gm = make_fx(model_fn)(x) # Assert presence of the expected operations in the setup assert ( len(list(find_op_nodes(torch.ops.aten.relu, gm.graph))) == 1 and len(list(find_op_nodes(torch.ops.aten.sqrt, gm.graph))) == 1 ), "Test setup failed: Expected sqrt and relu operations in the graph." # Configure split operations to test splitting_ops = ["silly::attention", "aten::sqrt"] split_gm, split_items = split_graph(gm, splitting_ops) # Validate the number of partitions assert len(split_items) == 3, ( "Consecutive splitting operations were not grouped correctly." ) # Validate that correctness is preserved new_x = torch.randn(8, 4) output_original = gm(new_x) output_split = split_gm(new_x) assert torch.allclose(output_original, output_split), ( "Output mismatch after splitting." ) # Check the splitting item has 2 nodes exactly (relu and attn) splitting_items = list(s for s in split_items if s.is_splitting_graph) assert len(splitting_items) == 1, "Expecting a single splitting graph" print(splitting_items[0].graph.graph) splitting_gm = splitting_items[0].graph assert len(splitting_gm.graph.nodes) == 4, "Expecting 4 nodes in splitting graph" assert [node.op for node in splitting_gm.graph.nodes] == ["placeholder"] + 2 * [ "call_function" ] + ["output"]