[Bugfix] Eliminate tuple inputs to submodules in graph partitioning (#28533)

Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
This commit is contained in:
Yanan Cao
2025-11-13 12:09:05 -08:00
committed by GitHub
parent 968060c15a
commit 262d263f6c
3 changed files with 140 additions and 2 deletions

View File

@@ -4,6 +4,7 @@
import ast
import dataclasses
import hashlib
import operator
import os
import pprint
import time
@@ -307,12 +308,24 @@ def split_graph(
) -> tuple[fx.GraphModule, list[SplitItem]]:
# split graph by ops
subgraph_id = 0
node_to_subgraph_id = {}
split_op_graphs = []
node_to_subgraph_id: dict[fx.Node, int] = {}
split_op_graphs: list[int] = []
for node in graph.graph.nodes:
if node.op in ("output", "placeholder"):
continue
# Check if this is a getitem operation on a node from an earlier subgraph.
# If so, assign it to the same subgraph as its input to avoid passing entire
# tuple as input to submodules, which is against standalone_compile and
# AoTAutograd input requirement.
if node.op == "call_function" and node.target == operator.getitem:
# Assign this getitem to the same subgraph as its input
input_node = node.args[0]
if input_node.op != "placeholder":
assert input_node in node_to_subgraph_id
node_to_subgraph_id[node] = node_to_subgraph_id[input_node]
continue
if should_split(node, splitting_ops):
subgraph_id += 1
node_to_subgraph_id[node] = subgraph_id