[Core][CI] Add opt-in media URL caching via VLLM_MEDIA_CACHE (#37123)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-03-30 06:58:53 -05:00
committed by GitHub
parent 1031c84c36
commit 677424c7ac
3 changed files with 266 additions and 0 deletions

View File

@@ -4,6 +4,8 @@
import asyncio
import mimetypes
import os
import shutil
import time
from tempfile import NamedTemporaryFile, TemporaryDirectory
import aiohttp
@@ -375,3 +377,113 @@ async def test_ssrf_bypass_backslash_disallowed_domain():
with pytest.raises(ValueError, match="allowed domains"):
await connector.fetch_image_async(bypass_url)
def _make_cached_connector(cache_dir, *, max_mb=10, ttl_hours=24):
"""Create a MediaConnector with caching enabled via monkeypatched internals.
We bypass __init__'s env-var path and wire up the cache fields directly
so tests don't depend on environment variables. URLs in these tests are
only used as cache keys (hashed to derive filenames); no HTTP requests
are made.
"""
connector = MediaConnector()
connector._media_cache_dir = cache_dir
connector._media_cache_max_bytes = max_mb * 1024 * 1024
connector._media_cache_ttl_secs = ttl_hours * 3600
return connector
def test_cache_put_and_get():
"""Basic round-trip: put bytes, get them back."""
with TemporaryDirectory() as cache_dir:
connector = _make_cached_connector(cache_dir)
url = "https://example.com/image.png"
data = b"fake-image-bytes"
connector._put_cached_bytes(url, data)
cached = connector._get_cached_bytes(url)
assert cached == data
def test_cache_ttl_expiry():
"""Entries older than TTL are evicted on read."""
with TemporaryDirectory() as cache_dir:
connector = _make_cached_connector(cache_dir, ttl_hours=24)
url = "https://example.com/old.png"
data = b"old-data"
connector._put_cached_bytes(url, data)
# Backdate the file's mtime so it appears expired
cache_path = connector._media_cache_path(url)
expired_time = time.time() - (25 * 3600) # 25 hours ago
os.utime(cache_path, (expired_time, expired_time))
assert connector._get_cached_bytes(url) is None
assert not cache_path.exists()
def test_cache_lru_eviction():
"""Oldest entries are evicted when cache exceeds size budget."""
with TemporaryDirectory() as cache_dir:
# Set a very small max size: 100 bytes
connector = _make_cached_connector(cache_dir, max_mb=0)
connector._media_cache_max_bytes = 100
# Write three 50-byte entries (total 150 > 100 budget)
urls = [f"https://example.com/{i}.png" for i in range(3)]
for i, url in enumerate(urls):
connector._put_cached_bytes(url, b"x" * 50)
# Stagger mtime so eviction order is deterministic
path = connector._media_cache_path(url)
os.utime(path, (time.time() + i, time.time() + i))
# The oldest entry (urls[0]) should have been evicted
assert connector._get_cached_bytes(urls[0]) is None
# The newest entries should still be present
assert connector._get_cached_bytes(urls[2]) == b"x" * 50
def test_cache_ttl_eviction_during_write():
"""_maybe_evict removes expired files even if under size budget."""
with TemporaryDirectory() as cache_dir:
connector = _make_cached_connector(cache_dir, ttl_hours=1)
url_old = "https://example.com/stale.png"
url_new = "https://example.com/fresh.png"
connector._put_cached_bytes(url_old, b"stale")
# Backdate old entry past TTL
old_path = connector._media_cache_path(url_old)
expired_time = time.time() - (2 * 3600)
os.utime(old_path, (expired_time, expired_time))
# Writing a new entry triggers _maybe_evict
connector._put_cached_bytes(url_new, b"fresh")
assert not old_path.exists()
assert connector._get_cached_bytes(url_new) == b"fresh"
def test_put_cached_bytes_missing_dir():
"""_put_cached_bytes does not crash when the cache dir disappears."""
with TemporaryDirectory() as cache_dir:
connector = _make_cached_connector(cache_dir)
# Remove the directory to simulate it disappearing at runtime
shutil.rmtree(cache_dir)
# Should not raise (graceful degradation)
connector._put_cached_bytes("https://example.com/x.png", b"data")
def test_get_cached_bytes_file_deleted_before_read():
"""_get_cached_bytes returns None if the file vanishes mid-read."""
with TemporaryDirectory() as cache_dir:
connector = _make_cached_connector(cache_dir)
url = "https://example.com/vanish.png"
connector._put_cached_bytes(url, b"data")
# Delete the file to simulate concurrent eviction
connector._media_cache_path(url).unlink()
assert connector._get_cached_bytes(url) is None

View File

@@ -64,6 +64,9 @@ if TYPE_CHECKING:
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
VLLM_VIDEO_FETCH_TIMEOUT: int = 30
VLLM_AUDIO_FETCH_TIMEOUT: int = 10
VLLM_MEDIA_CACHE: str = ""
VLLM_MEDIA_CACHE_MAX_SIZE_MB: int = 5120
VLLM_MEDIA_CACHE_TTL_HOURS: float = 24
VLLM_MEDIA_FETCH_MAX_RETRIES: int = 3
VLLM_MEDIA_URL_ALLOW_REDIRECTS: bool = True
VLLM_MEDIA_LOADING_THREAD_COUNT: int = 8
@@ -776,6 +779,19 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_AUDIO_FETCH_TIMEOUT": lambda: int(
os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")
),
# Directory for caching media downloads (images, video, audio fetched
# from URLs during inference). Empty string disables caching.
"VLLM_MEDIA_CACHE": lambda: os.getenv("VLLM_MEDIA_CACHE", ""),
# Maximum cache size in MB. When exceeded, least-recently-used entries
# are evicted. Default is 5120 (5 GB).
"VLLM_MEDIA_CACHE_MAX_SIZE_MB": lambda: int(
os.getenv("VLLM_MEDIA_CACHE_MAX_SIZE_MB", "5120")
),
# Time-to-live in hours for cached media files. Entries older than this
# are evicted regardless of cache size. Default is 24 hours.
"VLLM_MEDIA_CACHE_TTL_HOURS": lambda: float(
os.getenv("VLLM_MEDIA_CACHE_TTL_HOURS", "24")
),
# Maximum number of retries for fetching media (images, audio, video)
# from URLs. Each retry quadruples the timeout. Default is 3.
"VLLM_MEDIA_FETCH_MAX_RETRIES": lambda: int(
@@ -1777,6 +1793,9 @@ def compile_factors() -> dict[str, object]:
"VLLM_IMAGE_FETCH_TIMEOUT",
"VLLM_VIDEO_FETCH_TIMEOUT",
"VLLM_AUDIO_FETCH_TIMEOUT",
"VLLM_MEDIA_CACHE",
"VLLM_MEDIA_CACHE_MAX_SIZE_MB",
"VLLM_MEDIA_CACHE_TTL_HOURS",
"VLLM_MEDIA_FETCH_MAX_RETRIES",
"VLLM_MEDIA_URL_ALLOW_REDIRECTS",
"VLLM_MEDIA_LOADING_THREAD_COUNT",

View File

@@ -3,6 +3,11 @@
import asyncio
import atexit
import contextlib
import hashlib
import os
import tempfile
import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any, TypeVar
@@ -16,6 +21,7 @@ from urllib3.util import Url, parse_url
import vllm.envs as envs
from vllm.connections import HTTPConnection, global_http_connection
from vllm.logger import init_logger
from vllm.utils.registry import ExtensionManager
from .audio import AudioEmbeddingMediaIO, AudioMediaIO
@@ -23,6 +29,8 @@ from .base import MediaIO
from .image import ImageEmbeddingMediaIO, ImageMediaIO
from .video import VideoMediaIO
logger = init_logger(__name__)
_M = TypeVar("_M")
global_thread_pool = ThreadPoolExecutor(
@@ -116,6 +124,115 @@ class MediaConnector:
allowed_media_domains = []
self.allowed_media_domains = allowed_media_domains
# Media download cache (opt-in via VLLM_MEDIA_CACHE)
self._media_cache_dir: str | None = None
self._media_cache_max_bytes: int = 0
self._media_cache_ttl_secs: float = 0
media_cache = envs.VLLM_MEDIA_CACHE
if media_cache:
try:
os.makedirs(media_cache, exist_ok=True)
# Verify the directory is writable before enabling caching
with tempfile.NamedTemporaryFile(dir=media_cache, delete=True):
pass
self._media_cache_dir = media_cache
self._media_cache_max_bytes = (
envs.VLLM_MEDIA_CACHE_MAX_SIZE_MB * 1024 * 1024
)
self._media_cache_ttl_secs = envs.VLLM_MEDIA_CACHE_TTL_HOURS * 3600
logger.info(
"Media cache enabled at %s (max %d MB, TTL %s hours)",
media_cache,
envs.VLLM_MEDIA_CACHE_MAX_SIZE_MB,
envs.VLLM_MEDIA_CACHE_TTL_HOURS,
)
except OSError:
logger.warning(
"VLLM_MEDIA_CACHE path %s is not writable, media caching disabled",
media_cache,
)
def _get_cached_bytes(self, url: str) -> bytes | None:
"""Return cached bytes for a URL, or None if not cached/expired."""
if not self._media_cache_dir:
return None
cache_path = self._media_cache_path(url)
# Check TTL
try:
age = time.time() - cache_path.stat().st_mtime
except OSError:
return None
if age > self._media_cache_ttl_secs:
cache_path.unlink(missing_ok=True)
return None
# Touch mtime for LRU ordering
try:
cache_path.touch()
return cache_path.read_bytes()
except OSError:
return None
def _put_cached_bytes(self, url: str, data: bytes) -> None:
"""Store downloaded bytes and evict if over budget."""
if not self._media_cache_dir:
return
cache_path = self._media_cache_path(url)
# Atomic write via temp file + rename
tmp_path = None
try:
with tempfile.NamedTemporaryFile(
mode="wb", dir=self._media_cache_dir, delete=False
) as tmp_file:
tmp_file.write(data)
tmp_path = tmp_file.name
os.rename(tmp_path, str(cache_path))
except OSError:
# Another process beat us or disk issue
if tmp_path is not None:
with contextlib.suppress(OSError):
os.remove(tmp_path)
return
self._maybe_evict(exclude=cache_path)
def _maybe_evict(self, exclude: Path | None = None) -> None:
"""Evict expired entries first, then LRU until under size limit."""
cache_dir = Path(self._media_cache_dir) # type: ignore[arg-type]
entries = []
expired = []
total_size = 0
now = time.time()
for f in cache_dir.iterdir():
if f.name.startswith("."):
continue
try:
stat = f.stat()
except OSError:
continue
age = now - stat.st_mtime
if age > self._media_cache_ttl_secs:
expired.append(f)
continue
total_size += stat.st_size
# Never evict the file we just wrote
if exclude is not None and f.name == exclude.name:
continue
entries.append((stat.st_mtime, stat.st_size, f))
# Evict items according to LRU policy
entries.sort(key=lambda e: e[0], reverse=True)
while total_size > self._media_cache_max_bytes and entries:
mtime, size, f = entries.pop()
expired.append(f)
total_size -= size
for f in expired:
f.unlink(missing_ok=True)
def _media_cache_path(self, url: str) -> Path:
url_hash = hashlib.sha256(url.encode()).hexdigest()[:20]
ext = Path(url.split("?", 1)[0]).suffix or ""
return Path(self._media_cache_dir) / f"{url_hash}{ext}" # type: ignore[arg-type]
def _load_data_url(
self,
url_spec: Url,
@@ -178,6 +295,10 @@ class MediaConnector:
if url_spec.scheme and url_spec.scheme.startswith("http"):
self._assert_url_in_allowed_media_domains(url_spec)
cached = self._get_cached_bytes(url)
if cached is not None:
return media_io.load_bytes(cached)
connection = self.connection
data = connection.get_bytes(
url_spec.url,
@@ -185,6 +306,7 @@ class MediaConnector:
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
)
self._put_cached_bytes(url, data)
return media_io.load_bytes(data)
if url_spec.scheme == "data":
@@ -209,12 +331,25 @@ class MediaConnector:
if url_spec.scheme and url_spec.scheme.startswith("http"):
self._assert_url_in_allowed_media_domains(url_spec)
cached = await loop.run_in_executor(
global_thread_pool, self._get_cached_bytes, url
)
if cached is not None:
future = loop.run_in_executor(
global_thread_pool, media_io.load_bytes, cached
)
return await future
connection = self.connection
data = await connection.async_get_bytes(
url_spec.url,
timeout=fetch_timeout,
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
)
await loop.run_in_executor(
global_thread_pool, self._put_cached_bytes, url, data
)
future = loop.run_in_executor(global_thread_pool, media_io.load_bytes, data)
return await future