[Bugfix] Fix for builtins (forward fix of pytorch/177558) (#37234)
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
This commit is contained in:
@@ -31,6 +31,7 @@ CHECK_IMPORTS = {
|
||||
"vllm/transformers_utils/config.py",
|
||||
"vllm/model_executor/models/registry.py",
|
||||
"vllm/compilation/caching.py",
|
||||
"vllm/env_override.py",
|
||||
"vllm/compilation/piecewise_backend.py",
|
||||
"vllm/distributed/utils.py",
|
||||
"vllm/distributed/parallel_state.py",
|
||||
|
||||
@@ -87,7 +87,7 @@ _maybe_set_cuda_compatibility_path()
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.torch_utils import is_torch_equal
|
||||
from vllm.utils.torch_utils import is_torch_equal, is_torch_equal_or_newer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -490,3 +490,45 @@ if is_torch_equal("2.9.0"):
|
||||
|
||||
PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched
|
||||
GraphLowering._update_scheduler = _update_scheduler_patched
|
||||
|
||||
# ===================================================
|
||||
# torch <2.12 GraphCaptureOutput.get_runtime_env monkeypatch
|
||||
# ===================================================
|
||||
# PyTorch's AOT compile path omits builtins from used_globals, causing
|
||||
# 'Missing required external references' errors for refs like 'type'.
|
||||
# (which happens in transformers code)
|
||||
# This mirrors the fix in https://github.com/pytorch/pytorch/pull/177558
|
||||
# and can be removed once torch >=2.12 is the minimum supported version.
|
||||
|
||||
if not is_torch_equal_or_newer("2.12.0"):
|
||||
import builtins as _builtins
|
||||
import pickle
|
||||
|
||||
from torch._dynamo.convert_frame import GraphCaptureOutput
|
||||
|
||||
_original_get_runtime_env = GraphCaptureOutput.get_runtime_env
|
||||
|
||||
def _safe_builtins_dict(builtins_dict: dict) -> dict:
|
||||
"""Filter a builtins dict to only picklable entries for serialization."""
|
||||
result = {}
|
||||
for k, v in builtins_dict.items():
|
||||
try:
|
||||
pickle.dumps(v)
|
||||
result[k] = v
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
def _patched_get_runtime_env(self): # type: ignore[no-untyped-def]
|
||||
runtime_env = _original_get_runtime_env(self)
|
||||
for ref in runtime_env.external_refs:
|
||||
if ref not in runtime_env.used_globals:
|
||||
if ref.startswith("__builtins_dict__") and ref in self.f_globals:
|
||||
runtime_env.used_globals[ref] = _safe_builtins_dict(
|
||||
self.f_globals[ref]
|
||||
)
|
||||
elif hasattr(_builtins, ref):
|
||||
runtime_env.used_globals[ref] = getattr(_builtins, ref)
|
||||
return runtime_env
|
||||
|
||||
GraphCaptureOutput.get_runtime_env = _patched_get_runtime_env
|
||||
|
||||
Reference in New Issue
Block a user