use skip_all_guards_unsafe to drop global_state and torch_function_mode_stack guards instead of previous hacks (#36204)

Signed-off-by: Laith Sakka <lsakka@meta.com>
This commit is contained in:
Laith Sakka
2026-03-16 01:52:31 -07:00
committed by GitHub
parent 821eb80c0d
commit 52131f88d9

View File

@@ -10,7 +10,6 @@ from types import CodeType
from typing import Any, ParamSpec, TypeVar
import torch
import torch._C._dynamo.guards
import vllm.envs as envs
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
@@ -24,65 +23,23 @@ R = TypeVar("R")
P = ParamSpec("P")
def _noop_add_global_state_guard(
self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any
) -> None:
"""No-op to skip the GLOBAL_STATE guard entirely"""
pass
def _noop_add_torch_function_mode_stack_guard(
self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any
) -> None:
"""No-op to skip the TORCH_FUNCTION_MODE_STACK guard entirely"""
pass
@contextmanager
def _compilation_context() -> Generator[None, None, None]:
"""Context manager for compilation settings and patches.
"""Context manager for compilation settings.
This manager:
1. Sets higher dynamo cache limits for compilation. (Needed for
qwen2_5_vl see test_qwen2_5_vl_evs_functionality).
Generally a recompilation can happen whenever we use a new
backend instance in torch.compile.
2. Patches out add_global_state_guard to skip GLOBAL_STATE guards
3. Patches out add_torch_function_mode_stack_guard to skip
TORCH_FUNCTION_MODE_STACK guards.
4. Restores everything when compilation completes
This manager sets higher dynamo cache limits for compilation.
(Needed for qwen2_5_vl see test_qwen2_5_vl_evs_functionality).
Generally a recompilation can happen whenever we use a new
backend instance in torch.compile.
"""
# Save original values
original_global_state_guard = (
torch._C._dynamo.guards.GuardManager.add_global_state_guard
)
original_torch_function_mode_stack_guard = (
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard
)
original_cache_size = torch._dynamo.config.cache_size_limit
original_accumulated_cache = torch._dynamo.config.accumulated_cache_size_limit
try:
# Set higher cache limits for compilation
torch._dynamo.config.cache_size_limit = 2048
torch._dynamo.config.accumulated_cache_size_limit = 8192
# Patch guard manager
torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
_noop_add_global_state_guard
)
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
_noop_add_torch_function_mode_stack_guard
)
yield
finally:
# Restore original values
torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
original_global_state_guard
)
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
original_torch_function_mode_stack_guard
)
torch._dynamo.config.cache_size_limit = original_cache_size
torch._dynamo.config.accumulated_cache_size_limit = original_accumulated_cache
@@ -155,7 +112,7 @@ class TorchCompileWithNoGuardsWrapper:
entry.guard_type == "SHAPE_ENV" for entry in x
]
else:
options["guard_filter_fn"] = lambda x: [False for _ in x]
options["guard_filter_fn"] = torch.compiler.skip_all_guards_unsafe
compiled_ptr: Any = self.forward
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False