[Bugfix] Add monkeypatch to prevent race condition from writing (#35420)
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
This commit is contained in:
@@ -184,6 +184,47 @@ def is_compile_cache_enabled(
|
||||
)
|
||||
|
||||
|
||||
def _patch_standalone_compile_atomic_save() -> None:
|
||||
"""Backport of pytorch/pytorch#162432 for torch < 2.10.0.
|
||||
|
||||
Patches CompiledArtifact.save() to use write_atomic for binary format,
|
||||
preventing corrupt cache files when multiple processes compile
|
||||
concurrently.
|
||||
"""
|
||||
from torch._inductor.codecache import write_atomic
|
||||
from torch._inductor.standalone_compile import CompiledArtifact as cls
|
||||
|
||||
if getattr(cls.save, "_vllm_patched", False):
|
||||
return
|
||||
|
||||
original_save = cls.save
|
||||
|
||||
def _save(
|
||||
self: Any, *, path: str, format: Literal["binary", "unpacked"] = "binary"
|
||||
) -> None:
|
||||
if format != "binary":
|
||||
return original_save(self, path=path, format=format)
|
||||
from torch._dynamo.utils import dynamo_timed
|
||||
from torch._inductor.codecache import torch_key
|
||||
from torch.utils._appending_byte_serializer import BytesWriter
|
||||
|
||||
with dynamo_timed("CompiledArtifact.save"):
|
||||
assert self._artifacts is not None
|
||||
artifact_bytes, cache_info = self._artifacts
|
||||
assert len(cache_info.aot_autograd_artifacts) == 1, cache_info
|
||||
key = cache_info.aot_autograd_artifacts[0]
|
||||
assert not os.path.isdir(path)
|
||||
writer = BytesWriter()
|
||||
writer.write_bytes(torch_key())
|
||||
writer.write_str(key)
|
||||
writer.write_bytes(artifact_bytes)
|
||||
write_atomic(path, writer.to_bytes())
|
||||
|
||||
_save._vllm_patched = True # type: ignore[attr-defined]
|
||||
cls.save = _save # type: ignore[assignment]
|
||||
logger.debug("Patched %s.save for atomic writes (torch < 2.10)", cls.__name__)
|
||||
|
||||
|
||||
class InductorStandaloneAdaptor(CompilerInterface):
|
||||
"""
|
||||
The adaptor for the Inductor compiler.
|
||||
@@ -197,6 +238,8 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
name = "inductor_standalone"
|
||||
|
||||
def __init__(self, save_format: Literal["binary", "unpacked"]) -> None:
|
||||
if not is_torch_equal_or_newer("2.10.0"):
|
||||
_patch_standalone_compile_atomic_save()
|
||||
self.save_format = save_format
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
|
||||
Reference in New Issue
Block a user