[Bugfix] Avoid merging empty-only partitions into splitting-op subgraphs (#36595)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
@@ -7,7 +7,7 @@ import pytest
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
from vllm.compilation.backends import split_graph
|
||||
from vllm.compilation.backends import _is_empty_allocation_node, split_graph
|
||||
from vllm.compilation.passes.fx_utils import find_op_nodes
|
||||
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
@@ -186,10 +186,25 @@ def test_consecutive_ops_in_split():
|
||||
] + ["output"]
|
||||
|
||||
|
||||
def test_empty_only_partition_is_merged():
|
||||
def _get_empty_nodes(split_item):
|
||||
return [
|
||||
node for node in split_item.graph.graph.nodes if _is_empty_allocation_node(node)
|
||||
]
|
||||
|
||||
|
||||
def _subgraphs_with_empty_nodes(split_items, *, is_splitting_graph):
|
||||
return [
|
||||
split_item
|
||||
for split_item in split_items
|
||||
if split_item.is_splitting_graph == is_splitting_graph
|
||||
and _get_empty_nodes(split_item)
|
||||
]
|
||||
|
||||
|
||||
def test_empty_only_partition_stays_separate_after_splitting_predecessor():
|
||||
"""
|
||||
Test that an empty-allocation-only partition is merged into its previous
|
||||
partition during Dynamo FX splitting.
|
||||
Empty-only subgraphs should not be merged when the only predecessor is
|
||||
a splitting-op subgraph.
|
||||
"""
|
||||
|
||||
def model_fn(x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -204,9 +219,65 @@ def test_empty_only_partition_is_merged():
|
||||
split_ops = ["aten::sin", "aten::cos.out"]
|
||||
split_gm, split_items = split_graph(gm, split_ops)
|
||||
|
||||
# Without the merge, this graph is split into 3 partitions where the
|
||||
# middle partition contains only aten::empty_like.
|
||||
assert len(split_items) == 2, "Empty-only partition should be merged"
|
||||
# Graph partitioning for this pattern is:
|
||||
# [sin], [empty_like], [cos.out].
|
||||
assert len(split_items) == 3, (
|
||||
"Empty-only partition should not merge into splitting-op subgraph"
|
||||
)
|
||||
|
||||
splitting_with_empty = _subgraphs_with_empty_nodes(
|
||||
split_items, is_splitting_graph=True
|
||||
)
|
||||
assert len(splitting_with_empty) == 0, (
|
||||
"Splitting-op subgraphs should not contain empty allocation nodes: "
|
||||
f"{[item.submod_name for item in splitting_with_empty]}"
|
||||
)
|
||||
|
||||
output_original = gm(x)
|
||||
output_split = split_gm(x)
|
||||
assert torch.allclose(output_original, output_split), "Output mismatch after split"
|
||||
|
||||
|
||||
def test_empty_only_partition_is_merged():
|
||||
"""
|
||||
Empty-only subgraphs should still be merged when a non-splitting predecessor
|
||||
exists. The merged empty node must remain outside splitting-op subgraphs.
|
||||
"""
|
||||
|
||||
def model_fn(x: torch.Tensor) -> torch.Tensor:
|
||||
base = x + 1
|
||||
y = torch.sin(base)
|
||||
out = torch.empty_like(base)
|
||||
torch.ops.aten.cos.out(base, out=out)
|
||||
return out + y
|
||||
|
||||
x = torch.randn(4, 3)
|
||||
gm = make_fx(model_fn)(x)
|
||||
split_gm, split_items = split_graph(gm, ["aten::sin", "aten::cos.out"])
|
||||
|
||||
# Partitioning should be:
|
||||
# [add, empty_like], [sin], [cos.out], [add].
|
||||
assert len(split_items) == 4, (
|
||||
"Empty-only partition should be merged into non-splitting predecessor"
|
||||
)
|
||||
|
||||
splitting_with_empty = _subgraphs_with_empty_nodes(
|
||||
split_items, is_splitting_graph=True
|
||||
)
|
||||
assert len(splitting_with_empty) == 0, (
|
||||
"Splitting-op subgraphs should not contain empty allocation nodes: "
|
||||
f"{[item.submod_name for item in splitting_with_empty]}"
|
||||
)
|
||||
|
||||
non_splitting_with_empty = _subgraphs_with_empty_nodes(
|
||||
split_items, is_splitting_graph=False
|
||||
)
|
||||
assert len(non_splitting_with_empty) == 1, (
|
||||
"Exactly one non-splitting subgraph should contain the merged empty node"
|
||||
)
|
||||
assert len(_get_empty_nodes(non_splitting_with_empty[0])) == 1, (
|
||||
"Expected exactly one empty allocation node in merged subgraph"
|
||||
)
|
||||
|
||||
output_original = gm(x)
|
||||
output_split = split_gm(x)
|
||||
@@ -220,18 +291,37 @@ def test_builtin_empty_only_partition_is_merged():
|
||||
"""
|
||||
|
||||
def model_fn(x: torch.Tensor) -> torch.Tensor:
|
||||
out1 = torch.empty_like(x)
|
||||
torch.ops.silly.attention(x, x, x, out1)
|
||||
out2 = torch.empty_like(x)
|
||||
torch.ops.silly.attention(out1, out1, out1, out2)
|
||||
return out2
|
||||
hidden = x + 1
|
||||
out1 = torch.empty_like(hidden)
|
||||
torch.ops.silly.attention(hidden, hidden, hidden, out1)
|
||||
out2 = torch.empty_like(hidden)
|
||||
torch.ops.silly.attention(out1, out1, hidden, out2)
|
||||
return out2 + hidden
|
||||
|
||||
gm = torch.fx.symbolic_trace(model_fn)
|
||||
split_gm, split_items = split_graph(gm, ["silly::attention"])
|
||||
|
||||
# Without the empty-only merge, this graph creates 4 partitions:
|
||||
# [empty_like], [attention], [empty_like], [attention].
|
||||
assert len(split_items) == 3, "Builtin empty-only partition should be merged"
|
||||
# Without empty-only merge, this graph would split into:
|
||||
# [add, empty_like], [attention], [empty_like], [attention], [add].
|
||||
assert len(split_items) == 4, "Builtin empty-only partition should be merged"
|
||||
|
||||
splitting_with_empty = _subgraphs_with_empty_nodes(
|
||||
split_items, is_splitting_graph=True
|
||||
)
|
||||
assert len(splitting_with_empty) == 0, (
|
||||
"Splitting-op subgraphs should not contain empty allocation nodes: "
|
||||
f"{[item.submod_name for item in splitting_with_empty]}"
|
||||
)
|
||||
|
||||
non_splitting_with_empty = _subgraphs_with_empty_nodes(
|
||||
split_items, is_splitting_graph=False
|
||||
)
|
||||
assert len(non_splitting_with_empty) == 1, (
|
||||
"Exactly one non-splitting subgraph should contain merged empty nodes"
|
||||
)
|
||||
assert len(_get_empty_nodes(non_splitting_with_empty[0])) == 2, (
|
||||
"Expected two builtin empty_like nodes in merged non-splitting subgraph"
|
||||
)
|
||||
|
||||
x = torch.randn(2, 3, device="cuda")
|
||||
output_original = gm(x)
|
||||
|
||||
@@ -431,6 +431,7 @@ def _is_empty_allocation_node(node: fx.Node) -> bool:
|
||||
|
||||
def _merge_empty_only_subgraphs(
|
||||
node_to_subgraph_id: dict[fx.Node, int],
|
||||
split_op_graphs: list[int],
|
||||
) -> None:
|
||||
"""
|
||||
Merge a partition that only contains an empty allocation op into the
|
||||
@@ -439,23 +440,35 @@ def _merge_empty_only_subgraphs(
|
||||
"""
|
||||
|
||||
nodes_by_subgraph_id: dict[int, list[fx.Node]] = defaultdict(list)
|
||||
subgraph_id_order: list[int] = []
|
||||
for node, subgraph_id in node_to_subgraph_id.items():
|
||||
if subgraph_id not in nodes_by_subgraph_id:
|
||||
subgraph_id_order.append(subgraph_id)
|
||||
nodes_by_subgraph_id[subgraph_id].append(node)
|
||||
|
||||
prev_subgraph_id: int | None = None
|
||||
for subgraph_id in subgraph_id_order:
|
||||
nodes = nodes_by_subgraph_id[subgraph_id]
|
||||
if (
|
||||
len(nodes) == 1
|
||||
and _is_empty_allocation_node(nodes[0])
|
||||
and prev_subgraph_id is not None
|
||||
):
|
||||
node_to_subgraph_id[nodes[0]] = prev_subgraph_id
|
||||
splitting_subgraphs = set(split_op_graphs)
|
||||
prev_non_splitting_subgraph_id: int | None = None
|
||||
|
||||
max_subgraph_id = max(node_to_subgraph_id.values(), default=-1)
|
||||
for subgraph_id in range(max_subgraph_id + 1):
|
||||
nodes = nodes_by_subgraph_id.get(subgraph_id, [])
|
||||
if not nodes:
|
||||
continue
|
||||
prev_subgraph_id = subgraph_id
|
||||
|
||||
is_non_splitting_subgraph = subgraph_id not in splitting_subgraphs
|
||||
is_empty_only_subgraph = len(nodes) == 1 and _is_empty_allocation_node(nodes[0])
|
||||
merged = False
|
||||
|
||||
if is_empty_only_subgraph and prev_non_splitting_subgraph_id is not None:
|
||||
# Safety check: don't move allocation before any input producer.
|
||||
empty_node = nodes[0]
|
||||
if all(
|
||||
input_node.op == "placeholder"
|
||||
or node_to_subgraph_id[input_node] <= prev_non_splitting_subgraph_id
|
||||
for input_node in empty_node.all_input_nodes
|
||||
):
|
||||
node_to_subgraph_id[empty_node] = prev_non_splitting_subgraph_id
|
||||
merged = True
|
||||
|
||||
if not merged and is_non_splitting_subgraph:
|
||||
prev_non_splitting_subgraph_id = subgraph_id
|
||||
|
||||
|
||||
def split_graph(
|
||||
@@ -496,7 +509,7 @@ def split_graph(
|
||||
else:
|
||||
node_to_subgraph_id[node] = subgraph_id
|
||||
|
||||
_merge_empty_only_subgraphs(node_to_subgraph_id)
|
||||
_merge_empty_only_subgraphs(node_to_subgraph_id, split_op_graphs)
|
||||
|
||||
# `keep_original_order` is important!
|
||||
# otherwise pytorch might reorder the nodes and
|
||||
|
||||
Reference in New Issue
Block a user