[Bugfix] Fix for builtins (forward fix of pytorch/177558) (#37234)

Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
This commit is contained in:
Lucas Kabela
2026-03-30 18:08:11 -07:00
committed by GitHub
parent 29e48707e8
commit e31915063d
2 changed files with 44 additions and 1 deletions

View File

@@ -31,6 +31,7 @@ CHECK_IMPORTS = {
"vllm/transformers_utils/config.py",
"vllm/model_executor/models/registry.py",
"vllm/compilation/caching.py",
"vllm/env_override.py",
"vllm/compilation/piecewise_backend.py",
"vllm/distributed/utils.py",
"vllm/distributed/parallel_state.py",

View File

@@ -87,7 +87,7 @@ _maybe_set_cuda_compatibility_path()
import torch
from vllm.logger import init_logger
from vllm.utils.torch_utils import is_torch_equal
from vllm.utils.torch_utils import is_torch_equal, is_torch_equal_or_newer
logger = init_logger(__name__)
@@ -490,3 +490,45 @@ if is_torch_equal("2.9.0"):
PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched
GraphLowering._update_scheduler = _update_scheduler_patched
# ===================================================
# torch <2.12 GraphCaptureOutput.get_runtime_env monkeypatch
# ===================================================
# PyTorch's AOT compile path omits builtins from used_globals, causing
# 'Missing required external references' errors for refs like 'type'.
# (which happens in transformers code)
# This mirrors the fix in https://github.com/pytorch/pytorch/pull/177558
# and can be removed once torch >=2.12 is the minimum supported version.
if not is_torch_equal_or_newer("2.12.0"):
import builtins as _builtins
import pickle
from torch._dynamo.convert_frame import GraphCaptureOutput
_original_get_runtime_env = GraphCaptureOutput.get_runtime_env
def _safe_builtins_dict(builtins_dict: dict) -> dict:
"""Filter a builtins dict to only picklable entries for serialization."""
result = {}
for k, v in builtins_dict.items():
try:
pickle.dumps(v)
result[k] = v
except Exception:
pass
return result
def _patched_get_runtime_env(self): # type: ignore[no-untyped-def]
runtime_env = _original_get_runtime_env(self)
for ref in runtime_env.external_refs:
if ref not in runtime_env.used_globals:
if ref.startswith("__builtins_dict__") and ref in self.f_globals:
runtime_env.used_globals[ref] = _safe_builtins_dict(
self.f_globals[ref]
)
elif hasattr(_builtins, ref):
runtime_env.used_globals[ref] = getattr(_builtins, ref)
return runtime_env
GraphCaptureOutput.get_runtime_env = _patched_get_runtime_env