diff --git a/tests/compile/test_graph_partition.py b/tests/compile/test_graph_partition.py index 9aa11dbe2..49bb54824 100644 --- a/tests/compile/test_graph_partition.py +++ b/tests/compile/test_graph_partition.py @@ -7,7 +7,7 @@ import pytest import torch from torch.fx.experimental.proxy_tensor import make_fx -from vllm.compilation.backends import split_graph +from vllm.compilation.backends import _is_empty_allocation_node, split_graph from vllm.compilation.passes.fx_utils import find_op_nodes # This import automatically registers `torch.ops.silly.attention` @@ -186,10 +186,25 @@ def test_consecutive_ops_in_split(): ] + ["output"] -def test_empty_only_partition_is_merged(): +def _get_empty_nodes(split_item): + return [ + node for node in split_item.graph.graph.nodes if _is_empty_allocation_node(node) + ] + + +def _subgraphs_with_empty_nodes(split_items, *, is_splitting_graph): + return [ + split_item + for split_item in split_items + if split_item.is_splitting_graph == is_splitting_graph + and _get_empty_nodes(split_item) + ] + + +def test_empty_only_partition_stays_separate_after_splitting_predecessor(): """ - Test that an empty-allocation-only partition is merged into its previous - partition during Dynamo FX splitting. + Empty-only subgraphs should not be merged when the only predecessor is + a splitting-op subgraph. """ def model_fn(x: torch.Tensor) -> torch.Tensor: @@ -204,9 +219,65 @@ def test_empty_only_partition_is_merged(): split_ops = ["aten::sin", "aten::cos.out"] split_gm, split_items = split_graph(gm, split_ops) - # Without the merge, this graph is split into 3 partitions where the - # middle partition contains only aten::empty_like. - assert len(split_items) == 2, "Empty-only partition should be merged" + # Graph partitioning for this pattern is: + # [sin], [empty_like], [cos.out]. + assert len(split_items) == 3, ( + "Empty-only partition should not merge into splitting-op subgraph" + ) + + splitting_with_empty = _subgraphs_with_empty_nodes( + split_items, is_splitting_graph=True + ) + assert len(splitting_with_empty) == 0, ( + "Splitting-op subgraphs should not contain empty allocation nodes: " + f"{[item.submod_name for item in splitting_with_empty]}" + ) + + output_original = gm(x) + output_split = split_gm(x) + assert torch.allclose(output_original, output_split), "Output mismatch after split" + + +def test_empty_only_partition_is_merged(): + """ + Empty-only subgraphs should still be merged when a non-splitting predecessor + exists. The merged empty node must remain outside splitting-op subgraphs. + """ + + def model_fn(x: torch.Tensor) -> torch.Tensor: + base = x + 1 + y = torch.sin(base) + out = torch.empty_like(base) + torch.ops.aten.cos.out(base, out=out) + return out + y + + x = torch.randn(4, 3) + gm = make_fx(model_fn)(x) + split_gm, split_items = split_graph(gm, ["aten::sin", "aten::cos.out"]) + + # Partitioning should be: + # [add, empty_like], [sin], [cos.out], [add]. + assert len(split_items) == 4, ( + "Empty-only partition should be merged into non-splitting predecessor" + ) + + splitting_with_empty = _subgraphs_with_empty_nodes( + split_items, is_splitting_graph=True + ) + assert len(splitting_with_empty) == 0, ( + "Splitting-op subgraphs should not contain empty allocation nodes: " + f"{[item.submod_name for item in splitting_with_empty]}" + ) + + non_splitting_with_empty = _subgraphs_with_empty_nodes( + split_items, is_splitting_graph=False + ) + assert len(non_splitting_with_empty) == 1, ( + "Exactly one non-splitting subgraph should contain the merged empty node" + ) + assert len(_get_empty_nodes(non_splitting_with_empty[0])) == 1, ( + "Expected exactly one empty allocation node in merged subgraph" + ) output_original = gm(x) output_split = split_gm(x) @@ -220,18 +291,37 @@ def test_builtin_empty_only_partition_is_merged(): """ def model_fn(x: torch.Tensor) -> torch.Tensor: - out1 = torch.empty_like(x) - torch.ops.silly.attention(x, x, x, out1) - out2 = torch.empty_like(x) - torch.ops.silly.attention(out1, out1, out1, out2) - return out2 + hidden = x + 1 + out1 = torch.empty_like(hidden) + torch.ops.silly.attention(hidden, hidden, hidden, out1) + out2 = torch.empty_like(hidden) + torch.ops.silly.attention(out1, out1, hidden, out2) + return out2 + hidden gm = torch.fx.symbolic_trace(model_fn) split_gm, split_items = split_graph(gm, ["silly::attention"]) - # Without the empty-only merge, this graph creates 4 partitions: - # [empty_like], [attention], [empty_like], [attention]. - assert len(split_items) == 3, "Builtin empty-only partition should be merged" + # Without empty-only merge, this graph would split into: + # [add, empty_like], [attention], [empty_like], [attention], [add]. + assert len(split_items) == 4, "Builtin empty-only partition should be merged" + + splitting_with_empty = _subgraphs_with_empty_nodes( + split_items, is_splitting_graph=True + ) + assert len(splitting_with_empty) == 0, ( + "Splitting-op subgraphs should not contain empty allocation nodes: " + f"{[item.submod_name for item in splitting_with_empty]}" + ) + + non_splitting_with_empty = _subgraphs_with_empty_nodes( + split_items, is_splitting_graph=False + ) + assert len(non_splitting_with_empty) == 1, ( + "Exactly one non-splitting subgraph should contain merged empty nodes" + ) + assert len(_get_empty_nodes(non_splitting_with_empty[0])) == 2, ( + "Expected two builtin empty_like nodes in merged non-splitting subgraph" + ) x = torch.randn(2, 3, device="cuda") output_original = gm(x) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index c0c46d9e7..51dff720b 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -431,6 +431,7 @@ def _is_empty_allocation_node(node: fx.Node) -> bool: def _merge_empty_only_subgraphs( node_to_subgraph_id: dict[fx.Node, int], + split_op_graphs: list[int], ) -> None: """ Merge a partition that only contains an empty allocation op into the @@ -439,23 +440,35 @@ def _merge_empty_only_subgraphs( """ nodes_by_subgraph_id: dict[int, list[fx.Node]] = defaultdict(list) - subgraph_id_order: list[int] = [] for node, subgraph_id in node_to_subgraph_id.items(): - if subgraph_id not in nodes_by_subgraph_id: - subgraph_id_order.append(subgraph_id) nodes_by_subgraph_id[subgraph_id].append(node) - prev_subgraph_id: int | None = None - for subgraph_id in subgraph_id_order: - nodes = nodes_by_subgraph_id[subgraph_id] - if ( - len(nodes) == 1 - and _is_empty_allocation_node(nodes[0]) - and prev_subgraph_id is not None - ): - node_to_subgraph_id[nodes[0]] = prev_subgraph_id + splitting_subgraphs = set(split_op_graphs) + prev_non_splitting_subgraph_id: int | None = None + + max_subgraph_id = max(node_to_subgraph_id.values(), default=-1) + for subgraph_id in range(max_subgraph_id + 1): + nodes = nodes_by_subgraph_id.get(subgraph_id, []) + if not nodes: continue - prev_subgraph_id = subgraph_id + + is_non_splitting_subgraph = subgraph_id not in splitting_subgraphs + is_empty_only_subgraph = len(nodes) == 1 and _is_empty_allocation_node(nodes[0]) + merged = False + + if is_empty_only_subgraph and prev_non_splitting_subgraph_id is not None: + # Safety check: don't move allocation before any input producer. + empty_node = nodes[0] + if all( + input_node.op == "placeholder" + or node_to_subgraph_id[input_node] <= prev_non_splitting_subgraph_id + for input_node in empty_node.all_input_nodes + ): + node_to_subgraph_id[empty_node] = prev_non_splitting_subgraph_id + merged = True + + if not merged and is_non_splitting_subgraph: + prev_non_splitting_subgraph_id = subgraph_id def split_graph( @@ -496,7 +509,7 @@ def split_graph( else: node_to_subgraph_id[node] = subgraph_id - _merge_empty_only_subgraphs(node_to_subgraph_id) + _merge_empty_only_subgraphs(node_to_subgraph_id, split_op_graphs) # `keep_original_order` is important! # otherwise pytorch might reorder the nodes and