Add evaluate_guards option to DynamicShapesConfig (#27432)

Signed-off-by: Laith Sakka <lsakka@meta.com>
This commit is contained in:
Laith Sakka
2025-12-08 07:46:15 -08:00
committed by GitHub
parent 184076c3fe
commit 87aee9ed2b
6 changed files with 222 additions and 31 deletions

View File

@@ -26,6 +26,7 @@ from vllm.compilation.partition_rules import (
should_split,
)
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.config.compilation import DynamicShapesType
from vllm.config.utils import Range, hash_factors
from vllm.logger import init_logger
from vllm.logging_utils import lazy
@@ -722,6 +723,29 @@ class VllmBackend:
self.split_gm, submod_names_to_compile, self.vllm_config, self
).run(*fake_args)
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode()
if (
self.compilation_config.dynamic_shapes_config.evaluate_guards
and self.compilation_config.dynamic_shapes_config.type
== DynamicShapesType.BACKED
):
from torch.utils._sympy.value_ranges import ValueRanges
# Drop counter-0/1 specializations guards; for backed dynamic shapes,
# torch.compile will specialize for 0/1 inputs or otherwise guards that
# shape is >= 2. This is because it's really hard not to hit a check
# against 0/1. When we evaluate shape guards, we exclude checking those
# guards (We would fail always otherwise).
# We avoid that by updating the ranges of backed sizes when the min is
# 2 for any, we assume it's 0.
for s, r in fake_mode.shape_env.var_to_range.items():
if r.lower == 2:
fake_mode.shape_env.var_to_range[s] = ValueRanges(0, r.upper)
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
if not os.path.exists(graph_path):
# code adapted from
@@ -749,8 +773,6 @@ class VllmBackend:
graph, example_inputs, self.prefix, self.split_gm
)
# if we need to copy input buffers for cudagraph
#
# index of tensors that have symbolic shapes (batch size)
# for weights and static buffers, they will have concrete shapes.
# symbolic shape only happens for input tensors.

View File

@@ -392,7 +392,6 @@ def _support_torch_compile(
factors.append(_model_hash_key(self.forward))
hash_key = hashlib.sha256(str(factors).encode()).hexdigest()
cache_dir = os.path.join(
envs.VLLM_CACHE_ROOT,
"torch_aot_compile",
@@ -413,7 +412,8 @@ def _support_torch_compile(
f, f_globals=self.forward.__globals__
)
_verify_source_unchanged(loaded_fn.source_info(), self.vllm_config)
loaded_fn.disable_guard_check()
if not self.compilation_config.dynamic_shapes_config.evaluate_guards:
loaded_fn.disable_guard_check()
self.aot_compiled_fn = loaded_fn
except Exception as e:
if os.path.exists(aot_compilation_path):

View File

@@ -4,7 +4,7 @@
import os
import sys
from abc import abstractmethod
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from types import CodeType
from typing import Any
@@ -13,6 +13,7 @@ import torch._C._dynamo.guards
import vllm.envs as envs
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
from vllm.config.compilation import DynamicShapesType
from vllm.logger import init_logger
from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context
@@ -125,23 +126,49 @@ class TorchCompileWithNoGuardsWrapper:
if isinstance(backend, str) and backend == "inductor":
options = vllm_config.compilation_config.inductor_compile_config
if mode != CompilationMode.STOCK_TORCH_COMPILE:
# Drop all the guards.
options["guard_filter_fn"] = lambda x: [False for _ in x]
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
from vllm.compilation.decorators import DynamicShapesType
self.first_compile = True
self.evaluate_guards = (
vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards
)
ds_type = vllm_config.compilation_config.dynamic_shapes_config.type
compiled_ptr: Any = self.forward
if ds_type == DynamicShapesType.UNBACKED:
if envs.VLLM_USE_BYTECODE_HOOK:
# reason is that bytecode does this hack torch._dynamo.eval_frame.
# remove_from_cache(self.original_code_object()) to force a new
# re-compilation.
raise ValueError(
"UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0. "
if mode != CompilationMode.STOCK_TORCH_COMPILE:
# Drop all the guards.
if self.evaluate_guards:
assert not envs.VLLM_USE_BYTECODE_HOOK, (
"compilation_config.dynamic_shapes_config.evaluate_guards "
"requires VLLM_USE_BYTECODE_HOOK=0. "
)
if envs.VLLM_USE_AOT_COMPILE:
# disabled until https://github.com/pytorch/pytorch/pull/169239
# is picked up.
assert ds_type != DynamicShapesType.BACKED, (
"evaluate_guards for backed shapes requires "
"VLLM_USE_AOT_COMPILE=False. "
)
options["guard_filter_fn"] = lambda x: [
entry.guard_type == "SHAPE_ENV" for entry in x
]
else:
options["guard_filter_fn"] = lambda x: [False for _ in x]
compiled_ptr: Any = self.forward
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
if ds_type == DynamicShapesType.UNBACKED:
# reason is that bytecode does torch._dynamo.eval_frame.
# remove_from_cache(self.original_code_object()) to force a new
# re-compilation. And if we use
# compiled_ptr = self.check_invariants_and_forward
# it will reset all entries.
assert not envs.VLLM_USE_BYTECODE_HOOK, (
"UNBACKED dynamic shapes requires VLLM_USE_BYTECODE_HOOK=0. "
)
assert not self.evaluate_guards, "UNBACKED dynamic shapes do not add guards"
compiled_ptr = self.check_invariants_and_forward
if envs.VLLM_USE_AOT_COMPILE:
@@ -195,7 +222,13 @@ class TorchCompileWithNoGuardsWrapper:
self.forward, *args, **kwargs
)
else:
with _compilation_context():
ctx = (
nullcontext()
if self.first_compile or not self.evaluate_guards
else torch.compiler.set_stance("fail_on_recompile")
)
self.first_compile = False
with _compilation_context(), ctx:
return self._call_with_optional_nvtx_range(
self._compiled_callable, *args, **kwargs
)