[fix][torch.compile] Fix cold-start compilation time increase by adding kv cache update to splitting ops (#33441)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
@@ -359,7 +359,14 @@ def split_graph(
|
||||
subgraph_id += 1
|
||||
node_to_subgraph_id[node] = subgraph_id
|
||||
split_op_graphs.append(subgraph_id)
|
||||
subgraph_id += 1
|
||||
|
||||
# keep consecutive splitting ops together
|
||||
# (we know node.next exists because node isn't the last (output) node)
|
||||
if should_split(node.next, splitting_ops):
|
||||
# this will get incremented by the next node
|
||||
subgraph_id -= 1
|
||||
else:
|
||||
subgraph_id += 1
|
||||
else:
|
||||
node_to_subgraph_id[node] = subgraph_id
|
||||
|
||||
|
||||
Reference in New Issue
Block a user