diff --git a/vllm/connections.py b/vllm/connections.py index f79d681ce..8ef715f80 100644 --- a/vllm/connections.py +++ b/vllm/connections.py @@ -1,15 +1,201 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Mapping, MutableMapping +import asyncio +import functools +import time +from collections.abc import Callable, Coroutine, Mapping, MutableMapping from pathlib import Path +from typing import Any, ParamSpec, TypeVar import aiohttp import requests from urllib3.util import parse_url +import vllm.envs as envs +from vllm.logger import init_logger from vllm.version import __version__ as VLLM_VERSION +logger = init_logger(__name__) + +_P = ParamSpec("_P") +_T = TypeVar("_T") + +# Multiplier applied to timeout and sleep on each retry attempt. +# Attempt N uses: base_timeout * (_RETRY_BACKOFF_FACTOR ** N) for the +# per-attempt timeout and sleeps _RETRY_BACKOFF_FACTOR ** N seconds. +_RETRY_BACKOFF_FACTOR = 4 + + +def _is_retryable(exc: Exception) -> bool: + """Return True for transient errors that are worth retrying. + + Retryable: + - Timeouts (aiohttp, requests, stdlib) + - Connection-level failures (refused, reset, DNS) + - Server errors (5xx) -- includes S3 503 SlowDown + Not retryable: + - Client errors (4xx) -- bad URL, auth, not-found + - Programming errors (ValueError, TypeError, ...) + """ + # Timeouts + if isinstance( + exc, + ( + TimeoutError, + asyncio.TimeoutError, + requests.exceptions.Timeout, + aiohttp.ServerTimeoutError, + ), + ): + return True + # Connection-level failures + if isinstance( + exc, + ( + ConnectionError, + aiohttp.ClientConnectionError, + requests.exceptions.ConnectionError, + ), + ): + return True + # aiohttp server-side disconnects + if isinstance(exc, aiohttp.ServerDisconnectedError): + return True + # requests 5xx -- raise_for_status() throws HTTPError + if ( + isinstance(exc, requests.exceptions.HTTPError) + and exc.response is not None + and exc.response.status_code >= 500 + ): + return True + # aiohttp 5xx -- raise_for_status() throws ClientResponseError + return isinstance(exc, aiohttp.ClientResponseError) and exc.status >= 500 + + +def _log_retry( + args: tuple, + kwargs: dict, + attempt: int, + max_retries: int, + attempt_timeout: float | None, + exc: Exception, + backoff: float, + base_timeout: float | None, +) -> None: + # args[0] is `self` (bound method), args[1] is the URL + url = args[1] if len(args) > 1 else kwargs.get("url") + timeout_info = ( + f"timeout={attempt_timeout:.3f}s" if base_timeout is not None else "no timeout" + ) + next_timeout = ( + f" with timeout={base_timeout * (_RETRY_BACKOFF_FACTOR ** (attempt + 1)):.3f}s" + if base_timeout is not None + else "" + ) + logger.warning( + "HTTP fetch failed for %s (attempt %d/%d, %s): %s -- retrying in %.3fs%s", + url, + attempt + 1, + max_retries, + timeout_info, + exc, + backoff, + next_timeout, + ) + + +def _sync_retry( + fn: Callable[_P, _T], +) -> Callable[_P, _T]: + """Add retry logic with exponential backoff to a sync method. + + The decorated method must accept ``timeout`` as a keyword argument. + The decorator replaces it with a per-attempt timeout that grows by + ``_RETRY_BACKOFF_FACTOR`` on each retry so transient slowness on busy + hosts is absorbed. + """ + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> _T: + base_timeout: float | None = kwargs.get("timeout") + max_retries = max(envs.VLLM_MEDIA_FETCH_MAX_RETRIES, 1) + + for attempt in range(max_retries): + attempt_timeout = ( + base_timeout * (_RETRY_BACKOFF_FACTOR**attempt) + if base_timeout is not None + else None + ) + kwargs["timeout"] = attempt_timeout + try: + return fn(*args, **kwargs) + except Exception as e: + if not _is_retryable(e) or attempt + 1 >= max_retries: + raise + backoff = _RETRY_BACKOFF_FACTOR**attempt + _log_retry( + args, + kwargs, + attempt, + max_retries, + attempt_timeout, + e, + backoff, + base_timeout, + ) + time.sleep(backoff) + + raise AssertionError("unreachable") + + return wrapper # type: ignore[return-value] + + +def _async_retry( + fn: Callable[_P, Coroutine[Any, Any, _T]], +) -> Callable[_P, Coroutine[Any, Any, _T]]: + """Add retry logic with exponential backoff to an async method. + + The decorated method must accept ``timeout`` as a keyword argument. + The decorator replaces it with a per-attempt timeout that grows by + ``_RETRY_BACKOFF_FACTOR`` on each retry so transient slowness on busy + hosts is absorbed. + """ + + @functools.wraps(fn) + async def wrapper(*args: Any, **kwargs: Any) -> _T: + base_timeout: float | None = kwargs.get("timeout") + max_retries = max(envs.VLLM_MEDIA_FETCH_MAX_RETRIES, 1) + + for attempt in range(max_retries): + attempt_timeout = ( + base_timeout * (_RETRY_BACKOFF_FACTOR**attempt) + if base_timeout is not None + else None + ) + kwargs["timeout"] = attempt_timeout + try: + return await fn(*args, **kwargs) + except Exception as e: + if not _is_retryable(e) or attempt + 1 >= max_retries: + raise + backoff = _RETRY_BACKOFF_FACTOR**attempt + _log_retry( + args, + kwargs, + attempt, + max_retries, + attempt_timeout, + e, + backoff, + base_timeout, + ) + await asyncio.sleep(backoff) + + raise AssertionError("unreachable") + + return wrapper # type: ignore[return-value] + class HTTPConnection: """Helper class to send HTTP requests.""" @@ -89,6 +275,7 @@ class HTTPConnection: allow_redirects=allow_redirects, ) + @_sync_retry def get_bytes( self, url: str, *, timeout: float | None = None, allow_redirects: bool = True ) -> bytes: @@ -99,6 +286,7 @@ class HTTPConnection: return r.content + @_async_retry async def async_get_bytes( self, url: str, @@ -147,6 +335,7 @@ class HTTPConnection: return await r.json() + @_sync_retry def download_file( self, url: str, @@ -155,15 +344,22 @@ class HTTPConnection: timeout: float | None = None, chunk_size: int = 128, ) -> Path: - with self.get_response(url, timeout=timeout) as r: - r.raise_for_status() + try: + with self.get_response(url, timeout=timeout) as r: + r.raise_for_status() - with save_path.open("wb") as f: - for chunk in r.iter_content(chunk_size): - f.write(chunk) + with save_path.open("wb") as f: + for chunk in r.iter_content(chunk_size): + f.write(chunk) - return save_path + return save_path + except Exception: + # Clean up partial downloads before retrying or propagating + if save_path.exists(): + save_path.unlink() + raise + @_async_retry async def async_download_file( self, url: str, @@ -172,14 +368,23 @@ class HTTPConnection: timeout: float | None = None, chunk_size: int = 128, ) -> Path: - async with await self.get_async_response(url, timeout=timeout) as r: - r.raise_for_status() + try: + async with await self.get_async_response( + url, + timeout=timeout, + ) as r: + r.raise_for_status() - with save_path.open("wb") as f: - async for chunk in r.content.iter_chunked(chunk_size): - f.write(chunk) + with save_path.open("wb") as f: + async for chunk in r.content.iter_chunked(chunk_size): + f.write(chunk) - return save_path + return save_path + except Exception: + # Clean up partial downloads before retrying or propagating + if save_path.exists(): + save_path.unlink() + raise global_http_connection = HTTPConnection() diff --git a/vllm/envs.py b/vllm/envs.py index d6240df36..2f93b2cb3 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -64,6 +64,7 @@ if TYPE_CHECKING: VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_VIDEO_FETCH_TIMEOUT: int = 30 VLLM_AUDIO_FETCH_TIMEOUT: int = 10 + VLLM_MEDIA_FETCH_MAX_RETRIES: int = 3 VLLM_MEDIA_URL_ALLOW_REDIRECTS: bool = True VLLM_MEDIA_LOADING_THREAD_COUNT: int = 8 VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25 @@ -773,6 +774,11 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_AUDIO_FETCH_TIMEOUT": lambda: int( os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10") ), + # Maximum number of retries for fetching media (images, audio, video) + # from URLs. Each retry quadruples the timeout. Default is 3. + "VLLM_MEDIA_FETCH_MAX_RETRIES": lambda: int( + os.getenv("VLLM_MEDIA_FETCH_MAX_RETRIES", "3") + ), # Whether to allow HTTP redirects when fetching from media URLs. # Default to True "VLLM_MEDIA_URL_ALLOW_REDIRECTS": lambda: bool( @@ -1768,6 +1774,7 @@ def compile_factors() -> dict[str, object]: "VLLM_IMAGE_FETCH_TIMEOUT", "VLLM_VIDEO_FETCH_TIMEOUT", "VLLM_AUDIO_FETCH_TIMEOUT", + "VLLM_MEDIA_FETCH_MAX_RETRIES", "VLLM_MEDIA_URL_ALLOW_REDIRECTS", "VLLM_MEDIA_LOADING_THREAD_COUNT", "VLLM_MAX_AUDIO_CLIP_FILESIZE_MB",