diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index c00486af6..e021ce9e3 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -184,6 +184,47 @@ def is_compile_cache_enabled( ) +def _patch_standalone_compile_atomic_save() -> None: + """Backport of pytorch/pytorch#162432 for torch < 2.10.0. + + Patches CompiledArtifact.save() to use write_atomic for binary format, + preventing corrupt cache files when multiple processes compile + concurrently. + """ + from torch._inductor.codecache import write_atomic + from torch._inductor.standalone_compile import CompiledArtifact as cls + + if getattr(cls.save, "_vllm_patched", False): + return + + original_save = cls.save + + def _save( + self: Any, *, path: str, format: Literal["binary", "unpacked"] = "binary" + ) -> None: + if format != "binary": + return original_save(self, path=path, format=format) + from torch._dynamo.utils import dynamo_timed + from torch._inductor.codecache import torch_key + from torch.utils._appending_byte_serializer import BytesWriter + + with dynamo_timed("CompiledArtifact.save"): + assert self._artifacts is not None + artifact_bytes, cache_info = self._artifacts + assert len(cache_info.aot_autograd_artifacts) == 1, cache_info + key = cache_info.aot_autograd_artifacts[0] + assert not os.path.isdir(path) + writer = BytesWriter() + writer.write_bytes(torch_key()) + writer.write_str(key) + writer.write_bytes(artifact_bytes) + write_atomic(path, writer.to_bytes()) + + _save._vllm_patched = True # type: ignore[attr-defined] + cls.save = _save # type: ignore[assignment] + logger.debug("Patched %s.save for atomic writes (torch < 2.10)", cls.__name__) + + class InductorStandaloneAdaptor(CompilerInterface): """ The adaptor for the Inductor compiler. @@ -197,6 +238,8 @@ class InductorStandaloneAdaptor(CompilerInterface): name = "inductor_standalone" def __init__(self, save_format: Literal["binary", "unpacked"]) -> None: + if not is_torch_equal_or_newer("2.10.0"): + _patch_standalone_compile_atomic_save() self.save_format = save_format def compute_hash(self, vllm_config: VllmConfig) -> str: