From 29152683699f87480b12aa96875f025b7c5137fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Sat, 31 Jan 2026 09:48:34 -0500 Subject: [PATCH] [fix][torch.compile] Fix cold-start compilation time increase by adding kv cache update to splitting ops (#33441) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Luka Govedič Co-authored-by: Richard Zou (cherry picked from commit 15f40b20aadf27af43bdae38bf4644b82e21634f) --- tests/compile/test_cold_start.py | 48 +++++++++++++++++++++ tests/compile/test_graph_partition.py | 62 +++++++++++++++++++++++++++ vllm/compilation/backends.py | 9 +++- vllm/config/compilation.py | 9 ++++ 4 files changed, 127 insertions(+), 1 deletion(-) create mode 100644 tests/compile/test_cold_start.py diff --git a/tests/compile/test_cold_start.py b/tests/compile/test_cold_start.py new file mode 100644 index 000000000..1d24d1839 --- /dev/null +++ b/tests/compile/test_cold_start.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from torch._dynamo.utils import counters + +from vllm import LLM +from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode + + +def test_moe_compilation_cold_start(monkeypatch, use_fresh_inductor_cache): + # Run in same process so we can access PyTorch's internal counters + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + + # I'm not sure if this is going to affect the numbers + monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "0") + + # Force cold compilation + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + + compilation_config = CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + cudagraph_mode=CUDAGraphMode.NONE, # make the model loading faster + ) + + counters.clear() + + _ = LLM( + model="microsoft/Phi-tiny-MoE-instruct", + max_model_len=256, + load_format="dummy", # make the model loading faster + compilation_config=compilation_config, + num_gpu_blocks_override=8, # make the model loading faster + ) + + # vLLM-compile cold start is special. By default, we do + # one full dynamo capture of the entire forward pass. + # The forward pass consists of 32 transformer layers. + # Then, we split on the attention operation. This results in + # 33 subgraphs (not including the attention operation). + # The 33 subgraphs then get standalone_compile'd. + # + # There are actually only 3 unique subgraphs for this model + # (all of its transformer layers are the same modulo weights); + # this is true for most vLLM models. + # So we test that during cold start, the aot_autograd cache + # misses for 3 subgraphs and hits for the rest. + assert counters["aot_autograd"]["autograd_cache_miss"] == 3 + assert counters["aot_autograd"]["autograd_cache_hit"] == 30 diff --git a/tests/compile/test_graph_partition.py b/tests/compile/test_graph_partition.py index 1cd783843..38e3e038a 100644 --- a/tests/compile/test_graph_partition.py +++ b/tests/compile/test_graph_partition.py @@ -8,6 +8,10 @@ import torch from torch.fx.experimental.proxy_tensor import make_fx from vllm.compilation.backends import split_graph +from vllm.compilation.fx_utils import find_op_nodes + +# This import automatically registers `torch.ops.silly.attention` +from . import silly_attention # noqa: F401 def test_getitem_moved_to_producer_subgraph(): @@ -122,3 +126,61 @@ def test_no_tuple_inputs_with_multiple_consumers(): output_split = split_gm(new_x) assert torch.allclose(output_original, output_split), "Output mismatch after split" + + +def test_consecutive_ops_in_split(): + """ + Test that consecutive splitting operations are grouped into the same subgraph + """ + + def model_fn(x: torch.Tensor) -> torch.Tensor: + """ + Define a simple model where consecutive operations create opportunities + for splitting subgraphs. + """ + # Apply silly attention followed by consecutive operations + intermediate = torch.relu(x) + attn_inout = torch.sqrt(intermediate) + torch.ops.silly.attention(intermediate, intermediate, attn_inout, attn_inout) + final_result = torch.sigmoid(attn_inout) + return final_result + + torch.set_default_device("cuda") + + # Create the traced FX graph for the model + x = torch.randn(8, 4) + + gm = make_fx(model_fn)(x) + + # Assert presence of the expected operations in the setup + assert ( + len(list(find_op_nodes(torch.ops.aten.relu, gm.graph))) == 1 + and len(list(find_op_nodes(torch.ops.aten.sqrt, gm.graph))) == 1 + ), "Test setup failed: Expected sqrt and relu operations in the graph." + + # Configure split operations to test + splitting_ops = ["silly::attention", "aten::sqrt"] + split_gm, split_items = split_graph(gm, splitting_ops) + + # Validate the number of partitions + assert len(split_items) == 3, ( + "Consecutive splitting operations were not grouped correctly." + ) + + # Validate that correctness is preserved + new_x = torch.randn(8, 4) + output_original = gm(new_x) + output_split = split_gm(new_x) + assert torch.allclose(output_original, output_split), ( + "Output mismatch after splitting." + ) + + # Check the splitting item has 2 nodes exactly (relu and attn) + splitting_items = list(s for s in split_items if s.is_splitting_graph) + assert len(splitting_items) == 1, "Expecting a single splitting graph" + print(splitting_items[0].graph.graph) + splitting_gm = splitting_items[0].graph + assert len(splitting_gm.graph.nodes) == 4, "Expecting 4 nodes in splitting graph" + assert [node.op for node in splitting_gm.graph.nodes] == ["placeholder"] + 2 * [ + "call_function" + ] + ["output"] diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 454d81317..ad7993261 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -361,7 +361,14 @@ def split_graph( subgraph_id += 1 node_to_subgraph_id[node] = subgraph_id split_op_graphs.append(subgraph_id) - subgraph_id += 1 + + # keep consecutive splitting ops together + # (we know node.next exists because node isn't the last (output) node) + if should_split(node.next, splitting_ops): + # this will get incremented by the next node + subgraph_id -= 1 + else: + subgraph_id += 1 else: node_to_subgraph_id[node] = subgraph_id diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 9d86a4bae..8ca8eaa3b 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -925,6 +925,15 @@ class CompilationConfig: # for details. Make a copy to avoid mutating the class-level # list via reference. self.splitting_ops = list(self._attention_ops) + + # unified_kv_cache_update has a string param that prevents Inductor + # from reusing piecewise graphs. Remove it from the compiled graph. + # This has the side-effect of excluding cache from cudagraphs but + # that doesn't seem to affect performance. + # https://github.com/vllm-project/vllm/issues/33267 + if not self.use_inductor_graph_partition: + self.splitting_ops.append("vllm::unified_kv_cache_update") + elif len(self.splitting_ops) == 0: if ( self.cudagraph_mode == CUDAGraphMode.PIECEWISE