[fix][torch.compile] Fix cold-start compilation time increase by adding kv cache update to splitting ops (#33441)
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Richard Zou <zou3519@gmail.com>
(cherry picked from commit 15f40b20aa)
This commit is contained in:
48
tests/compile/test_cold_start.py
Normal file
48
tests/compile/test_cold_start.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from torch._dynamo.utils import counters
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode
|
||||
|
||||
|
||||
def test_moe_compilation_cold_start(monkeypatch, use_fresh_inductor_cache):
|
||||
# Run in same process so we can access PyTorch's internal counters
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
|
||||
# I'm not sure if this is going to affect the numbers
|
||||
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "0")
|
||||
|
||||
# Force cold compilation
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
compilation_config = CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
cudagraph_mode=CUDAGraphMode.NONE, # make the model loading faster
|
||||
)
|
||||
|
||||
counters.clear()
|
||||
|
||||
_ = LLM(
|
||||
model="microsoft/Phi-tiny-MoE-instruct",
|
||||
max_model_len=256,
|
||||
load_format="dummy", # make the model loading faster
|
||||
compilation_config=compilation_config,
|
||||
num_gpu_blocks_override=8, # make the model loading faster
|
||||
)
|
||||
|
||||
# vLLM-compile cold start is special. By default, we do
|
||||
# one full dynamo capture of the entire forward pass.
|
||||
# The forward pass consists of 32 transformer layers.
|
||||
# Then, we split on the attention operation. This results in
|
||||
# 33 subgraphs (not including the attention operation).
|
||||
# The 33 subgraphs then get standalone_compile'd.
|
||||
#
|
||||
# There are actually only 3 unique subgraphs for this model
|
||||
# (all of its transformer layers are the same modulo weights);
|
||||
# this is true for most vLLM models.
|
||||
# So we test that during cold start, the aot_autograd cache
|
||||
# misses for 3 subgraphs and hits for the rest.
|
||||
assert counters["aot_autograd"]["autograd_cache_miss"] == 3
|
||||
assert counters["aot_autograd"]["autograd_cache_hit"] == 30
|
||||
@@ -8,6 +8,10 @@ import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
from vllm.compilation.backends import split_graph
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
|
||||
# This import automatically registers `torch.ops.silly.attention`
|
||||
from . import silly_attention # noqa: F401
|
||||
|
||||
|
||||
def test_getitem_moved_to_producer_subgraph():
|
||||
@@ -122,3 +126,61 @@ def test_no_tuple_inputs_with_multiple_consumers():
|
||||
output_split = split_gm(new_x)
|
||||
|
||||
assert torch.allclose(output_original, output_split), "Output mismatch after split"
|
||||
|
||||
|
||||
def test_consecutive_ops_in_split():
|
||||
"""
|
||||
Test that consecutive splitting operations are grouped into the same subgraph
|
||||
"""
|
||||
|
||||
def model_fn(x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Define a simple model where consecutive operations create opportunities
|
||||
for splitting subgraphs.
|
||||
"""
|
||||
# Apply silly attention followed by consecutive operations
|
||||
intermediate = torch.relu(x)
|
||||
attn_inout = torch.sqrt(intermediate)
|
||||
torch.ops.silly.attention(intermediate, intermediate, attn_inout, attn_inout)
|
||||
final_result = torch.sigmoid(attn_inout)
|
||||
return final_result
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
# Create the traced FX graph for the model
|
||||
x = torch.randn(8, 4)
|
||||
|
||||
gm = make_fx(model_fn)(x)
|
||||
|
||||
# Assert presence of the expected operations in the setup
|
||||
assert (
|
||||
len(list(find_op_nodes(torch.ops.aten.relu, gm.graph))) == 1
|
||||
and len(list(find_op_nodes(torch.ops.aten.sqrt, gm.graph))) == 1
|
||||
), "Test setup failed: Expected sqrt and relu operations in the graph."
|
||||
|
||||
# Configure split operations to test
|
||||
splitting_ops = ["silly::attention", "aten::sqrt"]
|
||||
split_gm, split_items = split_graph(gm, splitting_ops)
|
||||
|
||||
# Validate the number of partitions
|
||||
assert len(split_items) == 3, (
|
||||
"Consecutive splitting operations were not grouped correctly."
|
||||
)
|
||||
|
||||
# Validate that correctness is preserved
|
||||
new_x = torch.randn(8, 4)
|
||||
output_original = gm(new_x)
|
||||
output_split = split_gm(new_x)
|
||||
assert torch.allclose(output_original, output_split), (
|
||||
"Output mismatch after splitting."
|
||||
)
|
||||
|
||||
# Check the splitting item has 2 nodes exactly (relu and attn)
|
||||
splitting_items = list(s for s in split_items if s.is_splitting_graph)
|
||||
assert len(splitting_items) == 1, "Expecting a single splitting graph"
|
||||
print(splitting_items[0].graph.graph)
|
||||
splitting_gm = splitting_items[0].graph
|
||||
assert len(splitting_gm.graph.nodes) == 4, "Expecting 4 nodes in splitting graph"
|
||||
assert [node.op for node in splitting_gm.graph.nodes] == ["placeholder"] + 2 * [
|
||||
"call_function"
|
||||
] + ["output"]
|
||||
|
||||
@@ -361,7 +361,14 @@ def split_graph(
|
||||
subgraph_id += 1
|
||||
node_to_subgraph_id[node] = subgraph_id
|
||||
split_op_graphs.append(subgraph_id)
|
||||
subgraph_id += 1
|
||||
|
||||
# keep consecutive splitting ops together
|
||||
# (we know node.next exists because node isn't the last (output) node)
|
||||
if should_split(node.next, splitting_ops):
|
||||
# this will get incremented by the next node
|
||||
subgraph_id -= 1
|
||||
else:
|
||||
subgraph_id += 1
|
||||
else:
|
||||
node_to_subgraph_id[node] = subgraph_id
|
||||
|
||||
|
||||
@@ -925,6 +925,15 @@ class CompilationConfig:
|
||||
# for details. Make a copy to avoid mutating the class-level
|
||||
# list via reference.
|
||||
self.splitting_ops = list(self._attention_ops)
|
||||
|
||||
# unified_kv_cache_update has a string param that prevents Inductor
|
||||
# from reusing piecewise graphs. Remove it from the compiled graph.
|
||||
# This has the side-effect of excluding cache from cudagraphs but
|
||||
# that doesn't seem to affect performance.
|
||||
# https://github.com/vllm-project/vllm/issues/33267
|
||||
if not self.use_inductor_graph_partition:
|
||||
self.splitting_ops.append("vllm::unified_kv_cache_update")
|
||||
|
||||
elif len(self.splitting_ops) == 0:
|
||||
if (
|
||||
self.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||
|
||||
Reference in New Issue
Block a user