[fix][torch.compile] Fix cold-start compilation time increase by adding kv cache update to splitting ops (#33441)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
Luka Govedič
2026-01-31 09:48:34 -05:00
committed by GitHub
parent 793af538a3
commit 15f40b20aa
4 changed files with 127 additions and 1 deletions

View File

@@ -8,6 +8,10 @@ import torch
from torch.fx.experimental.proxy_tensor import make_fx
from vllm.compilation.backends import split_graph
from vllm.compilation.fx_utils import find_op_nodes
# This import automatically registers `torch.ops.silly.attention`
from . import silly_attention # noqa: F401
def test_getitem_moved_to_producer_subgraph():
@@ -122,3 +126,61 @@ def test_no_tuple_inputs_with_multiple_consumers():
output_split = split_gm(new_x)
assert torch.allclose(output_original, output_split), "Output mismatch after split"
def test_consecutive_ops_in_split():
"""
Test that consecutive splitting operations are grouped into the same subgraph
"""
def model_fn(x: torch.Tensor) -> torch.Tensor:
"""
Define a simple model where consecutive operations create opportunities
for splitting subgraphs.
"""
# Apply silly attention followed by consecutive operations
intermediate = torch.relu(x)
attn_inout = torch.sqrt(intermediate)
torch.ops.silly.attention(intermediate, intermediate, attn_inout, attn_inout)
final_result = torch.sigmoid(attn_inout)
return final_result
torch.set_default_device("cuda")
# Create the traced FX graph for the model
x = torch.randn(8, 4)
gm = make_fx(model_fn)(x)
# Assert presence of the expected operations in the setup
assert (
len(list(find_op_nodes(torch.ops.aten.relu, gm.graph))) == 1
and len(list(find_op_nodes(torch.ops.aten.sqrt, gm.graph))) == 1
), "Test setup failed: Expected sqrt and relu operations in the graph."
# Configure split operations to test
splitting_ops = ["silly::attention", "aten::sqrt"]
split_gm, split_items = split_graph(gm, splitting_ops)
# Validate the number of partitions
assert len(split_items) == 3, (
"Consecutive splitting operations were not grouped correctly."
)
# Validate that correctness is preserved
new_x = torch.randn(8, 4)
output_original = gm(new_x)
output_split = split_gm(new_x)
assert torch.allclose(output_original, output_split), (
"Output mismatch after splitting."
)
# Check the splitting item has 2 nodes exactly (relu and attn)
splitting_items = list(s for s in split_items if s.is_splitting_graph)
assert len(splitting_items) == 1, "Expecting a single splitting graph"
print(splitting_items[0].graph.graph)
splitting_gm = splitting_items[0].graph
assert len(splitting_gm.graph.nodes) == 4, "Expecting 4 nodes in splitting graph"
assert [node.op for node in splitting_gm.graph.nodes] == ["placeholder"] + 2 * [
"call_function"
] + ["output"]