[Bugfix] Eliminate tuple inputs to submodules in graph partitioning (#28533)
Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
import ast
|
||||
import dataclasses
|
||||
import hashlib
|
||||
import operator
|
||||
import os
|
||||
import pprint
|
||||
import time
|
||||
@@ -307,12 +308,24 @@ def split_graph(
|
||||
) -> tuple[fx.GraphModule, list[SplitItem]]:
|
||||
# split graph by ops
|
||||
subgraph_id = 0
|
||||
node_to_subgraph_id = {}
|
||||
split_op_graphs = []
|
||||
node_to_subgraph_id: dict[fx.Node, int] = {}
|
||||
split_op_graphs: list[int] = []
|
||||
for node in graph.graph.nodes:
|
||||
if node.op in ("output", "placeholder"):
|
||||
continue
|
||||
|
||||
# Check if this is a getitem operation on a node from an earlier subgraph.
|
||||
# If so, assign it to the same subgraph as its input to avoid passing entire
|
||||
# tuple as input to submodules, which is against standalone_compile and
|
||||
# AoTAutograd input requirement.
|
||||
if node.op == "call_function" and node.target == operator.getitem:
|
||||
# Assign this getitem to the same subgraph as its input
|
||||
input_node = node.args[0]
|
||||
if input_node.op != "placeholder":
|
||||
assert input_node in node_to_subgraph_id
|
||||
node_to_subgraph_id[node] = node_to_subgraph_id[input_node]
|
||||
continue
|
||||
|
||||
if should_split(node, splitting_ops):
|
||||
subgraph_id += 1
|
||||
node_to_subgraph_id[node] = subgraph_id
|
||||
|
||||
Reference in New Issue
Block a user