[fix][torch.compile] Fix cold-start compilation time increase by adding kv cache update to splitting ops (#33441)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
Luka Govedič
2026-01-31 09:48:34 -05:00
committed by GitHub
parent 793af538a3
commit 15f40b20aa
4 changed files with 127 additions and 1 deletions

View File

@@ -359,7 +359,14 @@ def split_graph(
subgraph_id += 1
node_to_subgraph_id[node] = subgraph_id
split_op_graphs.append(subgraph_id)
subgraph_id += 1
# keep consecutive splitting ops together
# (we know node.next exists because node isn't the last (output) node)
if should_split(node.next, splitting_ops):
# this will get incremented by the next node
subgraph_id -= 1
else:
subgraph_id += 1
else:
node_to_subgraph_id[node] = subgraph_id