diff --git a/tests/compile/test_graph_partition.py b/tests/compile/test_graph_partition.py index 49bb54824..0b490e97f 100644 --- a/tests/compile/test_graph_partition.py +++ b/tests/compile/test_graph_partition.py @@ -5,6 +5,8 @@ import operator import pytest import torch +import torch._dynamo +import torch.fx as fx from torch.fx.experimental.proxy_tensor import make_fx from vllm.compilation.backends import _is_empty_allocation_node, split_graph @@ -327,3 +329,296 @@ def test_builtin_empty_only_partition_is_merged(): output_original = gm(x) output_split = split_gm(x) assert torch.allclose(output_original, output_split), "Output mismatch after split" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +def test_sym_size_whole_shape_boundary(): + """ + Test that using x.size() (whole shape) across a split boundary can be + compiled by standalone_compile. + + The dynamo graph looks like: + shape = x.size() + y = sigmoid(x) # split point + z = y.clone().view(shape) + + Which splits into: + subgraph0(x) -> shape # returns torch.Size — problematic + subgraph1(x) -> y # sigmoid + subgraph2(y, shape) -> z # view + + Two approaches to fix the torch.Size crossing: + + Approach 1 — move sym_size to consumer (memory implication: x passed to + subgraph2 just for .size()): + subgraph0(x) -> # empty + subgraph1(x) -> y + subgraph2(y, x) -> z # computes shape locally from x + + Approach 2 — decompose shape into individual int/SymInt values: + subgraph0(x) -> s0, val # returns individual scalars, not Size + subgraph1(x) -> y + subgraph2(y, s0, val) -> z # reconstructs view args from scalars + """ + from torch._inductor import standalone_compile + + captured_graph = None + + def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule: + nonlocal captured_graph + captured_graph = gm + return gm + + def model_fn(x: torch.Tensor) -> torch.Tensor: + shape = x.size() + x = torch.ops.aten.sigmoid.default(x) + x = x.clone().view(shape) + return x + + x = torch.randn(4, 8) + torch._dynamo.mark_dynamic(x, 0) + compiled_fn = torch.compile(model_fn, backend=capturing_backend) + compiled_fn(x) + + split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"]) + assert len(split_items) == 3 + + submod_0 = split_gm.submod_0 + example_input = torch.randn(4, 8) + compiled = standalone_compile( + submod_0, [example_input, 4], dynamic_shapes="from_example_inputs" + ) + assert compiled is not None + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +def test_symint_crosses_split_boundary(): + """ + Test that SymInt placeholders from torch.compile + mark_dynamic + cross split boundaries safely via split_module's natural threading. + + SymInt values are threaded through subgraphs by split_module and + handled correctly by inductor — no special replacement is needed. + """ + captured_graph = None + + def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule: + nonlocal captured_graph + captured_graph = gm + return gm + + def model_fn(x: torch.Tensor) -> torch.Tensor: + batch_size = x.shape[0] + hidden_size = x.shape[1] + x = torch.ops.aten.sigmoid.default(x) + x = x.clone().view(batch_size, hidden_size) + x = torch.ops.aten.sigmoid.default(x) + x = x.clone().view(batch_size, hidden_size) + x = torch.ops.aten.sigmoid.default(x) + x = x.clone().view(batch_size, hidden_size) + return x + + x = torch.randn(4, 8) + torch._dynamo.mark_dynamic(x, 0) + + compiled_fn = torch.compile(model_fn, backend=capturing_backend) + compiled_fn(x) + + assert captured_graph is not None, "Graph should be captured by backend" + + # SymInt placeholders should exist in the captured graph + symint_placeholders = [ + node + for node in captured_graph.graph.nodes + if node.op == "placeholder" + and isinstance(node.meta.get("example_value"), torch.SymInt) + ] + assert len(symint_placeholders) > 0, ( + "Captured graph should have SymInt placeholders from mark_dynamic." + ) + + # split_graph should handle SymInt placeholders without error + split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"]) + + # Should have 3 splitting subgraphs (3 sigmoids) + splitting_subgraphs = [item for item in split_items if item.is_splitting_graph] + assert len(splitting_subgraphs) == 3, ( + f"Expected 3 splitting subgraphs (3 sigmoids), got {len(splitting_subgraphs)}" + ) + assert len(split_items) >= 6, ( + f"Expected at least 6 total subgraphs, got {len(split_items)}" + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +def test_shape_boundary_standalone_compile(): + """ + Repro for the original production bug: + + AssertionError: out_spec mismatch + TreeSpec(tuple, None, [*, *, TreeSpec(Size, None, [*, *]), *]) + vs + TreeSpec(tuple, None, [*, *, *, *]) + + A subgraph outputs torch.Size (e.g. torch.Size([s72, 2048])) as one of + its values when shape info crosses a split boundary. aot_autograd / inductor + expect all submodule outputs to be flat tensors or scalars, not torch.Size. + + With the fix, x.size() is decomposed into individual sym_size.int calls + so only scalar SymInts cross the boundary — not the torch.Size. + """ + from torch._inductor import standalone_compile + + captured_graph = None + + def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule: + nonlocal captured_graph + captured_graph = gm + return gm + + def model_fn(x: torch.Tensor) -> torch.Tensor: + shape = x.size() + x = torch.ops.aten.sigmoid.default(x) + x = x.clone().view(shape) + return x + + x = torch.randn(4, 8) + torch._dynamo.mark_dynamic(x, 0) + torch.compile(model_fn, backend=capturing_backend)(x) + + split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"]) + assert len(split_items) == 3 + + # Verify that the consumer subgraph only has a placeholder for the dynamic + # dim (SymInt) — the static dim (8) should be inlined as a literal, not + # threaded as a placeholder. + consumer = split_items[-1] # valid since len == 3: [producer, sigmoid, consumer] + symint_placeholders = [ + n + for n in consumer.graph.graph.nodes + if n.op == "placeholder" + and isinstance(n.meta.get("example_value"), torch.SymInt) + ] + static_int_placeholders = [ + n + for n in consumer.graph.graph.nodes + if n.op == "placeholder" + and isinstance(n.meta.get("example_value"), int) + and not isinstance(n.meta.get("example_value"), torch.SymInt) + ] + assert len(symint_placeholders) >= 1, ( + "Consumer should have a SymInt placeholder for the dynamic dim." + ) + assert len(static_int_placeholders) == 0, ( + "Static dims should be inlined as literals, not threaded as placeholders." + ) + + submod_0 = split_gm.submod_0 + + standalone_compile( + submod_0, [torch.randn(4, 8), 4], dynamic_shapes="from_example_inputs" + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +def test_size_used_in_multiple_consumer_subgraphs(): + """ + Validates that x.size() (whole shape) used by multiple downstream subgraphs + does not cause torch.Size to cross split boundaries. + + Model: + shape = x.size() # whole shape — must not cross as torch.Size + z1 = sigmoid(x) # split point 1 + y1 = y.view(shape) # consumer 1 uses shape + z2 = sigmoid(z1) # split point 2 + y2 = y.view(shape) # consumer 2 uses shape again + + Without the fix, torch.Size crosses the boundary as a submodule output, + which aot_autograd / standalone_compile rejects. + """ + captured_graph = None + captured_inputs = None + + def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule: + nonlocal captured_graph, captured_inputs + captured_graph = gm + captured_inputs = example_inputs + return gm + + def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + shape = x.size() + z1 = torch.ops.aten.sigmoid.default(x) + y1 = y.view(shape) + z2 = torch.ops.aten.sigmoid.default(z1) + y2 = y.view(shape) + return z2 + y1 + y2 + + x = torch.randn(4, 8) + y = torch.randn(4, 8) # same shape as x so view(shape) doesn't specialize dim 0 + torch._dynamo.mark_dynamic(x, 0) + torch._dynamo.mark_dynamic(y, 0) + torch.compile(model_fn, backend=capturing_backend)(x, y) + + split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"]) + + splitting_items = [item for item in split_items if item.is_splitting_graph] + assert len(splitting_items) == 2 + + # Verify functional correctness — fails without the fix because torch.Size + # would cross a split boundary as a submodule output + output_original = model_fn(x, y) + output_split = split_gm(*captured_inputs) + if isinstance(output_split, tuple): + output_split = next(o for o in output_split if isinstance(o, torch.Tensor)) + assert torch.allclose(output_original, output_split), "Output mismatch after split" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +def test_sym_size_metadata_propagated(): + """ + Validates that new sym_size.int nodes created by the pre-pass have + example_value metadata set. Without it, placeholder metadata in consumer + subgraphs would be None, breaking any code that dynamically builds + example inputs from metadata (e.g. standalone_compile per-submodule). + """ + from torch._inductor import standalone_compile + + captured_graph = None + + def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule: + nonlocal captured_graph + captured_graph = gm + return gm + + def model_fn(x: torch.Tensor) -> torch.Tensor: + shape = x.size() + x = torch.ops.aten.sigmoid.default(x) + x = x.clone().view(shape) + return x + + x = torch.randn(4, 8) + torch._dynamo.mark_dynamic(x, 0) + torch.compile(model_fn, backend=capturing_backend)(x) + + split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"]) + + # For each submodule, build example inputs purely from placeholder metadata. + # This fails if example_value is None on any placeholder (i.e. metadata + # was not propagated to the sym_size.int nodes we created). + for item in split_items: + submod = item.graph + example_inputs = [] + for n in submod.graph.nodes: + if n.op != "placeholder": + continue + ev = n.meta.get("example_value") + assert ev is not None, ( + f"Placeholder '{n.name}' in {item.submod_name} has no " + "example_value metadata. sym_size.int nodes must propagate " + "metadata so consumer subgraphs can be introspected." + ) + if isinstance(ev, torch.Tensor): + example_inputs.append(torch.randn(*(int(d) for d in ev.shape))) + else: + example_inputs.append(int(ev)) + standalone_compile(submod, example_inputs, dynamic_shapes="from_example_inputs") diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 3526099dc..e049ef345 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -473,9 +473,65 @@ def _merge_empty_only_subgraphs( prev_non_splitting_subgraph_id = subgraph_id +def _decompose_size_nodes(graph: fx.GraphModule) -> None: + """Decompose x.size() into per-dim sym_size.int calls. + + torch.Size objects cannot cross split boundaries because aot_autograd + cannot handle them as submodule outputs. This replaces each size() call + with individual sym_size.int(x, dim) nodes: + - Dynamic dims (SymInt) → new sym_size.int node + - Static dims (plain int) → inlined as literal constant + """ + # Dynamo captures x.size()/x.shape as call_method target="size". + size_nodes = list(graph.graph.find_nodes(op="call_method", target="size")) + + for node in size_nodes: + tensor_node = node.args[0] + ev = tensor_node.meta.get("example_value") + assert ev is not None, ( + f"Tensor node '{tensor_node.name}' has no example_value metadata. " + f"Cannot decompose size node '{node.name}'." + ) + + # Build per-dim replacements: sym_size.int node or literal int. + dims: list[fx.Node | int] = [] + with graph.graph.inserting_after(tensor_node): + for i in range(ev.dim()): + dim_val = ev.shape[i] + if isinstance(dim_val, torch.SymInt): + dn = graph.graph.call_function( + torch.ops.aten.sym_size.int, args=(tensor_node, i) + ) + dn.meta["example_value"] = dim_val + dims.append(dn) + elif isinstance(dim_val, int): + dims.append(dim_val) + else: + raise AssertionError( + f"dim_val is either torch.SymInt or int, " + f"got {type(dim_val)} for dim {i} of " + f"'{node.name}'" + ) + + # Replace size node in each user's args. + # Dynamo always passes size as a direct arg: view(clone, size) + # → view(clone, d0, d1, ...) + for user in list(node.users): + new_args = [] + for arg in user.args: + if arg is node: + new_args.extend(dims) + else: + new_args.append(arg) + user.args = tuple(new_args) + graph.graph.erase_node(node) + + def split_graph( graph: fx.GraphModule, splitting_ops: list[str] ) -> tuple[fx.GraphModule, list[SplitItem]]: + _decompose_size_nodes(graph) + # split graph by ops subgraph_id = 0 node_to_subgraph_id: dict[fx.Node, int] = {}