[compile][graph_partition]Add tensor size handling (#36038)

Signed-off-by: Xiao Fu <xiaofu@meta.com>
This commit is contained in:
Xiao
2026-03-19 19:55:25 -07:00
committed by GitHub
parent 47b7af0d87
commit ea2c148fa7
2 changed files with 351 additions and 0 deletions

View File

@@ -5,6 +5,8 @@ import operator
import pytest
import torch
import torch._dynamo
import torch.fx as fx
from torch.fx.experimental.proxy_tensor import make_fx
from vllm.compilation.backends import _is_empty_allocation_node, split_graph
@@ -327,3 +329,296 @@ def test_builtin_empty_only_partition_is_merged():
output_original = gm(x)
output_split = split_gm(x)
assert torch.allclose(output_original, output_split), "Output mismatch after split"
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
def test_sym_size_whole_shape_boundary():
"""
Test that using x.size() (whole shape) across a split boundary can be
compiled by standalone_compile.
The dynamo graph looks like:
shape = x.size()
y = sigmoid(x) # split point
z = y.clone().view(shape)
Which splits into:
subgraph0(x) -> shape # returns torch.Size — problematic
subgraph1(x) -> y # sigmoid
subgraph2(y, shape) -> z # view
Two approaches to fix the torch.Size crossing:
Approach 1 — move sym_size to consumer (memory implication: x passed to
subgraph2 just for .size()):
subgraph0(x) -> # empty
subgraph1(x) -> y
subgraph2(y, x) -> z # computes shape locally from x
Approach 2 — decompose shape into individual int/SymInt values:
subgraph0(x) -> s0, val # returns individual scalars, not Size
subgraph1(x) -> y
subgraph2(y, s0, val) -> z # reconstructs view args from scalars
"""
from torch._inductor import standalone_compile
captured_graph = None
def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule:
nonlocal captured_graph
captured_graph = gm
return gm
def model_fn(x: torch.Tensor) -> torch.Tensor:
shape = x.size()
x = torch.ops.aten.sigmoid.default(x)
x = x.clone().view(shape)
return x
x = torch.randn(4, 8)
torch._dynamo.mark_dynamic(x, 0)
compiled_fn = torch.compile(model_fn, backend=capturing_backend)
compiled_fn(x)
split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"])
assert len(split_items) == 3
submod_0 = split_gm.submod_0
example_input = torch.randn(4, 8)
compiled = standalone_compile(
submod_0, [example_input, 4], dynamic_shapes="from_example_inputs"
)
assert compiled is not None
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
def test_symint_crosses_split_boundary():
"""
Test that SymInt placeholders from torch.compile + mark_dynamic
cross split boundaries safely via split_module's natural threading.
SymInt values are threaded through subgraphs by split_module and
handled correctly by inductor — no special replacement is needed.
"""
captured_graph = None
def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule:
nonlocal captured_graph
captured_graph = gm
return gm
def model_fn(x: torch.Tensor) -> torch.Tensor:
batch_size = x.shape[0]
hidden_size = x.shape[1]
x = torch.ops.aten.sigmoid.default(x)
x = x.clone().view(batch_size, hidden_size)
x = torch.ops.aten.sigmoid.default(x)
x = x.clone().view(batch_size, hidden_size)
x = torch.ops.aten.sigmoid.default(x)
x = x.clone().view(batch_size, hidden_size)
return x
x = torch.randn(4, 8)
torch._dynamo.mark_dynamic(x, 0)
compiled_fn = torch.compile(model_fn, backend=capturing_backend)
compiled_fn(x)
assert captured_graph is not None, "Graph should be captured by backend"
# SymInt placeholders should exist in the captured graph
symint_placeholders = [
node
for node in captured_graph.graph.nodes
if node.op == "placeholder"
and isinstance(node.meta.get("example_value"), torch.SymInt)
]
assert len(symint_placeholders) > 0, (
"Captured graph should have SymInt placeholders from mark_dynamic."
)
# split_graph should handle SymInt placeholders without error
split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"])
# Should have 3 splitting subgraphs (3 sigmoids)
splitting_subgraphs = [item for item in split_items if item.is_splitting_graph]
assert len(splitting_subgraphs) == 3, (
f"Expected 3 splitting subgraphs (3 sigmoids), got {len(splitting_subgraphs)}"
)
assert len(split_items) >= 6, (
f"Expected at least 6 total subgraphs, got {len(split_items)}"
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
def test_shape_boundary_standalone_compile():
"""
Repro for the original production bug:
AssertionError: out_spec mismatch
TreeSpec(tuple, None, [*, *, TreeSpec(Size, None, [*, *]), *])
vs
TreeSpec(tuple, None, [*, *, *, *])
A subgraph outputs torch.Size (e.g. torch.Size([s72, 2048])) as one of
its values when shape info crosses a split boundary. aot_autograd / inductor
expect all submodule outputs to be flat tensors or scalars, not torch.Size.
With the fix, x.size() is decomposed into individual sym_size.int calls
so only scalar SymInts cross the boundary — not the torch.Size.
"""
from torch._inductor import standalone_compile
captured_graph = None
def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule:
nonlocal captured_graph
captured_graph = gm
return gm
def model_fn(x: torch.Tensor) -> torch.Tensor:
shape = x.size()
x = torch.ops.aten.sigmoid.default(x)
x = x.clone().view(shape)
return x
x = torch.randn(4, 8)
torch._dynamo.mark_dynamic(x, 0)
torch.compile(model_fn, backend=capturing_backend)(x)
split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"])
assert len(split_items) == 3
# Verify that the consumer subgraph only has a placeholder for the dynamic
# dim (SymInt) — the static dim (8) should be inlined as a literal, not
# threaded as a placeholder.
consumer = split_items[-1] # valid since len == 3: [producer, sigmoid, consumer]
symint_placeholders = [
n
for n in consumer.graph.graph.nodes
if n.op == "placeholder"
and isinstance(n.meta.get("example_value"), torch.SymInt)
]
static_int_placeholders = [
n
for n in consumer.graph.graph.nodes
if n.op == "placeholder"
and isinstance(n.meta.get("example_value"), int)
and not isinstance(n.meta.get("example_value"), torch.SymInt)
]
assert len(symint_placeholders) >= 1, (
"Consumer should have a SymInt placeholder for the dynamic dim."
)
assert len(static_int_placeholders) == 0, (
"Static dims should be inlined as literals, not threaded as placeholders."
)
submod_0 = split_gm.submod_0
standalone_compile(
submod_0, [torch.randn(4, 8), 4], dynamic_shapes="from_example_inputs"
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
def test_size_used_in_multiple_consumer_subgraphs():
"""
Validates that x.size() (whole shape) used by multiple downstream subgraphs
does not cause torch.Size to cross split boundaries.
Model:
shape = x.size() # whole shape — must not cross as torch.Size
z1 = sigmoid(x) # split point 1
y1 = y.view(shape) # consumer 1 uses shape
z2 = sigmoid(z1) # split point 2
y2 = y.view(shape) # consumer 2 uses shape again
Without the fix, torch.Size crosses the boundary as a submodule output,
which aot_autograd / standalone_compile rejects.
"""
captured_graph = None
captured_inputs = None
def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule:
nonlocal captured_graph, captured_inputs
captured_graph = gm
captured_inputs = example_inputs
return gm
def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
shape = x.size()
z1 = torch.ops.aten.sigmoid.default(x)
y1 = y.view(shape)
z2 = torch.ops.aten.sigmoid.default(z1)
y2 = y.view(shape)
return z2 + y1 + y2
x = torch.randn(4, 8)
y = torch.randn(4, 8) # same shape as x so view(shape) doesn't specialize dim 0
torch._dynamo.mark_dynamic(x, 0)
torch._dynamo.mark_dynamic(y, 0)
torch.compile(model_fn, backend=capturing_backend)(x, y)
split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"])
splitting_items = [item for item in split_items if item.is_splitting_graph]
assert len(splitting_items) == 2
# Verify functional correctness — fails without the fix because torch.Size
# would cross a split boundary as a submodule output
output_original = model_fn(x, y)
output_split = split_gm(*captured_inputs)
if isinstance(output_split, tuple):
output_split = next(o for o in output_split if isinstance(o, torch.Tensor))
assert torch.allclose(output_original, output_split), "Output mismatch after split"
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
def test_sym_size_metadata_propagated():
"""
Validates that new sym_size.int nodes created by the pre-pass have
example_value metadata set. Without it, placeholder metadata in consumer
subgraphs would be None, breaking any code that dynamically builds
example inputs from metadata (e.g. standalone_compile per-submodule).
"""
from torch._inductor import standalone_compile
captured_graph = None
def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule:
nonlocal captured_graph
captured_graph = gm
return gm
def model_fn(x: torch.Tensor) -> torch.Tensor:
shape = x.size()
x = torch.ops.aten.sigmoid.default(x)
x = x.clone().view(shape)
return x
x = torch.randn(4, 8)
torch._dynamo.mark_dynamic(x, 0)
torch.compile(model_fn, backend=capturing_backend)(x)
split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"])
# For each submodule, build example inputs purely from placeholder metadata.
# This fails if example_value is None on any placeholder (i.e. metadata
# was not propagated to the sym_size.int nodes we created).
for item in split_items:
submod = item.graph
example_inputs = []
for n in submod.graph.nodes:
if n.op != "placeholder":
continue
ev = n.meta.get("example_value")
assert ev is not None, (
f"Placeholder '{n.name}' in {item.submod_name} has no "
"example_value metadata. sym_size.int nodes must propagate "
"metadata so consumer subgraphs can be introspected."
)
if isinstance(ev, torch.Tensor):
example_inputs.append(torch.randn(*(int(d) for d in ev.shape)))
else:
example_inputs.append(int(ev))
standalone_compile(submod, example_inputs, dynamic_shapes="from_example_inputs")

View File

@@ -473,9 +473,65 @@ def _merge_empty_only_subgraphs(
prev_non_splitting_subgraph_id = subgraph_id
def _decompose_size_nodes(graph: fx.GraphModule) -> None:
"""Decompose x.size() into per-dim sym_size.int calls.
torch.Size objects cannot cross split boundaries because aot_autograd
cannot handle them as submodule outputs. This replaces each size() call
with individual sym_size.int(x, dim) nodes:
- Dynamic dims (SymInt) → new sym_size.int node
- Static dims (plain int) → inlined as literal constant
"""
# Dynamo captures x.size()/x.shape as call_method target="size".
size_nodes = list(graph.graph.find_nodes(op="call_method", target="size"))
for node in size_nodes:
tensor_node = node.args[0]
ev = tensor_node.meta.get("example_value")
assert ev is not None, (
f"Tensor node '{tensor_node.name}' has no example_value metadata. "
f"Cannot decompose size node '{node.name}'."
)
# Build per-dim replacements: sym_size.int node or literal int.
dims: list[fx.Node | int] = []
with graph.graph.inserting_after(tensor_node):
for i in range(ev.dim()):
dim_val = ev.shape[i]
if isinstance(dim_val, torch.SymInt):
dn = graph.graph.call_function(
torch.ops.aten.sym_size.int, args=(tensor_node, i)
)
dn.meta["example_value"] = dim_val
dims.append(dn)
elif isinstance(dim_val, int):
dims.append(dim_val)
else:
raise AssertionError(
f"dim_val is either torch.SymInt or int, "
f"got {type(dim_val)} for dim {i} of "
f"'{node.name}'"
)
# Replace size node in each user's args.
# Dynamo always passes size as a direct arg: view(clone, size)
# → view(clone, d0, d1, ...)
for user in list(node.users):
new_args = []
for arg in user.args:
if arg is node:
new_args.extend(dims)
else:
new_args.append(arg)
user.args = tuple(new_args)
graph.graph.erase_node(node)
def split_graph(
graph: fx.GraphModule, splitting_ops: list[str]
) -> tuple[fx.GraphModule, list[SplitItem]]:
_decompose_size_nodes(graph)
# split graph by ops
subgraph_id = 0
node_to_subgraph_id: dict[fx.Node, int] = {}