diff --git a/tests/distributed/test_shm_buffer.py b/tests/distributed/test_shm_buffer.py index 9fe409edc..e56c71b1c 100644 --- a/tests/distributed/test_shm_buffer.py +++ b/tests/distributed/test_shm_buffer.py @@ -22,7 +22,7 @@ class TestSingleWriterShmRingBuffer(unittest.TestCase): def tearDown(self): """Clean up after tests""" if self.ring_buffer: - del self.ring_buffer + self.ring_buffer.close() def test_buffer_opening(self): """Test opening an existing buffer""" diff --git a/tests/distributed/test_shm_storage.py b/tests/distributed/test_shm_storage.py index 9ab35a292..ea63f4a29 100644 --- a/tests/distributed/test_shm_storage.py +++ b/tests/distributed/test_shm_storage.py @@ -56,7 +56,7 @@ class TestSingleWriterShmObjectStorage(unittest.TestCase): def tearDown(self): """Clean up after each test.""" if self.storage: - del self.storage + self.storage.close() def test_minimal_put_get_cycle(self): """Test basic put and get operations.""" diff --git a/vllm/distributed/device_communicators/shm_object_storage.py b/vllm/distributed/device_communicators/shm_object_storage.py index 5da261fbc..3d6048052 100644 --- a/vllm/distributed/device_communicators/shm_object_storage.py +++ b/vllm/distributed/device_communicators/shm_object_storage.py @@ -4,7 +4,7 @@ import pickle from abc import ABC, abstractmethod from collections.abc import Callable, Iterable -from contextlib import contextmanager +from contextlib import contextmanager, suppress from dataclasses import dataclass from itertools import chain from multiprocessing import shared_memory @@ -126,6 +126,7 @@ class SingleWriterShmRingBuffer: self.data_buffer_end = 0 if create: + logger.debug("Creating new shared memory buffer: %s", name) # we are creating a buffer self.metadata: dict[int, int] = {} # monotonic_id -> start address self.shared_memory = shared_memory.SharedMemory( @@ -169,11 +170,16 @@ class SingleWriterShmRingBuffer: self.data_buffer_start = 0 self.data_buffer_end = 0 - def __del__(self): + def close(self) -> None: + """Close the shared memory.""" if hasattr(self, "shared_memory"): self.shared_memory.close() if self.is_writer: - self.shared_memory.unlink() + with suppress(FileNotFoundError): + self.shared_memory.unlink() + + def __del__(self): + self.close() def int2byte(self, integer: int) -> bytes: """Convert an integer to bytes.""" @@ -663,6 +669,10 @@ class SingleWriterShmObjectStorage: if reader_count >= self.n_readers: self.increment_reader_flag(data_view[: self.flag_bytes]) + def close(self) -> None: + """Close the shared memory.""" + self.ring_buffer.close() + def handle(self): """Get handle for sharing across processes.""" return ShmObjectStorageHandle( diff --git a/vllm/envs.py b/vllm/envs.py index f82dae108..ed7b0362b 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -7,6 +7,7 @@ import logging import os import sys import tempfile +import uuid from collections.abc import Callable from typing import TYPE_CHECKING, Any, Literal @@ -457,6 +458,27 @@ def get_vllm_port() -> int | None: raise ValueError(f"VLLM_PORT '{port}' must be a valid integer") from err +def get_env_or_set_default( + env_name: str, + default_factory: Callable[[], str], +) -> Callable[[], str]: + """ + Create a lambda that returns an environment variable value if set, + or generates and sets a default value using the provided factory function. + """ + + def _get_or_set_default() -> str: + value = os.getenv(env_name) + if value is not None: + return value + + default_value = default_factory() + os.environ[env_name] = default_value + return default_value + + return _get_or_set_default + + # The start-* and end* here are used by the documentation generator # to extract the used env vars. @@ -1558,8 +1580,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ), # Name of the shared memory buffer used for object storage. # Only effective when mm_config.mm_processor_cache_type == "shm". - "VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME": lambda: os.getenv( - "VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME", "VLLM_OBJECT_STORAGE_SHM_BUFFER" + # Automatically generates a unique UUID-based name per process tree + # if not explicitly set. + "VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME": get_env_or_set_default( + "VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME", + lambda: f"VLLM_OBJECT_STORAGE_SHM_BUFFER_{uuid.uuid4().hex}", ), # The size in MB of the buffers (NVL and RDMA) used by DeepEP "VLLM_DEEPEP_BUFFER_SIZE_MB": lambda: int( diff --git a/vllm/multimodal/cache.py b/vllm/multimodal/cache.py index cb17f7fdd..2a0f59099 100644 --- a/vllm/multimodal/cache.py +++ b/vllm/multimodal/cache.py @@ -295,6 +295,10 @@ class BaseMultiModalProcessorCache( """ return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes] + def close(self) -> None: + """Close the underlying cache, if needed.""" + pass + @abstractmethod def touch_sender_cache_item(self, mm_hash: str) -> None: """ @@ -534,6 +538,10 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache): def make_stats(self, *, delta: bool = False) -> CacheInfo: return self._stat(delta=delta) + @override + def close(self) -> None: + self._shm_cache.close() + def remove_dangling_items(self) -> None: """Remove items that are no longer in the shared memory cache.""" cached_hashes = self._shm_cache.key_index.keys() diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 855758d21..cceb51796 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -249,6 +249,9 @@ class AsyncLLM(EngineClient): if engine_core := getattr(self, "engine_core", None): engine_core.shutdown() + if input_processor := getattr(self, "input_processor", None): + input_processor.close() + handler = getattr(self, "output_handler", None) if handler is not None: cancel_task_threadsafe(handler) diff --git a/vllm/v1/engine/input_processor.py b/vllm/v1/engine/input_processor.py index 318aa51ce..06a8c4b69 100644 --- a/vllm/v1/engine/input_processor.py +++ b/vllm/v1/engine/input_processor.py @@ -712,3 +712,7 @@ class InputProcessor: def clear_mm_cache(self) -> None: self.input_preprocessor.clear_mm_cache() + + def close(self) -> None: + if self.mm_processor_cache is not None: + self.mm_processor_cache.close()