# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import operator import pytest import torch import torch._dynamo import torch.fx as fx from torch.fx.experimental.proxy_tensor import make_fx from vllm.compilation.backends import _is_empty_allocation_node, split_graph from vllm.compilation.passes.fx_utils import find_op_nodes # This import automatically registers `torch.ops.silly.attention` from . import silly_attention # noqa: F401 def test_getitem_moved_to_producer_subgraph(): """ Test that getitem operations are moved to the same subgraph as their input, preventing tuple inputs to submodules. """ def model_fn(x: torch.Tensor) -> torch.Tensor: # torch.split returns a tuple, creating real getitem operations # Should become first submodule that produces tuple chunks = torch.split(x, x.shape[0] // 2, dim=0) # Following ops should become second submodule that consumes tuple result_0 = torch.relu(chunks[0]) result_1 = torch.relu(chunks[1]) return torch.cat([result_0, result_1], dim=0) x = torch.randn(4, 3) gm = make_fx(model_fn)(x) has_getitem = any( node.op == "call_function" and node.target == operator.getitem for node in gm.graph.nodes ) assert has_getitem, "Test setup failed: graph should contain getitem operations" # Split on tuple producer aten::split split_ops = ["aten::split.Tensor"] split_gm, split_items = split_graph(gm, split_ops) assert len(split_items) == 2, "Graph should be split into 2 submodules" for split_item in split_items: submodule = split_item.graph getitem_on_placeholder = [] for node in submodule.graph.nodes: if ( node.op == "call_function" and node.target == operator.getitem and node.args[0].op == "placeholder" ): getitem_on_placeholder.append(node) assert len(getitem_on_placeholder) == 0, ( f"Submodule {split_item.submod_name} has getitem operations on " f"placeholder nodes: {[n.name for n in getitem_on_placeholder]}. " "This means tuple inputs were not properly eliminated." ) new_x = torch.randn(4, 3) output_original = gm(new_x) output_split = split_gm(new_x) assert torch.allclose(output_original, output_split), "Output mismatch" def test_no_tuple_inputs_with_multiple_consumers(): """ Test that when a tuple is consumed by multiple split operations, getitem operations are properly moved to avoid tuple inputs. """ def model_fn(x: torch.Tensor) -> torch.Tensor: # torch.split returns a tuple, creating real getitem operations # Should become first submodule that produces tuple chunks = torch.split(x, x.shape[0] // 2, dim=0) # These should become second submodule consuming tuple result_1 = torch.relu(chunks[0]) result_2 = torch.relu(chunks[1]) # Artificial graph splitting point to create another # independent submodule that consumes tuple later # This would become the third submodule result_1 = torch.sigmoid(result_1) # Fourth submodule that consumes tuple result = torch.cat([chunks[0], chunks[1], result_1, result_2]) return result x = torch.randn(4, 3) gm = make_fx(model_fn)(x) has_getitem = any( node.op == "call_function" and node.target == operator.getitem for node in gm.graph.nodes ) assert has_getitem, "Test setup failed: graph should contain getitem operations" split_ops = ["aten::split.Tensor", "aten::sigmoid"] split_gm, split_items = split_graph(gm, split_ops) assert len(split_items) == 4, "Graph should be split into 4 submodules" for split_item in split_items: submodule = split_item.graph for node in submodule.graph.nodes: if ( node.op == "call_function" and node.target == operator.getitem and node.args[0].op == "placeholder" ): pytest.fail( f"Submodule {split_item.submod_name} has getitem on " f"placeholder {node.args[0].name}, indicating it receives " "a tuple input" ) new_x = torch.randn(4, 3) output_original = gm(new_x) output_split = split_gm(new_x) assert torch.allclose(output_original, output_split), "Output mismatch after split" def test_consecutive_ops_in_split(): """ Test that consecutive splitting operations are grouped into the same subgraph """ def model_fn(x: torch.Tensor) -> torch.Tensor: """ Define a simple model where consecutive operations create opportunities for splitting subgraphs. """ # Apply silly attention followed by consecutive operations intermediate = torch.relu(x) attn_inout = torch.sqrt(intermediate) torch.ops.silly.attention(intermediate, intermediate, attn_inout, attn_inout) final_result = torch.sigmoid(attn_inout) return final_result torch.set_default_device("cuda") # Create the traced FX graph for the model x = torch.randn(8, 4) gm = make_fx(model_fn)(x) # Assert presence of the expected operations in the setup assert ( len(list(find_op_nodes(torch.ops.aten.relu, gm.graph))) == 1 and len(list(find_op_nodes(torch.ops.aten.sqrt, gm.graph))) == 1 ), "Test setup failed: Expected sqrt and relu operations in the graph." # Configure split operations to test splitting_ops = ["silly::attention", "aten::sqrt"] split_gm, split_items = split_graph(gm, splitting_ops) # Validate the number of partitions assert len(split_items) == 3, ( "Consecutive splitting operations were not grouped correctly." ) # Validate that correctness is preserved new_x = torch.randn(8, 4) output_original = gm(new_x) output_split = split_gm(new_x) assert torch.allclose(output_original, output_split), ( "Output mismatch after splitting." ) # Check the splitting item has 2 nodes exactly (relu and attn) splitting_items = list(s for s in split_items if s.is_splitting_graph) assert len(splitting_items) == 1, "Expecting a single splitting graph" print(splitting_items[0].graph.graph) splitting_gm = splitting_items[0].graph assert len(splitting_gm.graph.nodes) == 4, "Expecting 4 nodes in splitting graph" assert [node.op for node in splitting_gm.graph.nodes] == ["placeholder"] + 2 * [ "call_function" ] + ["output"] def _get_empty_nodes(split_item): return [ node for node in split_item.graph.graph.nodes if _is_empty_allocation_node(node) ] def _subgraphs_with_empty_nodes(split_items, *, is_splitting_graph): return [ split_item for split_item in split_items if split_item.is_splitting_graph == is_splitting_graph and _get_empty_nodes(split_item) ] def test_empty_only_partition_stays_separate_after_splitting_predecessor(): """ Empty-only subgraphs should not be merged when the only predecessor is a splitting-op subgraph. """ def model_fn(x: torch.Tensor) -> torch.Tensor: y = torch.sin(x) out = torch.empty_like(y) torch.ops.aten.cos.out(y, out=out) return out x = torch.randn(4, 3) gm = make_fx(model_fn)(x) split_ops = ["aten::sin", "aten::cos.out"] split_gm, split_items = split_graph(gm, split_ops) # Graph partitioning for this pattern is: # [sin], [empty_like], [cos.out]. assert len(split_items) == 3, ( "Empty-only partition should not merge into splitting-op subgraph" ) splitting_with_empty = _subgraphs_with_empty_nodes( split_items, is_splitting_graph=True ) assert len(splitting_with_empty) == 0, ( "Splitting-op subgraphs should not contain empty allocation nodes: " f"{[item.submod_name for item in splitting_with_empty]}" ) output_original = gm(x) output_split = split_gm(x) assert torch.allclose(output_original, output_split), "Output mismatch after split" def test_empty_only_partition_is_merged(): """ Empty-only subgraphs should still be merged when a non-splitting predecessor exists. The merged empty node must remain outside splitting-op subgraphs. """ def model_fn(x: torch.Tensor) -> torch.Tensor: base = x + 1 y = torch.sin(base) out = torch.empty_like(base) torch.ops.aten.cos.out(base, out=out) return out + y x = torch.randn(4, 3) gm = make_fx(model_fn)(x) split_gm, split_items = split_graph(gm, ["aten::sin", "aten::cos.out"]) # Partitioning should be: # [add, empty_like], [sin], [cos.out], [add]. assert len(split_items) == 4, ( "Empty-only partition should be merged into non-splitting predecessor" ) splitting_with_empty = _subgraphs_with_empty_nodes( split_items, is_splitting_graph=True ) assert len(splitting_with_empty) == 0, ( "Splitting-op subgraphs should not contain empty allocation nodes: " f"{[item.submod_name for item in splitting_with_empty]}" ) non_splitting_with_empty = _subgraphs_with_empty_nodes( split_items, is_splitting_graph=False ) assert len(non_splitting_with_empty) == 1, ( "Exactly one non-splitting subgraph should contain the merged empty node" ) assert len(_get_empty_nodes(non_splitting_with_empty[0])) == 1, ( "Expected exactly one empty allocation node in merged subgraph" ) output_original = gm(x) output_split = split_gm(x) assert torch.allclose(output_original, output_split), "Output mismatch after split" def test_builtin_empty_only_partition_is_merged(): """ In Dynamo graphs, torch.empty/empty_like may appear as builtin call targets (not aten OpOverload). Ensure empty-only partitions are still merged. """ def model_fn(x: torch.Tensor) -> torch.Tensor: hidden = x + 1 out1 = torch.empty_like(hidden) torch.ops.silly.attention(hidden, hidden, hidden, out1) out2 = torch.empty_like(hidden) torch.ops.silly.attention(out1, out1, hidden, out2) return out2 + hidden gm = torch.fx.symbolic_trace(model_fn) split_gm, split_items = split_graph(gm, ["silly::attention"]) # Without empty-only merge, this graph would split into: # [add, empty_like], [attention], [empty_like], [attention], [add]. assert len(split_items) == 4, "Builtin empty-only partition should be merged" splitting_with_empty = _subgraphs_with_empty_nodes( split_items, is_splitting_graph=True ) assert len(splitting_with_empty) == 0, ( "Splitting-op subgraphs should not contain empty allocation nodes: " f"{[item.submod_name for item in splitting_with_empty]}" ) non_splitting_with_empty = _subgraphs_with_empty_nodes( split_items, is_splitting_graph=False ) assert len(non_splitting_with_empty) == 1, ( "Exactly one non-splitting subgraph should contain merged empty nodes" ) assert len(_get_empty_nodes(non_splitting_with_empty[0])) == 2, ( "Expected two builtin empty_like nodes in merged non-splitting subgraph" ) x = torch.randn(2, 3, device="cuda") output_original = gm(x) output_split = split_gm(x) assert torch.allclose(output_original, output_split), "Output mismatch after split" @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") def test_sym_size_whole_shape_boundary(): """ Test that using x.size() (whole shape) across a split boundary can be compiled by standalone_compile. The dynamo graph looks like: shape = x.size() y = sigmoid(x) # split point z = y.clone().view(shape) Which splits into: subgraph0(x) -> shape # returns torch.Size — problematic subgraph1(x) -> y # sigmoid subgraph2(y, shape) -> z # view Two approaches to fix the torch.Size crossing: Approach 1 — move sym_size to consumer (memory implication: x passed to subgraph2 just for .size()): subgraph0(x) -> # empty subgraph1(x) -> y subgraph2(y, x) -> z # computes shape locally from x Approach 2 — decompose shape into individual int/SymInt values: subgraph0(x) -> s0, val # returns individual scalars, not Size subgraph1(x) -> y subgraph2(y, s0, val) -> z # reconstructs view args from scalars """ from torch._inductor import standalone_compile captured_graph = None def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule: nonlocal captured_graph captured_graph = gm return gm def model_fn(x: torch.Tensor) -> torch.Tensor: shape = x.size() x = torch.ops.aten.sigmoid.default(x) x = x.clone().view(shape) return x x = torch.randn(4, 8) torch._dynamo.mark_dynamic(x, 0) compiled_fn = torch.compile(model_fn, backend=capturing_backend) compiled_fn(x) split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"]) assert len(split_items) == 3 submod_0 = split_gm.submod_0 example_input = torch.randn(4, 8) compiled = standalone_compile( submod_0, [example_input, 4], dynamic_shapes="from_example_inputs" ) assert compiled is not None @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") def test_symint_crosses_split_boundary(): """ Test that SymInt placeholders from torch.compile + mark_dynamic cross split boundaries safely via split_module's natural threading. SymInt values are threaded through subgraphs by split_module and handled correctly by inductor — no special replacement is needed. """ captured_graph = None def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule: nonlocal captured_graph captured_graph = gm return gm def model_fn(x: torch.Tensor) -> torch.Tensor: batch_size = x.shape[0] hidden_size = x.shape[1] x = torch.ops.aten.sigmoid.default(x) x = x.clone().view(batch_size, hidden_size) x = torch.ops.aten.sigmoid.default(x) x = x.clone().view(batch_size, hidden_size) x = torch.ops.aten.sigmoid.default(x) x = x.clone().view(batch_size, hidden_size) return x x = torch.randn(4, 8) torch._dynamo.mark_dynamic(x, 0) compiled_fn = torch.compile(model_fn, backend=capturing_backend) compiled_fn(x) assert captured_graph is not None, "Graph should be captured by backend" # SymInt placeholders should exist in the captured graph symint_placeholders = [ node for node in captured_graph.graph.nodes if node.op == "placeholder" and isinstance(node.meta.get("example_value"), torch.SymInt) ] assert len(symint_placeholders) > 0, ( "Captured graph should have SymInt placeholders from mark_dynamic." ) # split_graph should handle SymInt placeholders without error split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"]) # Should have 3 splitting subgraphs (3 sigmoids) splitting_subgraphs = [item for item in split_items if item.is_splitting_graph] assert len(splitting_subgraphs) == 3, ( f"Expected 3 splitting subgraphs (3 sigmoids), got {len(splitting_subgraphs)}" ) assert len(split_items) >= 6, ( f"Expected at least 6 total subgraphs, got {len(split_items)}" ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") def test_shape_boundary_standalone_compile(): """ Repro for the original production bug: AssertionError: out_spec mismatch TreeSpec(tuple, None, [*, *, TreeSpec(Size, None, [*, *]), *]) vs TreeSpec(tuple, None, [*, *, *, *]) A subgraph outputs torch.Size (e.g. torch.Size([s72, 2048])) as one of its values when shape info crosses a split boundary. aot_autograd / inductor expect all submodule outputs to be flat tensors or scalars, not torch.Size. With the fix, x.size() is decomposed into individual sym_size.int calls so only scalar SymInts cross the boundary — not the torch.Size. """ from torch._inductor import standalone_compile captured_graph = None def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule: nonlocal captured_graph captured_graph = gm return gm def model_fn(x: torch.Tensor) -> torch.Tensor: shape = x.size() x = torch.ops.aten.sigmoid.default(x) x = x.clone().view(shape) return x x = torch.randn(4, 8) torch._dynamo.mark_dynamic(x, 0) torch.compile(model_fn, backend=capturing_backend)(x) split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"]) assert len(split_items) == 3 # Verify that the consumer subgraph only has a placeholder for the dynamic # dim (SymInt) — the static dim (8) should be inlined as a literal, not # threaded as a placeholder. consumer = split_items[-1] # valid since len == 3: [producer, sigmoid, consumer] symint_placeholders = [ n for n in consumer.graph.graph.nodes if n.op == "placeholder" and isinstance(n.meta.get("example_value"), torch.SymInt) ] static_int_placeholders = [ n for n in consumer.graph.graph.nodes if n.op == "placeholder" and isinstance(n.meta.get("example_value"), int) and not isinstance(n.meta.get("example_value"), torch.SymInt) ] assert len(symint_placeholders) >= 1, ( "Consumer should have a SymInt placeholder for the dynamic dim." ) assert len(static_int_placeholders) == 0, ( "Static dims should be inlined as literals, not threaded as placeholders." ) submod_0 = split_gm.submod_0 standalone_compile( submod_0, [torch.randn(4, 8), 4], dynamic_shapes="from_example_inputs" ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") def test_size_used_in_multiple_consumer_subgraphs(): """ Validates that x.size() (whole shape) used by multiple downstream subgraphs does not cause torch.Size to cross split boundaries. Model: shape = x.size() # whole shape — must not cross as torch.Size z1 = sigmoid(x) # split point 1 y1 = y.view(shape) # consumer 1 uses shape z2 = sigmoid(z1) # split point 2 y2 = y.view(shape) # consumer 2 uses shape again Without the fix, torch.Size crosses the boundary as a submodule output, which aot_autograd / standalone_compile rejects. """ captured_graph = None captured_inputs = None def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule: nonlocal captured_graph, captured_inputs captured_graph = gm captured_inputs = example_inputs return gm def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: shape = x.size() z1 = torch.ops.aten.sigmoid.default(x) y1 = y.view(shape) z2 = torch.ops.aten.sigmoid.default(z1) y2 = y.view(shape) return z2 + y1 + y2 x = torch.randn(4, 8) y = torch.randn(4, 8) # same shape as x so view(shape) doesn't specialize dim 0 torch._dynamo.mark_dynamic(x, 0) torch._dynamo.mark_dynamic(y, 0) torch.compile(model_fn, backend=capturing_backend)(x, y) split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"]) splitting_items = [item for item in split_items if item.is_splitting_graph] assert len(splitting_items) == 2 # Verify functional correctness — fails without the fix because torch.Size # would cross a split boundary as a submodule output output_original = model_fn(x, y) output_split = split_gm(*captured_inputs) if isinstance(output_split, tuple): output_split = next(o for o in output_split if isinstance(o, torch.Tensor)) assert torch.allclose(output_original, output_split), "Output mismatch after split" @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") def test_sym_size_metadata_propagated(): """ Validates that new sym_size.int nodes created by the pre-pass have example_value metadata set. Without it, placeholder metadata in consumer subgraphs would be None, breaking any code that dynamically builds example inputs from metadata (e.g. standalone_compile per-submodule). """ from torch._inductor import standalone_compile captured_graph = None def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule: nonlocal captured_graph captured_graph = gm return gm def model_fn(x: torch.Tensor) -> torch.Tensor: shape = x.size() x = torch.ops.aten.sigmoid.default(x) x = x.clone().view(shape) return x x = torch.randn(4, 8) torch._dynamo.mark_dynamic(x, 0) torch.compile(model_fn, backend=capturing_backend)(x) split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"]) # For each submodule, build example inputs purely from placeholder metadata. # This fails if example_value is None on any placeholder (i.e. metadata # was not propagated to the sym_size.int nodes we created). for item in split_items: submod = item.graph example_inputs = [] for n in submod.graph.nodes: if n.op != "placeholder": continue ev = n.meta.get("example_value") assert ev is not None, ( f"Placeholder '{n.name}' in {item.submod_name} has no " "example_value metadata. sym_size.int nodes must propagate " "metadata so consumer subgraphs can be introspected." ) if isinstance(ev, torch.Tensor): example_inputs.append(torch.randn(*(int(d) for d in ev.shape))) else: example_inputs.append(int(ev)) standalone_compile(submod, example_inputs, dynamic_shapes="from_example_inputs")