Add evaluate_guards option to DynamicShapesConfig (#27432)
Signed-off-by: Laith Sakka <lsakka@meta.com>
This commit is contained in:
@@ -26,6 +26,7 @@ from vllm.compilation.partition_rules import (
|
||||
should_split,
|
||||
)
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||
from vllm.config.compilation import DynamicShapesType
|
||||
from vllm.config.utils import Range, hash_factors
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logging_utils import lazy
|
||||
@@ -722,6 +723,29 @@ class VllmBackend:
|
||||
self.split_gm, submod_names_to_compile, self.vllm_config, self
|
||||
).run(*fake_args)
|
||||
|
||||
from torch._guards import detect_fake_mode
|
||||
|
||||
fake_mode = detect_fake_mode()
|
||||
|
||||
if (
|
||||
self.compilation_config.dynamic_shapes_config.evaluate_guards
|
||||
and self.compilation_config.dynamic_shapes_config.type
|
||||
== DynamicShapesType.BACKED
|
||||
):
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
# Drop counter-0/1 specializations guards; for backed dynamic shapes,
|
||||
# torch.compile will specialize for 0/1 inputs or otherwise guards that
|
||||
# shape is >= 2. This is because it's really hard not to hit a check
|
||||
# against 0/1. When we evaluate shape guards, we exclude checking those
|
||||
# guards (We would fail always otherwise).
|
||||
|
||||
# We avoid that by updating the ranges of backed sizes when the min is
|
||||
# 2 for any, we assume it's 0.
|
||||
for s, r in fake_mode.shape_env.var_to_range.items():
|
||||
if r.lower == 2:
|
||||
fake_mode.shape_env.var_to_range[s] = ValueRanges(0, r.upper)
|
||||
|
||||
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
|
||||
if not os.path.exists(graph_path):
|
||||
# code adapted from
|
||||
@@ -749,8 +773,6 @@ class VllmBackend:
|
||||
graph, example_inputs, self.prefix, self.split_gm
|
||||
)
|
||||
|
||||
# if we need to copy input buffers for cudagraph
|
||||
#
|
||||
# index of tensors that have symbolic shapes (batch size)
|
||||
# for weights and static buffers, they will have concrete shapes.
|
||||
# symbolic shape only happens for input tensors.
|
||||
|
||||
@@ -392,7 +392,6 @@ def _support_torch_compile(
|
||||
|
||||
factors.append(_model_hash_key(self.forward))
|
||||
hash_key = hashlib.sha256(str(factors).encode()).hexdigest()
|
||||
|
||||
cache_dir = os.path.join(
|
||||
envs.VLLM_CACHE_ROOT,
|
||||
"torch_aot_compile",
|
||||
@@ -413,7 +412,8 @@ def _support_torch_compile(
|
||||
f, f_globals=self.forward.__globals__
|
||||
)
|
||||
_verify_source_unchanged(loaded_fn.source_info(), self.vllm_config)
|
||||
loaded_fn.disable_guard_check()
|
||||
if not self.compilation_config.dynamic_shapes_config.evaluate_guards:
|
||||
loaded_fn.disable_guard_check()
|
||||
self.aot_compiled_fn = loaded_fn
|
||||
except Exception as e:
|
||||
if os.path.exists(aot_compilation_path):
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import os
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from types import CodeType
|
||||
from typing import Any
|
||||
|
||||
@@ -13,6 +13,7 @@ import torch._C._dynamo.guards
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
|
||||
from vllm.config.compilation import DynamicShapesType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context
|
||||
|
||||
@@ -125,23 +126,49 @@ class TorchCompileWithNoGuardsWrapper:
|
||||
if isinstance(backend, str) and backend == "inductor":
|
||||
options = vllm_config.compilation_config.inductor_compile_config
|
||||
|
||||
if mode != CompilationMode.STOCK_TORCH_COMPILE:
|
||||
# Drop all the guards.
|
||||
options["guard_filter_fn"] = lambda x: [False for _ in x]
|
||||
|
||||
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
|
||||
from vllm.compilation.decorators import DynamicShapesType
|
||||
self.first_compile = True
|
||||
self.evaluate_guards = (
|
||||
vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards
|
||||
)
|
||||
|
||||
ds_type = vllm_config.compilation_config.dynamic_shapes_config.type
|
||||
compiled_ptr: Any = self.forward
|
||||
if ds_type == DynamicShapesType.UNBACKED:
|
||||
if envs.VLLM_USE_BYTECODE_HOOK:
|
||||
# reason is that bytecode does this hack torch._dynamo.eval_frame.
|
||||
# remove_from_cache(self.original_code_object()) to force a new
|
||||
# re-compilation.
|
||||
raise ValueError(
|
||||
"UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0. "
|
||||
|
||||
if mode != CompilationMode.STOCK_TORCH_COMPILE:
|
||||
# Drop all the guards.
|
||||
if self.evaluate_guards:
|
||||
assert not envs.VLLM_USE_BYTECODE_HOOK, (
|
||||
"compilation_config.dynamic_shapes_config.evaluate_guards "
|
||||
"requires VLLM_USE_BYTECODE_HOOK=0. "
|
||||
)
|
||||
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
# disabled until https://github.com/pytorch/pytorch/pull/169239
|
||||
# is picked up.
|
||||
assert ds_type != DynamicShapesType.BACKED, (
|
||||
"evaluate_guards for backed shapes requires "
|
||||
"VLLM_USE_AOT_COMPILE=False. "
|
||||
)
|
||||
|
||||
options["guard_filter_fn"] = lambda x: [
|
||||
entry.guard_type == "SHAPE_ENV" for entry in x
|
||||
]
|
||||
else:
|
||||
options["guard_filter_fn"] = lambda x: [False for _ in x]
|
||||
|
||||
compiled_ptr: Any = self.forward
|
||||
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
|
||||
|
||||
if ds_type == DynamicShapesType.UNBACKED:
|
||||
# reason is that bytecode does torch._dynamo.eval_frame.
|
||||
# remove_from_cache(self.original_code_object()) to force a new
|
||||
# re-compilation. And if we use
|
||||
# compiled_ptr = self.check_invariants_and_forward
|
||||
# it will reset all entries.
|
||||
assert not envs.VLLM_USE_BYTECODE_HOOK, (
|
||||
"UNBACKED dynamic shapes requires VLLM_USE_BYTECODE_HOOK=0. "
|
||||
)
|
||||
assert not self.evaluate_guards, "UNBACKED dynamic shapes do not add guards"
|
||||
|
||||
compiled_ptr = self.check_invariants_and_forward
|
||||
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
@@ -195,7 +222,13 @@ class TorchCompileWithNoGuardsWrapper:
|
||||
self.forward, *args, **kwargs
|
||||
)
|
||||
else:
|
||||
with _compilation_context():
|
||||
ctx = (
|
||||
nullcontext()
|
||||
if self.first_compile or not self.evaluate_guards
|
||||
else torch.compiler.set_stance("fail_on_recompile")
|
||||
)
|
||||
self.first_compile = False
|
||||
with _compilation_context(), ctx:
|
||||
return self._call_with_optional_nvtx_range(
|
||||
self._compiled_callable, *args, **kwargs
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user