[Core][CI] Add opt-in media URL caching via VLLM_MEDIA_CACHE (#37123)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -4,6 +4,8 @@
|
||||
import asyncio
|
||||
import mimetypes
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
||||
|
||||
import aiohttp
|
||||
@@ -375,3 +377,113 @@ async def test_ssrf_bypass_backslash_disallowed_domain():
|
||||
|
||||
with pytest.raises(ValueError, match="allowed domains"):
|
||||
await connector.fetch_image_async(bypass_url)
|
||||
|
||||
|
||||
def _make_cached_connector(cache_dir, *, max_mb=10, ttl_hours=24):
|
||||
"""Create a MediaConnector with caching enabled via monkeypatched internals.
|
||||
|
||||
We bypass __init__'s env-var path and wire up the cache fields directly
|
||||
so tests don't depend on environment variables. URLs in these tests are
|
||||
only used as cache keys (hashed to derive filenames); no HTTP requests
|
||||
are made.
|
||||
"""
|
||||
connector = MediaConnector()
|
||||
connector._media_cache_dir = cache_dir
|
||||
connector._media_cache_max_bytes = max_mb * 1024 * 1024
|
||||
connector._media_cache_ttl_secs = ttl_hours * 3600
|
||||
return connector
|
||||
|
||||
|
||||
def test_cache_put_and_get():
|
||||
"""Basic round-trip: put bytes, get them back."""
|
||||
with TemporaryDirectory() as cache_dir:
|
||||
connector = _make_cached_connector(cache_dir)
|
||||
url = "https://example.com/image.png"
|
||||
data = b"fake-image-bytes"
|
||||
|
||||
connector._put_cached_bytes(url, data)
|
||||
cached = connector._get_cached_bytes(url)
|
||||
assert cached == data
|
||||
|
||||
|
||||
def test_cache_ttl_expiry():
|
||||
"""Entries older than TTL are evicted on read."""
|
||||
with TemporaryDirectory() as cache_dir:
|
||||
connector = _make_cached_connector(cache_dir, ttl_hours=24)
|
||||
url = "https://example.com/old.png"
|
||||
data = b"old-data"
|
||||
|
||||
connector._put_cached_bytes(url, data)
|
||||
|
||||
# Backdate the file's mtime so it appears expired
|
||||
cache_path = connector._media_cache_path(url)
|
||||
expired_time = time.time() - (25 * 3600) # 25 hours ago
|
||||
os.utime(cache_path, (expired_time, expired_time))
|
||||
|
||||
assert connector._get_cached_bytes(url) is None
|
||||
assert not cache_path.exists()
|
||||
|
||||
|
||||
def test_cache_lru_eviction():
|
||||
"""Oldest entries are evicted when cache exceeds size budget."""
|
||||
with TemporaryDirectory() as cache_dir:
|
||||
# Set a very small max size: 100 bytes
|
||||
connector = _make_cached_connector(cache_dir, max_mb=0)
|
||||
connector._media_cache_max_bytes = 100
|
||||
|
||||
# Write three 50-byte entries (total 150 > 100 budget)
|
||||
urls = [f"https://example.com/{i}.png" for i in range(3)]
|
||||
for i, url in enumerate(urls):
|
||||
connector._put_cached_bytes(url, b"x" * 50)
|
||||
# Stagger mtime so eviction order is deterministic
|
||||
path = connector._media_cache_path(url)
|
||||
os.utime(path, (time.time() + i, time.time() + i))
|
||||
|
||||
# The oldest entry (urls[0]) should have been evicted
|
||||
assert connector._get_cached_bytes(urls[0]) is None
|
||||
# The newest entries should still be present
|
||||
assert connector._get_cached_bytes(urls[2]) == b"x" * 50
|
||||
|
||||
|
||||
def test_cache_ttl_eviction_during_write():
|
||||
"""_maybe_evict removes expired files even if under size budget."""
|
||||
with TemporaryDirectory() as cache_dir:
|
||||
connector = _make_cached_connector(cache_dir, ttl_hours=1)
|
||||
url_old = "https://example.com/stale.png"
|
||||
url_new = "https://example.com/fresh.png"
|
||||
|
||||
connector._put_cached_bytes(url_old, b"stale")
|
||||
# Backdate old entry past TTL
|
||||
old_path = connector._media_cache_path(url_old)
|
||||
expired_time = time.time() - (2 * 3600)
|
||||
os.utime(old_path, (expired_time, expired_time))
|
||||
|
||||
# Writing a new entry triggers _maybe_evict
|
||||
connector._put_cached_bytes(url_new, b"fresh")
|
||||
|
||||
assert not old_path.exists()
|
||||
assert connector._get_cached_bytes(url_new) == b"fresh"
|
||||
|
||||
|
||||
def test_put_cached_bytes_missing_dir():
|
||||
"""_put_cached_bytes does not crash when the cache dir disappears."""
|
||||
with TemporaryDirectory() as cache_dir:
|
||||
connector = _make_cached_connector(cache_dir)
|
||||
# Remove the directory to simulate it disappearing at runtime
|
||||
shutil.rmtree(cache_dir)
|
||||
|
||||
# Should not raise (graceful degradation)
|
||||
connector._put_cached_bytes("https://example.com/x.png", b"data")
|
||||
|
||||
|
||||
def test_get_cached_bytes_file_deleted_before_read():
|
||||
"""_get_cached_bytes returns None if the file vanishes mid-read."""
|
||||
with TemporaryDirectory() as cache_dir:
|
||||
connector = _make_cached_connector(cache_dir)
|
||||
url = "https://example.com/vanish.png"
|
||||
|
||||
connector._put_cached_bytes(url, b"data")
|
||||
# Delete the file to simulate concurrent eviction
|
||||
connector._media_cache_path(url).unlink()
|
||||
|
||||
assert connector._get_cached_bytes(url) is None
|
||||
|
||||
19
vllm/envs.py
19
vllm/envs.py
@@ -64,6 +64,9 @@ if TYPE_CHECKING:
|
||||
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
|
||||
VLLM_VIDEO_FETCH_TIMEOUT: int = 30
|
||||
VLLM_AUDIO_FETCH_TIMEOUT: int = 10
|
||||
VLLM_MEDIA_CACHE: str = ""
|
||||
VLLM_MEDIA_CACHE_MAX_SIZE_MB: int = 5120
|
||||
VLLM_MEDIA_CACHE_TTL_HOURS: float = 24
|
||||
VLLM_MEDIA_FETCH_MAX_RETRIES: int = 3
|
||||
VLLM_MEDIA_URL_ALLOW_REDIRECTS: bool = True
|
||||
VLLM_MEDIA_LOADING_THREAD_COUNT: int = 8
|
||||
@@ -776,6 +779,19 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_AUDIO_FETCH_TIMEOUT": lambda: int(
|
||||
os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")
|
||||
),
|
||||
# Directory for caching media downloads (images, video, audio fetched
|
||||
# from URLs during inference). Empty string disables caching.
|
||||
"VLLM_MEDIA_CACHE": lambda: os.getenv("VLLM_MEDIA_CACHE", ""),
|
||||
# Maximum cache size in MB. When exceeded, least-recently-used entries
|
||||
# are evicted. Default is 5120 (5 GB).
|
||||
"VLLM_MEDIA_CACHE_MAX_SIZE_MB": lambda: int(
|
||||
os.getenv("VLLM_MEDIA_CACHE_MAX_SIZE_MB", "5120")
|
||||
),
|
||||
# Time-to-live in hours for cached media files. Entries older than this
|
||||
# are evicted regardless of cache size. Default is 24 hours.
|
||||
"VLLM_MEDIA_CACHE_TTL_HOURS": lambda: float(
|
||||
os.getenv("VLLM_MEDIA_CACHE_TTL_HOURS", "24")
|
||||
),
|
||||
# Maximum number of retries for fetching media (images, audio, video)
|
||||
# from URLs. Each retry quadruples the timeout. Default is 3.
|
||||
"VLLM_MEDIA_FETCH_MAX_RETRIES": lambda: int(
|
||||
@@ -1777,6 +1793,9 @@ def compile_factors() -> dict[str, object]:
|
||||
"VLLM_IMAGE_FETCH_TIMEOUT",
|
||||
"VLLM_VIDEO_FETCH_TIMEOUT",
|
||||
"VLLM_AUDIO_FETCH_TIMEOUT",
|
||||
"VLLM_MEDIA_CACHE",
|
||||
"VLLM_MEDIA_CACHE_MAX_SIZE_MB",
|
||||
"VLLM_MEDIA_CACHE_TTL_HOURS",
|
||||
"VLLM_MEDIA_FETCH_MAX_RETRIES",
|
||||
"VLLM_MEDIA_URL_ALLOW_REDIRECTS",
|
||||
"VLLM_MEDIA_LOADING_THREAD_COUNT",
|
||||
|
||||
@@ -3,6 +3,11 @@
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import contextlib
|
||||
import hashlib
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar
|
||||
@@ -16,6 +21,7 @@ from urllib3.util import Url, parse_url
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.connections import HTTPConnection, global_http_connection
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.registry import ExtensionManager
|
||||
|
||||
from .audio import AudioEmbeddingMediaIO, AudioMediaIO
|
||||
@@ -23,6 +29,8 @@ from .base import MediaIO
|
||||
from .image import ImageEmbeddingMediaIO, ImageMediaIO
|
||||
from .video import VideoMediaIO
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_M = TypeVar("_M")
|
||||
|
||||
global_thread_pool = ThreadPoolExecutor(
|
||||
@@ -116,6 +124,115 @@ class MediaConnector:
|
||||
allowed_media_domains = []
|
||||
self.allowed_media_domains = allowed_media_domains
|
||||
|
||||
# Media download cache (opt-in via VLLM_MEDIA_CACHE)
|
||||
self._media_cache_dir: str | None = None
|
||||
self._media_cache_max_bytes: int = 0
|
||||
self._media_cache_ttl_secs: float = 0
|
||||
media_cache = envs.VLLM_MEDIA_CACHE
|
||||
if media_cache:
|
||||
try:
|
||||
os.makedirs(media_cache, exist_ok=True)
|
||||
# Verify the directory is writable before enabling caching
|
||||
with tempfile.NamedTemporaryFile(dir=media_cache, delete=True):
|
||||
pass
|
||||
self._media_cache_dir = media_cache
|
||||
self._media_cache_max_bytes = (
|
||||
envs.VLLM_MEDIA_CACHE_MAX_SIZE_MB * 1024 * 1024
|
||||
)
|
||||
self._media_cache_ttl_secs = envs.VLLM_MEDIA_CACHE_TTL_HOURS * 3600
|
||||
logger.info(
|
||||
"Media cache enabled at %s (max %d MB, TTL %s hours)",
|
||||
media_cache,
|
||||
envs.VLLM_MEDIA_CACHE_MAX_SIZE_MB,
|
||||
envs.VLLM_MEDIA_CACHE_TTL_HOURS,
|
||||
)
|
||||
except OSError:
|
||||
logger.warning(
|
||||
"VLLM_MEDIA_CACHE path %s is not writable, media caching disabled",
|
||||
media_cache,
|
||||
)
|
||||
|
||||
def _get_cached_bytes(self, url: str) -> bytes | None:
|
||||
"""Return cached bytes for a URL, or None if not cached/expired."""
|
||||
if not self._media_cache_dir:
|
||||
return None
|
||||
cache_path = self._media_cache_path(url)
|
||||
# Check TTL
|
||||
try:
|
||||
age = time.time() - cache_path.stat().st_mtime
|
||||
except OSError:
|
||||
return None
|
||||
if age > self._media_cache_ttl_secs:
|
||||
cache_path.unlink(missing_ok=True)
|
||||
return None
|
||||
# Touch mtime for LRU ordering
|
||||
try:
|
||||
cache_path.touch()
|
||||
return cache_path.read_bytes()
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
def _put_cached_bytes(self, url: str, data: bytes) -> None:
|
||||
"""Store downloaded bytes and evict if over budget."""
|
||||
if not self._media_cache_dir:
|
||||
return
|
||||
cache_path = self._media_cache_path(url)
|
||||
# Atomic write via temp file + rename
|
||||
tmp_path = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="wb", dir=self._media_cache_dir, delete=False
|
||||
) as tmp_file:
|
||||
tmp_file.write(data)
|
||||
tmp_path = tmp_file.name
|
||||
os.rename(tmp_path, str(cache_path))
|
||||
except OSError:
|
||||
# Another process beat us or disk issue
|
||||
if tmp_path is not None:
|
||||
with contextlib.suppress(OSError):
|
||||
os.remove(tmp_path)
|
||||
return
|
||||
self._maybe_evict(exclude=cache_path)
|
||||
|
||||
def _maybe_evict(self, exclude: Path | None = None) -> None:
|
||||
"""Evict expired entries first, then LRU until under size limit."""
|
||||
cache_dir = Path(self._media_cache_dir) # type: ignore[arg-type]
|
||||
entries = []
|
||||
expired = []
|
||||
total_size = 0
|
||||
now = time.time()
|
||||
for f in cache_dir.iterdir():
|
||||
if f.name.startswith("."):
|
||||
continue
|
||||
try:
|
||||
stat = f.stat()
|
||||
except OSError:
|
||||
continue
|
||||
age = now - stat.st_mtime
|
||||
if age > self._media_cache_ttl_secs:
|
||||
expired.append(f)
|
||||
continue
|
||||
total_size += stat.st_size
|
||||
# Never evict the file we just wrote
|
||||
if exclude is not None and f.name == exclude.name:
|
||||
continue
|
||||
entries.append((stat.st_mtime, stat.st_size, f))
|
||||
|
||||
# Evict items according to LRU policy
|
||||
entries.sort(key=lambda e: e[0], reverse=True)
|
||||
while total_size > self._media_cache_max_bytes and entries:
|
||||
mtime, size, f = entries.pop()
|
||||
expired.append(f)
|
||||
total_size -= size
|
||||
|
||||
for f in expired:
|
||||
f.unlink(missing_ok=True)
|
||||
|
||||
def _media_cache_path(self, url: str) -> Path:
|
||||
url_hash = hashlib.sha256(url.encode()).hexdigest()[:20]
|
||||
ext = Path(url.split("?", 1)[0]).suffix or ""
|
||||
return Path(self._media_cache_dir) / f"{url_hash}{ext}" # type: ignore[arg-type]
|
||||
|
||||
def _load_data_url(
|
||||
self,
|
||||
url_spec: Url,
|
||||
@@ -178,6 +295,10 @@ class MediaConnector:
|
||||
if url_spec.scheme and url_spec.scheme.startswith("http"):
|
||||
self._assert_url_in_allowed_media_domains(url_spec)
|
||||
|
||||
cached = self._get_cached_bytes(url)
|
||||
if cached is not None:
|
||||
return media_io.load_bytes(cached)
|
||||
|
||||
connection = self.connection
|
||||
data = connection.get_bytes(
|
||||
url_spec.url,
|
||||
@@ -185,6 +306,7 @@ class MediaConnector:
|
||||
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
|
||||
)
|
||||
|
||||
self._put_cached_bytes(url, data)
|
||||
return media_io.load_bytes(data)
|
||||
|
||||
if url_spec.scheme == "data":
|
||||
@@ -209,12 +331,25 @@ class MediaConnector:
|
||||
if url_spec.scheme and url_spec.scheme.startswith("http"):
|
||||
self._assert_url_in_allowed_media_domains(url_spec)
|
||||
|
||||
cached = await loop.run_in_executor(
|
||||
global_thread_pool, self._get_cached_bytes, url
|
||||
)
|
||||
if cached is not None:
|
||||
future = loop.run_in_executor(
|
||||
global_thread_pool, media_io.load_bytes, cached
|
||||
)
|
||||
return await future
|
||||
|
||||
connection = self.connection
|
||||
data = await connection.async_get_bytes(
|
||||
url_spec.url,
|
||||
timeout=fetch_timeout,
|
||||
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
|
||||
)
|
||||
|
||||
await loop.run_in_executor(
|
||||
global_thread_pool, self._put_cached_bytes, url, data
|
||||
)
|
||||
future = loop.run_in_executor(global_thread_pool, media_io.load_bytes, data)
|
||||
return await future
|
||||
|
||||
|
||||
Reference in New Issue
Block a user