[aot_compile]change VLLM backend to read fake args from example_value (#29104)
Signed-off-by: Laith Sakka <lsakka@meta.com>
This commit is contained in:
@@ -1,6 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import multiprocessing
|
||||||
import tempfile
|
import tempfile
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
@@ -137,3 +139,67 @@ def test_shape_env(monkeypatch: pytest.MonkeyPatch):
|
|||||||
artifacts = compiled_mod.aot_compiled_fn._artifacts
|
artifacts = compiled_mod.aot_compiled_fn._artifacts
|
||||||
guards_string = artifacts.compiled_fn.shape_env.format_guards()
|
guards_string = artifacts.compiled_fn.shape_env.format_guards()
|
||||||
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
|
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
|
||||||
|
)
|
||||||
|
@use_vllm_config(make_vllm_config())
|
||||||
|
def test_gpt2_cache_hit(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""
|
||||||
|
Test that compiling gpt2 twice results in a cache hit and
|
||||||
|
capture torch dynamic symbol creations to ensure make_symbol
|
||||||
|
not called on cache hit.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch.fx.experimental.symbolic_shapes as symbolic_shapes_module
|
||||||
|
from torch.utils._sympy.symbol import make_symbol
|
||||||
|
|
||||||
|
from vllm import LLM
|
||||||
|
|
||||||
|
create_symbol_counter = multiprocessing.Value("i", 0)
|
||||||
|
original_make_symbol = make_symbol
|
||||||
|
|
||||||
|
@functools.wraps(original_make_symbol)
|
||||||
|
def counting_make_symbol(prefix, idx, **kwargs):
|
||||||
|
with create_symbol_counter.get_lock():
|
||||||
|
create_symbol_counter.value += 1
|
||||||
|
return original_make_symbol(prefix, idx, **kwargs)
|
||||||
|
|
||||||
|
symbolic_shapes_module.make_symbol = counting_make_symbol
|
||||||
|
try:
|
||||||
|
with monkeypatch.context() as m, tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
|
||||||
|
m.setenv("VLLM_USE_AOT_COMPILE", "1")
|
||||||
|
# First compilation - initialize model and generate
|
||||||
|
llm_model = LLM(
|
||||||
|
model="gpt2",
|
||||||
|
compilation_config=CompilationConfig(
|
||||||
|
mode=CompilationMode.VLLM_COMPILE,
|
||||||
|
),
|
||||||
|
max_model_len=256,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_model.generate("Hello, my name is")
|
||||||
|
assert create_symbol_counter.value == 2
|
||||||
|
create_symbol_counter.value = 0
|
||||||
|
|
||||||
|
# Clean up first model
|
||||||
|
del llm_model
|
||||||
|
|
||||||
|
# Second compilation - should hit cache
|
||||||
|
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
|
||||||
|
llm_model = LLM(
|
||||||
|
model="gpt2",
|
||||||
|
compilation_config=CompilationConfig(
|
||||||
|
mode=CompilationMode.VLLM_COMPILE,
|
||||||
|
),
|
||||||
|
max_model_len=256,
|
||||||
|
)
|
||||||
|
llm_model.generate("Hello, my name is")
|
||||||
|
|
||||||
|
assert create_symbol_counter.value == 0
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Restore original method
|
||||||
|
symbolic_shapes_module.make_symbol = original_make_symbol
|
||||||
|
|||||||
@@ -402,6 +402,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
|||||||
self.extra_traceback = False
|
self.extra_traceback = False
|
||||||
|
|
||||||
def run(self, *args):
|
def run(self, *args):
|
||||||
|
# maybe instead just assert inputs are fake?
|
||||||
fake_args = [
|
fake_args = [
|
||||||
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
|
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
|
||||||
for t in args
|
for t in args
|
||||||
@@ -416,11 +417,13 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
|||||||
kwargs: dict[str, Any],
|
kwargs: dict[str, Any],
|
||||||
) -> Any:
|
) -> Any:
|
||||||
assert isinstance(target, str)
|
assert isinstance(target, str)
|
||||||
|
|
||||||
output = super().call_module(target, args, kwargs)
|
output = super().call_module(target, args, kwargs)
|
||||||
|
|
||||||
if target in self.compile_submod_names:
|
if target in self.compile_submod_names:
|
||||||
index = self.compile_submod_names.index(target)
|
index = self.compile_submod_names.index(target)
|
||||||
submod = self.fetch_attr(target)
|
submod = self.fetch_attr(target)
|
||||||
|
|
||||||
sym_shape_indices = [
|
sym_shape_indices = [
|
||||||
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
|
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
|
||||||
]
|
]
|
||||||
@@ -746,11 +749,21 @@ class VllmBackend:
|
|||||||
if not item.is_splitting_graph
|
if not item.is_splitting_graph
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Extract fake values from the graph to use them when needed.
|
||||||
|
all_fake_values = []
|
||||||
|
for i in graph.graph.find_nodes(op="placeholder"):
|
||||||
|
all_fake_values.append(i.meta["example_value"])
|
||||||
|
|
||||||
|
fake_args = [
|
||||||
|
all_fake_values[i] if isinstance(t, torch.Tensor) else t
|
||||||
|
for i, t in enumerate(example_inputs)
|
||||||
|
]
|
||||||
|
|
||||||
# propagate the split graph to the piecewise backend,
|
# propagate the split graph to the piecewise backend,
|
||||||
# compile submodules with symbolic shapes
|
# compile submodules with symbolic shapes
|
||||||
PiecewiseCompileInterpreter(
|
PiecewiseCompileInterpreter(
|
||||||
self.split_gm, submod_names_to_compile, self.vllm_config, self
|
self.split_gm, submod_names_to_compile, self.vllm_config, self
|
||||||
).run(*example_inputs)
|
).run(*fake_args)
|
||||||
|
|
||||||
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
|
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
|
||||||
if not os.path.exists(graph_path):
|
if not os.path.exists(graph_path):
|
||||||
@@ -780,14 +793,7 @@ class VllmBackend:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# if we need to copy input buffers for cudagraph
|
# if we need to copy input buffers for cudagraph
|
||||||
from torch._guards import detect_fake_mode
|
#
|
||||||
|
|
||||||
fake_mode = detect_fake_mode()
|
|
||||||
fake_args = [
|
|
||||||
fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
|
|
||||||
for t in example_inputs
|
|
||||||
]
|
|
||||||
|
|
||||||
# index of tensors that have symbolic shapes (batch size)
|
# index of tensors that have symbolic shapes (batch size)
|
||||||
# for weights and static buffers, they will have concrete shapes.
|
# for weights and static buffers, they will have concrete shapes.
|
||||||
# symbolic shape only happens for input tensors.
|
# symbolic shape only happens for input tensors.
|
||||||
|
|||||||
@@ -433,7 +433,6 @@ def _support_torch_compile(
|
|||||||
return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)
|
return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)
|
||||||
|
|
||||||
# This is the path for the first compilation.
|
# This is the path for the first compilation.
|
||||||
|
|
||||||
# the first compilation needs to have dynamic shapes marked
|
# the first compilation needs to have dynamic shapes marked
|
||||||
_mark_dynamic_inputs(
|
_mark_dynamic_inputs(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user