use skip_all_guards_unsafe to drop global_state and torch_function_mode_stack guards instead of previous hacks (#36204)
Signed-off-by: Laith Sakka <lsakka@meta.com>
This commit is contained in:
@@ -10,7 +10,6 @@ from types import CodeType
|
|||||||
from typing import Any, ParamSpec, TypeVar
|
from typing import Any, ParamSpec, TypeVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._C._dynamo.guards
|
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
|
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
|
||||||
@@ -24,65 +23,23 @@ R = TypeVar("R")
|
|||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
|
|
||||||
|
|
||||||
def _noop_add_global_state_guard(
|
|
||||||
self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""No-op to skip the GLOBAL_STATE guard entirely"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def _noop_add_torch_function_mode_stack_guard(
|
|
||||||
self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""No-op to skip the TORCH_FUNCTION_MODE_STACK guard entirely"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _compilation_context() -> Generator[None, None, None]:
|
def _compilation_context() -> Generator[None, None, None]:
|
||||||
"""Context manager for compilation settings and patches.
|
"""Context manager for compilation settings.
|
||||||
|
|
||||||
This manager:
|
This manager sets higher dynamo cache limits for compilation.
|
||||||
1. Sets higher dynamo cache limits for compilation. (Needed for
|
(Needed for qwen2_5_vl see test_qwen2_5_vl_evs_functionality).
|
||||||
qwen2_5_vl see test_qwen2_5_vl_evs_functionality).
|
Generally a recompilation can happen whenever we use a new
|
||||||
Generally a recompilation can happen whenever we use a new
|
backend instance in torch.compile.
|
||||||
backend instance in torch.compile.
|
|
||||||
2. Patches out add_global_state_guard to skip GLOBAL_STATE guards
|
|
||||||
3. Patches out add_torch_function_mode_stack_guard to skip
|
|
||||||
TORCH_FUNCTION_MODE_STACK guards.
|
|
||||||
4. Restores everything when compilation completes
|
|
||||||
"""
|
"""
|
||||||
# Save original values
|
|
||||||
original_global_state_guard = (
|
|
||||||
torch._C._dynamo.guards.GuardManager.add_global_state_guard
|
|
||||||
)
|
|
||||||
original_torch_function_mode_stack_guard = (
|
|
||||||
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard
|
|
||||||
)
|
|
||||||
original_cache_size = torch._dynamo.config.cache_size_limit
|
original_cache_size = torch._dynamo.config.cache_size_limit
|
||||||
original_accumulated_cache = torch._dynamo.config.accumulated_cache_size_limit
|
original_accumulated_cache = torch._dynamo.config.accumulated_cache_size_limit
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Set higher cache limits for compilation
|
|
||||||
torch._dynamo.config.cache_size_limit = 2048
|
torch._dynamo.config.cache_size_limit = 2048
|
||||||
torch._dynamo.config.accumulated_cache_size_limit = 8192
|
torch._dynamo.config.accumulated_cache_size_limit = 8192
|
||||||
|
|
||||||
# Patch guard manager
|
|
||||||
torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
|
|
||||||
_noop_add_global_state_guard
|
|
||||||
)
|
|
||||||
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
|
|
||||||
_noop_add_torch_function_mode_stack_guard
|
|
||||||
)
|
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
# Restore original values
|
|
||||||
torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
|
|
||||||
original_global_state_guard
|
|
||||||
)
|
|
||||||
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
|
|
||||||
original_torch_function_mode_stack_guard
|
|
||||||
)
|
|
||||||
torch._dynamo.config.cache_size_limit = original_cache_size
|
torch._dynamo.config.cache_size_limit = original_cache_size
|
||||||
torch._dynamo.config.accumulated_cache_size_limit = original_accumulated_cache
|
torch._dynamo.config.accumulated_cache_size_limit = original_accumulated_cache
|
||||||
|
|
||||||
@@ -155,7 +112,7 @@ class TorchCompileWithNoGuardsWrapper:
|
|||||||
entry.guard_type == "SHAPE_ENV" for entry in x
|
entry.guard_type == "SHAPE_ENV" for entry in x
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
options["guard_filter_fn"] = lambda x: [False for _ in x]
|
options["guard_filter_fn"] = torch.compiler.skip_all_guards_unsafe
|
||||||
|
|
||||||
compiled_ptr: Any = self.forward
|
compiled_ptr: Any = self.forward
|
||||||
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
|
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
|
||||||
|
|||||||
Reference in New Issue
Block a user