[Bugfix] Add monkeypatch to prevent race condition from writing (#35420)

Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
This commit is contained in:
Lucas Kabela
2026-02-27 11:51:36 -08:00
committed by GitHub
parent 2decec9856
commit 234a65b781

View File

@@ -184,6 +184,47 @@ def is_compile_cache_enabled(
)
def _patch_standalone_compile_atomic_save() -> None:
"""Backport of pytorch/pytorch#162432 for torch < 2.10.0.
Patches CompiledArtifact.save() to use write_atomic for binary format,
preventing corrupt cache files when multiple processes compile
concurrently.
"""
from torch._inductor.codecache import write_atomic
from torch._inductor.standalone_compile import CompiledArtifact as cls
if getattr(cls.save, "_vllm_patched", False):
return
original_save = cls.save
def _save(
self: Any, *, path: str, format: Literal["binary", "unpacked"] = "binary"
) -> None:
if format != "binary":
return original_save(self, path=path, format=format)
from torch._dynamo.utils import dynamo_timed
from torch._inductor.codecache import torch_key
from torch.utils._appending_byte_serializer import BytesWriter
with dynamo_timed("CompiledArtifact.save"):
assert self._artifacts is not None
artifact_bytes, cache_info = self._artifacts
assert len(cache_info.aot_autograd_artifacts) == 1, cache_info
key = cache_info.aot_autograd_artifacts[0]
assert not os.path.isdir(path)
writer = BytesWriter()
writer.write_bytes(torch_key())
writer.write_str(key)
writer.write_bytes(artifact_bytes)
write_atomic(path, writer.to_bytes())
_save._vllm_patched = True # type: ignore[attr-defined]
cls.save = _save # type: ignore[assignment]
logger.debug("Patched %s.save for atomic writes (torch < 2.10)", cls.__name__)
class InductorStandaloneAdaptor(CompilerInterface):
"""
The adaptor for the Inductor compiler.
@@ -197,6 +238,8 @@ class InductorStandaloneAdaptor(CompilerInterface):
name = "inductor_standalone"
def __init__(self, save_format: Literal["binary", "unpacked"]) -> None:
if not is_torch_equal_or_newer("2.10.0"):
_patch_standalone_compile_atomic_save()
self.save_format = save_format
def compute_hash(self, vllm_config: VllmConfig) -> str: