[Bugfix] Avoid merging empty-only partitions into splitting-op subgraphs (#36595)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
Jiangyun Zhu
2026-03-10 22:39:01 +08:00
committed by GitHub
parent cf88b23749
commit ca5fb4bbd8
2 changed files with 132 additions and 29 deletions

View File

@@ -7,7 +7,7 @@ import pytest
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from vllm.compilation.backends import split_graph
from vllm.compilation.backends import _is_empty_allocation_node, split_graph
from vllm.compilation.passes.fx_utils import find_op_nodes
# This import automatically registers `torch.ops.silly.attention`
@@ -186,10 +186,25 @@ def test_consecutive_ops_in_split():
] + ["output"]
def test_empty_only_partition_is_merged():
def _get_empty_nodes(split_item):
return [
node for node in split_item.graph.graph.nodes if _is_empty_allocation_node(node)
]
def _subgraphs_with_empty_nodes(split_items, *, is_splitting_graph):
return [
split_item
for split_item in split_items
if split_item.is_splitting_graph == is_splitting_graph
and _get_empty_nodes(split_item)
]
def test_empty_only_partition_stays_separate_after_splitting_predecessor():
"""
Test that an empty-allocation-only partition is merged into its previous
partition during Dynamo FX splitting.
Empty-only subgraphs should not be merged when the only predecessor is
a splitting-op subgraph.
"""
def model_fn(x: torch.Tensor) -> torch.Tensor:
@@ -204,9 +219,65 @@ def test_empty_only_partition_is_merged():
split_ops = ["aten::sin", "aten::cos.out"]
split_gm, split_items = split_graph(gm, split_ops)
# Without the merge, this graph is split into 3 partitions where the
# middle partition contains only aten::empty_like.
assert len(split_items) == 2, "Empty-only partition should be merged"
# Graph partitioning for this pattern is:
# [sin], [empty_like], [cos.out].
assert len(split_items) == 3, (
"Empty-only partition should not merge into splitting-op subgraph"
)
splitting_with_empty = _subgraphs_with_empty_nodes(
split_items, is_splitting_graph=True
)
assert len(splitting_with_empty) == 0, (
"Splitting-op subgraphs should not contain empty allocation nodes: "
f"{[item.submod_name for item in splitting_with_empty]}"
)
output_original = gm(x)
output_split = split_gm(x)
assert torch.allclose(output_original, output_split), "Output mismatch after split"
def test_empty_only_partition_is_merged():
"""
Empty-only subgraphs should still be merged when a non-splitting predecessor
exists. The merged empty node must remain outside splitting-op subgraphs.
"""
def model_fn(x: torch.Tensor) -> torch.Tensor:
base = x + 1
y = torch.sin(base)
out = torch.empty_like(base)
torch.ops.aten.cos.out(base, out=out)
return out + y
x = torch.randn(4, 3)
gm = make_fx(model_fn)(x)
split_gm, split_items = split_graph(gm, ["aten::sin", "aten::cos.out"])
# Partitioning should be:
# [add, empty_like], [sin], [cos.out], [add].
assert len(split_items) == 4, (
"Empty-only partition should be merged into non-splitting predecessor"
)
splitting_with_empty = _subgraphs_with_empty_nodes(
split_items, is_splitting_graph=True
)
assert len(splitting_with_empty) == 0, (
"Splitting-op subgraphs should not contain empty allocation nodes: "
f"{[item.submod_name for item in splitting_with_empty]}"
)
non_splitting_with_empty = _subgraphs_with_empty_nodes(
split_items, is_splitting_graph=False
)
assert len(non_splitting_with_empty) == 1, (
"Exactly one non-splitting subgraph should contain the merged empty node"
)
assert len(_get_empty_nodes(non_splitting_with_empty[0])) == 1, (
"Expected exactly one empty allocation node in merged subgraph"
)
output_original = gm(x)
output_split = split_gm(x)
@@ -220,18 +291,37 @@ def test_builtin_empty_only_partition_is_merged():
"""
def model_fn(x: torch.Tensor) -> torch.Tensor:
out1 = torch.empty_like(x)
torch.ops.silly.attention(x, x, x, out1)
out2 = torch.empty_like(x)
torch.ops.silly.attention(out1, out1, out1, out2)
return out2
hidden = x + 1
out1 = torch.empty_like(hidden)
torch.ops.silly.attention(hidden, hidden, hidden, out1)
out2 = torch.empty_like(hidden)
torch.ops.silly.attention(out1, out1, hidden, out2)
return out2 + hidden
gm = torch.fx.symbolic_trace(model_fn)
split_gm, split_items = split_graph(gm, ["silly::attention"])
# Without the empty-only merge, this graph creates 4 partitions:
# [empty_like], [attention], [empty_like], [attention].
assert len(split_items) == 3, "Builtin empty-only partition should be merged"
# Without empty-only merge, this graph would split into:
# [add, empty_like], [attention], [empty_like], [attention], [add].
assert len(split_items) == 4, "Builtin empty-only partition should be merged"
splitting_with_empty = _subgraphs_with_empty_nodes(
split_items, is_splitting_graph=True
)
assert len(splitting_with_empty) == 0, (
"Splitting-op subgraphs should not contain empty allocation nodes: "
f"{[item.submod_name for item in splitting_with_empty]}"
)
non_splitting_with_empty = _subgraphs_with_empty_nodes(
split_items, is_splitting_graph=False
)
assert len(non_splitting_with_empty) == 1, (
"Exactly one non-splitting subgraph should contain merged empty nodes"
)
assert len(_get_empty_nodes(non_splitting_with_empty[0])) == 2, (
"Expected two builtin empty_like nodes in merged non-splitting subgraph"
)
x = torch.randn(2, 3, device="cuda")
output_original = gm(x)