[compile][graph_partition]Add tensor size handling (#36038)
Signed-off-by: Xiao Fu <xiaofu@meta.com>
This commit is contained in:
@@ -5,6 +5,8 @@ import operator
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch.fx as fx
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
from vllm.compilation.backends import _is_empty_allocation_node, split_graph
|
||||
@@ -327,3 +329,296 @@ def test_builtin_empty_only_partition_is_merged():
|
||||
output_original = gm(x)
|
||||
output_split = split_gm(x)
|
||||
assert torch.allclose(output_original, output_split), "Output mismatch after split"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
|
||||
def test_sym_size_whole_shape_boundary():
|
||||
"""
|
||||
Test that using x.size() (whole shape) across a split boundary can be
|
||||
compiled by standalone_compile.
|
||||
|
||||
The dynamo graph looks like:
|
||||
shape = x.size()
|
||||
y = sigmoid(x) # split point
|
||||
z = y.clone().view(shape)
|
||||
|
||||
Which splits into:
|
||||
subgraph0(x) -> shape # returns torch.Size — problematic
|
||||
subgraph1(x) -> y # sigmoid
|
||||
subgraph2(y, shape) -> z # view
|
||||
|
||||
Two approaches to fix the torch.Size crossing:
|
||||
|
||||
Approach 1 — move sym_size to consumer (memory implication: x passed to
|
||||
subgraph2 just for .size()):
|
||||
subgraph0(x) -> # empty
|
||||
subgraph1(x) -> y
|
||||
subgraph2(y, x) -> z # computes shape locally from x
|
||||
|
||||
Approach 2 — decompose shape into individual int/SymInt values:
|
||||
subgraph0(x) -> s0, val # returns individual scalars, not Size
|
||||
subgraph1(x) -> y
|
||||
subgraph2(y, s0, val) -> z # reconstructs view args from scalars
|
||||
"""
|
||||
from torch._inductor import standalone_compile
|
||||
|
||||
captured_graph = None
|
||||
|
||||
def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule:
|
||||
nonlocal captured_graph
|
||||
captured_graph = gm
|
||||
return gm
|
||||
|
||||
def model_fn(x: torch.Tensor) -> torch.Tensor:
|
||||
shape = x.size()
|
||||
x = torch.ops.aten.sigmoid.default(x)
|
||||
x = x.clone().view(shape)
|
||||
return x
|
||||
|
||||
x = torch.randn(4, 8)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
compiled_fn = torch.compile(model_fn, backend=capturing_backend)
|
||||
compiled_fn(x)
|
||||
|
||||
split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"])
|
||||
assert len(split_items) == 3
|
||||
|
||||
submod_0 = split_gm.submod_0
|
||||
example_input = torch.randn(4, 8)
|
||||
compiled = standalone_compile(
|
||||
submod_0, [example_input, 4], dynamic_shapes="from_example_inputs"
|
||||
)
|
||||
assert compiled is not None
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
|
||||
def test_symint_crosses_split_boundary():
|
||||
"""
|
||||
Test that SymInt placeholders from torch.compile + mark_dynamic
|
||||
cross split boundaries safely via split_module's natural threading.
|
||||
|
||||
SymInt values are threaded through subgraphs by split_module and
|
||||
handled correctly by inductor — no special replacement is needed.
|
||||
"""
|
||||
captured_graph = None
|
||||
|
||||
def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule:
|
||||
nonlocal captured_graph
|
||||
captured_graph = gm
|
||||
return gm
|
||||
|
||||
def model_fn(x: torch.Tensor) -> torch.Tensor:
|
||||
batch_size = x.shape[0]
|
||||
hidden_size = x.shape[1]
|
||||
x = torch.ops.aten.sigmoid.default(x)
|
||||
x = x.clone().view(batch_size, hidden_size)
|
||||
x = torch.ops.aten.sigmoid.default(x)
|
||||
x = x.clone().view(batch_size, hidden_size)
|
||||
x = torch.ops.aten.sigmoid.default(x)
|
||||
x = x.clone().view(batch_size, hidden_size)
|
||||
return x
|
||||
|
||||
x = torch.randn(4, 8)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
|
||||
compiled_fn = torch.compile(model_fn, backend=capturing_backend)
|
||||
compiled_fn(x)
|
||||
|
||||
assert captured_graph is not None, "Graph should be captured by backend"
|
||||
|
||||
# SymInt placeholders should exist in the captured graph
|
||||
symint_placeholders = [
|
||||
node
|
||||
for node in captured_graph.graph.nodes
|
||||
if node.op == "placeholder"
|
||||
and isinstance(node.meta.get("example_value"), torch.SymInt)
|
||||
]
|
||||
assert len(symint_placeholders) > 0, (
|
||||
"Captured graph should have SymInt placeholders from mark_dynamic."
|
||||
)
|
||||
|
||||
# split_graph should handle SymInt placeholders without error
|
||||
split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"])
|
||||
|
||||
# Should have 3 splitting subgraphs (3 sigmoids)
|
||||
splitting_subgraphs = [item for item in split_items if item.is_splitting_graph]
|
||||
assert len(splitting_subgraphs) == 3, (
|
||||
f"Expected 3 splitting subgraphs (3 sigmoids), got {len(splitting_subgraphs)}"
|
||||
)
|
||||
assert len(split_items) >= 6, (
|
||||
f"Expected at least 6 total subgraphs, got {len(split_items)}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
|
||||
def test_shape_boundary_standalone_compile():
|
||||
"""
|
||||
Repro for the original production bug:
|
||||
|
||||
AssertionError: out_spec mismatch
|
||||
TreeSpec(tuple, None, [*, *, TreeSpec(Size, None, [*, *]), *])
|
||||
vs
|
||||
TreeSpec(tuple, None, [*, *, *, *])
|
||||
|
||||
A subgraph outputs torch.Size (e.g. torch.Size([s72, 2048])) as one of
|
||||
its values when shape info crosses a split boundary. aot_autograd / inductor
|
||||
expect all submodule outputs to be flat tensors or scalars, not torch.Size.
|
||||
|
||||
With the fix, x.size() is decomposed into individual sym_size.int calls
|
||||
so only scalar SymInts cross the boundary — not the torch.Size.
|
||||
"""
|
||||
from torch._inductor import standalone_compile
|
||||
|
||||
captured_graph = None
|
||||
|
||||
def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule:
|
||||
nonlocal captured_graph
|
||||
captured_graph = gm
|
||||
return gm
|
||||
|
||||
def model_fn(x: torch.Tensor) -> torch.Tensor:
|
||||
shape = x.size()
|
||||
x = torch.ops.aten.sigmoid.default(x)
|
||||
x = x.clone().view(shape)
|
||||
return x
|
||||
|
||||
x = torch.randn(4, 8)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
torch.compile(model_fn, backend=capturing_backend)(x)
|
||||
|
||||
split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"])
|
||||
assert len(split_items) == 3
|
||||
|
||||
# Verify that the consumer subgraph only has a placeholder for the dynamic
|
||||
# dim (SymInt) — the static dim (8) should be inlined as a literal, not
|
||||
# threaded as a placeholder.
|
||||
consumer = split_items[-1] # valid since len == 3: [producer, sigmoid, consumer]
|
||||
symint_placeholders = [
|
||||
n
|
||||
for n in consumer.graph.graph.nodes
|
||||
if n.op == "placeholder"
|
||||
and isinstance(n.meta.get("example_value"), torch.SymInt)
|
||||
]
|
||||
static_int_placeholders = [
|
||||
n
|
||||
for n in consumer.graph.graph.nodes
|
||||
if n.op == "placeholder"
|
||||
and isinstance(n.meta.get("example_value"), int)
|
||||
and not isinstance(n.meta.get("example_value"), torch.SymInt)
|
||||
]
|
||||
assert len(symint_placeholders) >= 1, (
|
||||
"Consumer should have a SymInt placeholder for the dynamic dim."
|
||||
)
|
||||
assert len(static_int_placeholders) == 0, (
|
||||
"Static dims should be inlined as literals, not threaded as placeholders."
|
||||
)
|
||||
|
||||
submod_0 = split_gm.submod_0
|
||||
|
||||
standalone_compile(
|
||||
submod_0, [torch.randn(4, 8), 4], dynamic_shapes="from_example_inputs"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
|
||||
def test_size_used_in_multiple_consumer_subgraphs():
|
||||
"""
|
||||
Validates that x.size() (whole shape) used by multiple downstream subgraphs
|
||||
does not cause torch.Size to cross split boundaries.
|
||||
|
||||
Model:
|
||||
shape = x.size() # whole shape — must not cross as torch.Size
|
||||
z1 = sigmoid(x) # split point 1
|
||||
y1 = y.view(shape) # consumer 1 uses shape
|
||||
z2 = sigmoid(z1) # split point 2
|
||||
y2 = y.view(shape) # consumer 2 uses shape again
|
||||
|
||||
Without the fix, torch.Size crosses the boundary as a submodule output,
|
||||
which aot_autograd / standalone_compile rejects.
|
||||
"""
|
||||
captured_graph = None
|
||||
captured_inputs = None
|
||||
|
||||
def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule:
|
||||
nonlocal captured_graph, captured_inputs
|
||||
captured_graph = gm
|
||||
captured_inputs = example_inputs
|
||||
return gm
|
||||
|
||||
def model_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
shape = x.size()
|
||||
z1 = torch.ops.aten.sigmoid.default(x)
|
||||
y1 = y.view(shape)
|
||||
z2 = torch.ops.aten.sigmoid.default(z1)
|
||||
y2 = y.view(shape)
|
||||
return z2 + y1 + y2
|
||||
|
||||
x = torch.randn(4, 8)
|
||||
y = torch.randn(4, 8) # same shape as x so view(shape) doesn't specialize dim 0
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
torch._dynamo.mark_dynamic(y, 0)
|
||||
torch.compile(model_fn, backend=capturing_backend)(x, y)
|
||||
|
||||
split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"])
|
||||
|
||||
splitting_items = [item for item in split_items if item.is_splitting_graph]
|
||||
assert len(splitting_items) == 2
|
||||
|
||||
# Verify functional correctness — fails without the fix because torch.Size
|
||||
# would cross a split boundary as a submodule output
|
||||
output_original = model_fn(x, y)
|
||||
output_split = split_gm(*captured_inputs)
|
||||
if isinstance(output_split, tuple):
|
||||
output_split = next(o for o in output_split if isinstance(o, torch.Tensor))
|
||||
assert torch.allclose(output_original, output_split), "Output mismatch after split"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
|
||||
def test_sym_size_metadata_propagated():
|
||||
"""
|
||||
Validates that new sym_size.int nodes created by the pre-pass have
|
||||
example_value metadata set. Without it, placeholder metadata in consumer
|
||||
subgraphs would be None, breaking any code that dynamically builds
|
||||
example inputs from metadata (e.g. standalone_compile per-submodule).
|
||||
"""
|
||||
from torch._inductor import standalone_compile
|
||||
|
||||
captured_graph = None
|
||||
|
||||
def capturing_backend(gm: fx.GraphModule, example_inputs: list) -> fx.GraphModule:
|
||||
nonlocal captured_graph
|
||||
captured_graph = gm
|
||||
return gm
|
||||
|
||||
def model_fn(x: torch.Tensor) -> torch.Tensor:
|
||||
shape = x.size()
|
||||
x = torch.ops.aten.sigmoid.default(x)
|
||||
x = x.clone().view(shape)
|
||||
return x
|
||||
|
||||
x = torch.randn(4, 8)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
torch.compile(model_fn, backend=capturing_backend)(x)
|
||||
|
||||
split_gm, split_items = split_graph(captured_graph, ["aten::sigmoid"])
|
||||
|
||||
# For each submodule, build example inputs purely from placeholder metadata.
|
||||
# This fails if example_value is None on any placeholder (i.e. metadata
|
||||
# was not propagated to the sym_size.int nodes we created).
|
||||
for item in split_items:
|
||||
submod = item.graph
|
||||
example_inputs = []
|
||||
for n in submod.graph.nodes:
|
||||
if n.op != "placeholder":
|
||||
continue
|
||||
ev = n.meta.get("example_value")
|
||||
assert ev is not None, (
|
||||
f"Placeholder '{n.name}' in {item.submod_name} has no "
|
||||
"example_value metadata. sym_size.int nodes must propagate "
|
||||
"metadata so consumer subgraphs can be introspected."
|
||||
)
|
||||
if isinstance(ev, torch.Tensor):
|
||||
example_inputs.append(torch.randn(*(int(d) for d in ev.shape)))
|
||||
else:
|
||||
example_inputs.append(int(ev))
|
||||
standalone_compile(submod, example_inputs, dynamic_shapes="from_example_inputs")
|
||||
|
||||
@@ -473,9 +473,65 @@ def _merge_empty_only_subgraphs(
|
||||
prev_non_splitting_subgraph_id = subgraph_id
|
||||
|
||||
|
||||
def _decompose_size_nodes(graph: fx.GraphModule) -> None:
|
||||
"""Decompose x.size() into per-dim sym_size.int calls.
|
||||
|
||||
torch.Size objects cannot cross split boundaries because aot_autograd
|
||||
cannot handle them as submodule outputs. This replaces each size() call
|
||||
with individual sym_size.int(x, dim) nodes:
|
||||
- Dynamic dims (SymInt) → new sym_size.int node
|
||||
- Static dims (plain int) → inlined as literal constant
|
||||
"""
|
||||
# Dynamo captures x.size()/x.shape as call_method target="size".
|
||||
size_nodes = list(graph.graph.find_nodes(op="call_method", target="size"))
|
||||
|
||||
for node in size_nodes:
|
||||
tensor_node = node.args[0]
|
||||
ev = tensor_node.meta.get("example_value")
|
||||
assert ev is not None, (
|
||||
f"Tensor node '{tensor_node.name}' has no example_value metadata. "
|
||||
f"Cannot decompose size node '{node.name}'."
|
||||
)
|
||||
|
||||
# Build per-dim replacements: sym_size.int node or literal int.
|
||||
dims: list[fx.Node | int] = []
|
||||
with graph.graph.inserting_after(tensor_node):
|
||||
for i in range(ev.dim()):
|
||||
dim_val = ev.shape[i]
|
||||
if isinstance(dim_val, torch.SymInt):
|
||||
dn = graph.graph.call_function(
|
||||
torch.ops.aten.sym_size.int, args=(tensor_node, i)
|
||||
)
|
||||
dn.meta["example_value"] = dim_val
|
||||
dims.append(dn)
|
||||
elif isinstance(dim_val, int):
|
||||
dims.append(dim_val)
|
||||
else:
|
||||
raise AssertionError(
|
||||
f"dim_val is either torch.SymInt or int, "
|
||||
f"got {type(dim_val)} for dim {i} of "
|
||||
f"'{node.name}'"
|
||||
)
|
||||
|
||||
# Replace size node in each user's args.
|
||||
# Dynamo always passes size as a direct arg: view(clone, size)
|
||||
# → view(clone, d0, d1, ...)
|
||||
for user in list(node.users):
|
||||
new_args = []
|
||||
for arg in user.args:
|
||||
if arg is node:
|
||||
new_args.extend(dims)
|
||||
else:
|
||||
new_args.append(arg)
|
||||
user.args = tuple(new_args)
|
||||
graph.graph.erase_node(node)
|
||||
|
||||
|
||||
def split_graph(
|
||||
graph: fx.GraphModule, splitting_ops: list[str]
|
||||
) -> tuple[fx.GraphModule, list[SplitItem]]:
|
||||
_decompose_size_nodes(graph)
|
||||
|
||||
# split graph by ops
|
||||
subgraph_id = 0
|
||||
node_to_subgraph_id: dict[fx.Node, int] = {}
|
||||
|
||||
Reference in New Issue
Block a user