[Bugfix] Eliminate tuple inputs to submodules in graph partitioning (#28533)
Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
This commit is contained in:
@@ -445,6 +445,7 @@ steps:
|
|||||||
- vllm/
|
- vllm/
|
||||||
- tests/compile
|
- tests/compile
|
||||||
commands:
|
commands:
|
||||||
|
- pytest -v -s compile/test_graph_partition.py
|
||||||
- pytest -v -s compile/test_config.py
|
- pytest -v -s compile/test_config.py
|
||||||
- pytest -v -s compile/test_pass_manager.py
|
- pytest -v -s compile/test_pass_manager.py
|
||||||
- pytest -v -s compile/test_fusion.py
|
- pytest -v -s compile/test_fusion.py
|
||||||
|
|||||||
124
tests/compile/test_graph_partition.py
Normal file
124
tests/compile/test_graph_partition.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import operator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
|
|
||||||
|
from vllm.compilation.backends import split_graph
|
||||||
|
|
||||||
|
|
||||||
|
def test_getitem_moved_to_producer_subgraph():
|
||||||
|
"""
|
||||||
|
Test that getitem operations are moved to the same subgraph as their input,
|
||||||
|
preventing tuple inputs to submodules.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def model_fn(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# torch.split returns a tuple, creating real getitem operations
|
||||||
|
# Should become first submodule that produces tuple
|
||||||
|
chunks = torch.split(x, x.shape[0] // 2, dim=0)
|
||||||
|
|
||||||
|
# Following ops should become second submodule that consumes tuple
|
||||||
|
result_0 = torch.relu(chunks[0])
|
||||||
|
result_1 = torch.relu(chunks[1])
|
||||||
|
return torch.cat([result_0, result_1], dim=0)
|
||||||
|
|
||||||
|
x = torch.randn(4, 3)
|
||||||
|
gm = make_fx(model_fn)(x)
|
||||||
|
|
||||||
|
has_getitem = any(
|
||||||
|
node.op == "call_function" and node.target == operator.getitem
|
||||||
|
for node in gm.graph.nodes
|
||||||
|
)
|
||||||
|
assert has_getitem, "Test setup failed: graph should contain getitem operations"
|
||||||
|
|
||||||
|
# Split on tuple producer aten::split
|
||||||
|
split_ops = ["aten::split.Tensor"]
|
||||||
|
split_gm, split_items = split_graph(gm, split_ops)
|
||||||
|
assert len(split_items) == 2, "Graph should be split into 2 submodules"
|
||||||
|
|
||||||
|
for split_item in split_items:
|
||||||
|
submodule = split_item.graph
|
||||||
|
|
||||||
|
getitem_on_placeholder = []
|
||||||
|
for node in submodule.graph.nodes:
|
||||||
|
if (
|
||||||
|
node.op == "call_function"
|
||||||
|
and node.target == operator.getitem
|
||||||
|
and node.args[0].op == "placeholder"
|
||||||
|
):
|
||||||
|
getitem_on_placeholder.append(node)
|
||||||
|
|
||||||
|
assert len(getitem_on_placeholder) == 0, (
|
||||||
|
f"Submodule {split_item.submod_name} has getitem operations on "
|
||||||
|
f"placeholder nodes: {[n.name for n in getitem_on_placeholder]}. "
|
||||||
|
"This means tuple inputs were not properly eliminated."
|
||||||
|
)
|
||||||
|
|
||||||
|
new_x = torch.randn(4, 3)
|
||||||
|
output_original = gm(new_x)
|
||||||
|
output_split = split_gm(new_x)
|
||||||
|
|
||||||
|
assert torch.allclose(output_original, output_split), "Output mismatch"
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_tuple_inputs_with_multiple_consumers():
|
||||||
|
"""
|
||||||
|
Test that when a tuple is consumed by multiple split operations,
|
||||||
|
getitem operations are properly moved to avoid tuple inputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def model_fn(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# torch.split returns a tuple, creating real getitem operations
|
||||||
|
# Should become first submodule that produces tuple
|
||||||
|
chunks = torch.split(x, x.shape[0] // 2, dim=0)
|
||||||
|
|
||||||
|
# These should become second submodule consuming tuple
|
||||||
|
result_1 = torch.relu(chunks[0])
|
||||||
|
result_2 = torch.relu(chunks[1])
|
||||||
|
|
||||||
|
# Artificial graph splitting point to create another
|
||||||
|
# independent submodule that consumes tuple later
|
||||||
|
# This would become the third submodule
|
||||||
|
result_1 = torch.sigmoid(result_1)
|
||||||
|
|
||||||
|
# Fourth submodule that consumes tuple
|
||||||
|
result = torch.cat([chunks[0], chunks[1], result_1, result_2])
|
||||||
|
return result
|
||||||
|
|
||||||
|
x = torch.randn(4, 3)
|
||||||
|
gm = make_fx(model_fn)(x)
|
||||||
|
|
||||||
|
has_getitem = any(
|
||||||
|
node.op == "call_function" and node.target == operator.getitem
|
||||||
|
for node in gm.graph.nodes
|
||||||
|
)
|
||||||
|
assert has_getitem, "Test setup failed: graph should contain getitem operations"
|
||||||
|
|
||||||
|
split_ops = ["aten::split.Tensor", "aten::sigmoid"]
|
||||||
|
split_gm, split_items = split_graph(gm, split_ops)
|
||||||
|
assert len(split_items) == 4, "Graph should be split into 4 submodules"
|
||||||
|
|
||||||
|
for split_item in split_items:
|
||||||
|
submodule = split_item.graph
|
||||||
|
|
||||||
|
for node in submodule.graph.nodes:
|
||||||
|
if (
|
||||||
|
node.op == "call_function"
|
||||||
|
and node.target == operator.getitem
|
||||||
|
and node.args[0].op == "placeholder"
|
||||||
|
):
|
||||||
|
pytest.fail(
|
||||||
|
f"Submodule {split_item.submod_name} has getitem on "
|
||||||
|
f"placeholder {node.args[0].name}, indicating it receives "
|
||||||
|
"a tuple input"
|
||||||
|
)
|
||||||
|
|
||||||
|
new_x = torch.randn(4, 3)
|
||||||
|
output_original = gm(new_x)
|
||||||
|
output_split = split_gm(new_x)
|
||||||
|
|
||||||
|
assert torch.allclose(output_original, output_split), "Output mismatch after split"
|
||||||
@@ -4,6 +4,7 @@
|
|||||||
import ast
|
import ast
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import operator
|
||||||
import os
|
import os
|
||||||
import pprint
|
import pprint
|
||||||
import time
|
import time
|
||||||
@@ -307,12 +308,24 @@ def split_graph(
|
|||||||
) -> tuple[fx.GraphModule, list[SplitItem]]:
|
) -> tuple[fx.GraphModule, list[SplitItem]]:
|
||||||
# split graph by ops
|
# split graph by ops
|
||||||
subgraph_id = 0
|
subgraph_id = 0
|
||||||
node_to_subgraph_id = {}
|
node_to_subgraph_id: dict[fx.Node, int] = {}
|
||||||
split_op_graphs = []
|
split_op_graphs: list[int] = []
|
||||||
for node in graph.graph.nodes:
|
for node in graph.graph.nodes:
|
||||||
if node.op in ("output", "placeholder"):
|
if node.op in ("output", "placeholder"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Check if this is a getitem operation on a node from an earlier subgraph.
|
||||||
|
# If so, assign it to the same subgraph as its input to avoid passing entire
|
||||||
|
# tuple as input to submodules, which is against standalone_compile and
|
||||||
|
# AoTAutograd input requirement.
|
||||||
|
if node.op == "call_function" and node.target == operator.getitem:
|
||||||
|
# Assign this getitem to the same subgraph as its input
|
||||||
|
input_node = node.args[0]
|
||||||
|
if input_node.op != "placeholder":
|
||||||
|
assert input_node in node_to_subgraph_id
|
||||||
|
node_to_subgraph_id[node] = node_to_subgraph_id[input_node]
|
||||||
|
continue
|
||||||
|
|
||||||
if should_split(node, splitting_ops):
|
if should_split(node, splitting_ops):
|
||||||
subgraph_id += 1
|
subgraph_id += 1
|
||||||
node_to_subgraph_id[node] = subgraph_id
|
node_to_subgraph_id[node] = subgraph_id
|
||||||
|
|||||||
Reference in New Issue
Block a user