diff --git a/docs/usage/security.md b/docs/usage/security.md index b126d2a1e..4879ddbf6 100644 --- a/docs/usage/security.md +++ b/docs/usage/security.md @@ -66,6 +66,10 @@ Restrict domains that vLLM can access for media URLs by setting `--allowed-media-domains` to prevent Server-Side Request Forgery (SSRF) attacks. (e.g. `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`) +This protection applies to both the online serving API (multimodal inputs) and +the **batch runner** (`vllm run-batch`), where `file_url` values in batch +transcription/translation requests are validated against the same allowlist. + Without domain restrictions, a malicious user could supply URLs that: - **Target internal services**: Access internal network endpoints, cloud metadata diff --git a/tests/entrypoints/openai/test_run_batch.py b/tests/entrypoints/openai/test_run_batch.py index bf670105b..cd1daf0bb 100644 --- a/tests/entrypoints/openai/test_run_batch.py +++ b/tests/entrypoints/openai/test_run_batch.py @@ -4,11 +4,15 @@ import json import subprocess import tempfile +from unittest.mock import AsyncMock, MagicMock, patch import pytest from vllm.assets.audio import AudioAsset -from vllm.entrypoints.openai.run_batch import BatchRequestOutput +from vllm.entrypoints.openai.run_batch import ( + BatchRequestOutput, + download_bytes_from_url, +) CHAT_MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM" EMBEDDING_MODEL_NAME = "intfloat/multilingual-e5-small" @@ -746,3 +750,131 @@ def test_tool_calling(): assert "arguments" in tool_call["function"] # Verify the tool name matches our tool definition assert tool_call["function"]["name"] == "get_current_weather" + + +# --------------------------------------------------------------------------- +# Unit tests for download_bytes_from_url SSRF protection +# --------------------------------------------------------------------------- + + +def _make_aiohttp_mocks(response_data: bytes = b"fake-data", status: int = 200): + """Create mock objects that simulate aiohttp.ClientSession context managers.""" + mock_resp = MagicMock() + mock_resp.status = status + mock_resp.read = AsyncMock(return_value=response_data) + mock_resp.__aenter__ = AsyncMock(return_value=mock_resp) + mock_resp.__aexit__ = AsyncMock(return_value=False) + + mock_session = MagicMock() + mock_session.get = MagicMock(return_value=mock_resp) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + return mock_session + + +@pytest.mark.asyncio +async def test_download_bytes_data_url_bypasses_domain_check(): + """data: URLs must work regardless of the domain allowlist.""" + data_url = f"data:audio/wav;base64,{MINIMAL_WAV_BASE64}" + result = await download_bytes_from_url( + data_url, allowed_media_domains=["example.com"] + ) + assert isinstance(result, bytes) + assert len(result) > 0 + + +@pytest.mark.asyncio +async def test_download_bytes_rejects_disallowed_domain(): + """HTTP URLs whose hostname is not in the allowlist must be rejected.""" + url = "https://evil.internal/secret" + with pytest.raises(ValueError, match="allowed domains"): + await download_bytes_from_url(url, allowed_media_domains=["example.com"]) + + +@pytest.mark.asyncio +async def test_download_bytes_rejects_cloud_metadata_ip(): + """Cloud metadata endpoints must be blocked when an allowlist is set.""" + url = "http://169.254.169.254/latest/meta-data/" + with pytest.raises(ValueError, match="allowed domains"): + await download_bytes_from_url(url, allowed_media_domains=["example.com"]) + + +@pytest.mark.asyncio +async def test_download_bytes_rejects_internal_ip(): + """Private-range IPs must be blocked when an allowlist is set.""" + for internal_url in [ + "http://10.0.0.1/secret", + "http://192.168.1.1/admin", + "http://127.0.0.1:8080/internal", + ]: + with pytest.raises(ValueError, match="allowed domains"): + await download_bytes_from_url( + internal_url, allowed_media_domains=["example.com"] + ) + + +@pytest.mark.asyncio +async def test_download_bytes_allows_permitted_domain(): + """HTTP URLs whose hostname IS in the allowlist must be fetched.""" + url = "https://example.com/audio.wav" + expected = b"audio-bytes" + mock_session = _make_aiohttp_mocks(expected) + + with patch( + "vllm.entrypoints.openai.run_batch.aiohttp.ClientSession", + return_value=mock_session, + ): + result = await download_bytes_from_url( + url, allowed_media_domains=["example.com"] + ) + assert result == expected + + +@pytest.mark.asyncio +async def test_download_bytes_no_allowlist_permits_any_domain(): + """Without an allowlist all HTTP URLs must be attempted (backward compat).""" + url = "https://any-domain.example.org/file.wav" + expected = b"some-data" + mock_session = _make_aiohttp_mocks(expected) + + with patch( + "vllm.entrypoints.openai.run_batch.aiohttp.ClientSession", + return_value=mock_session, + ): + result = await download_bytes_from_url(url, allowed_media_domains=None) + assert result == expected + + +@pytest.mark.asyncio +async def test_download_bytes_empty_allowlist_denies_all(): + """An empty allowlist must deny all HTTP URLs (least privilege).""" + url = "https://any-domain.example.org/file.wav" + with pytest.raises(ValueError, match="allowed domains"): + await download_bytes_from_url(url, allowed_media_domains=[]) + + +@pytest.mark.asyncio +async def test_download_bytes_unsupported_scheme(): + """Unsupported URL schemes must be rejected regardless of allowlist.""" + with pytest.raises(ValueError, match="Unsupported URL scheme"): + await download_bytes_from_url("ftp://example.com/file.wav") + + with pytest.raises(ValueError, match="Unsupported URL scheme"): + await download_bytes_from_url( + "ftp://example.com/file.wav", + allowed_media_domains=["example.com"], + ) + + +@pytest.mark.asyncio +async def test_download_bytes_backslash_bypass(): + """Backslash-@ URL confusion must not bypass the allowlist. + + urllib3.parse_url() and aiohttp/yarl disagree on backslash-before-@. + The fix normalizes through urllib3 before handing to aiohttp. + """ + bypass_url = "http://allowed.example.com\\@evil.internal/secret" + with pytest.raises(ValueError, match="allowed domains"): + await download_bytes_from_url( + bypass_url, allowed_media_domains=["evil.internal"] + ) diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 03a15991d..3afd9b8ca 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -20,7 +20,9 @@ from pydantic import Field, TypeAdapter, field_validator, model_validator from pydantic_core.core_schema import ValidationInfo from starlette.datastructures import State from tqdm import tqdm +from urllib3.util import parse_url +import vllm.envs as envs from vllm.config import config from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.protocol import EngineClient @@ -439,19 +441,25 @@ async def write_file( await write_local_file(path_or_url, batch_outputs) -async def download_bytes_from_url(url: str) -> bytes: +async def download_bytes_from_url( + url: str, + allowed_media_domains: list[str] | None = None, +) -> bytes: """ Download data from a URL or decode from a data URL. Args: url: Either an HTTP/HTTPS URL or a data URL (data:...;base64,...) + allowed_media_domains: If set, only HTTP/HTTPS URLs whose hostname + is in this list are permitted. data: URLs are not subject to + this restriction. Returns: Data as bytes """ parsed = urlparse(url) - # Handle data URLs (base64 encoded) + # Handle data URLs (base64 encoded) - not subject to domain restrictions if parsed.scheme == "data": # Format: data:...;base64, if "," in url: @@ -465,9 +473,24 @@ async def download_bytes_from_url(url: str) -> bytes: # Handle HTTP/HTTPS URLs elif parsed.scheme in ("http", "https"): + if allowed_media_domains is not None: + url_spec = parse_url(url) + if url_spec.hostname not in allowed_media_domains: + raise ValueError( + f"The URL must be from one of the allowed domains: " + f"{allowed_media_domains}. Input URL domain: " + f"{url_spec.hostname}" + ) + # Use the normalized URL to prevent parsing discrepancies + # between urllib3 and aiohttp (e.g. backslash-@ attacks). + url = url_spec.url + async with ( aiohttp.ClientSession() as session, - session.get(url) as resp, + session.get( + url, + allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS, + ) as resp, ): if resp.status != 200: raise Exception( @@ -593,7 +616,10 @@ def handle_endpoint_request( return run_request(handler_fn, request, tracker) -def make_transcription_wrapper(is_translation: bool) -> WrapperFn: +def make_transcription_wrapper( + is_translation: bool, + allowed_media_domains: list[str] | None = None, +) -> WrapperFn: """ Factory function to create a wrapper for transcription/translation handlers. The wrapper converts BatchTranscriptionRequest or BatchTranslationRequest @@ -602,6 +628,8 @@ def make_transcription_wrapper(is_translation: bool) -> WrapperFn: Args: is_translation: If True, process as translation; otherwise process as transcription + allowed_media_domains: If set, only URLs from these domains are + permitted for HTTP/HTTPS fetches. Returns: A function that takes a handler and returns a wrapped handler @@ -619,7 +647,10 @@ def make_transcription_wrapper(is_translation: bool) -> WrapperFn: ): try: # Download data from URL - audio_data = await download_bytes_from_url(batch_request_body.file_url) + audio_data = await download_bytes_from_url( + batch_request_body.file_url, + allowed_media_domains=allowed_media_domains, + ) # Create a mock file from the downloaded audio data mock_file = UploadFile( @@ -691,6 +722,8 @@ async def build_endpoint_registry( serving_embedding = getattr(state, "serving_embedding", None) serving_scores = getattr(state, "serving_scores", None) + allowed_media_domains = getattr(args, "allowed_media_domains", None) + # Registry of endpoint configurations endpoint_registry: dict[str, dict[str, Any]] = { "completions": { @@ -730,7 +763,10 @@ async def build_endpoint_registry( if openai_serving_transcription is not None else None ), - "wrapper_fn": make_transcription_wrapper(is_translation=False), + "wrapper_fn": make_transcription_wrapper( + is_translation=False, + allowed_media_domains=allowed_media_domains, + ), }, "translations": { "url_matcher": lambda url: url == "/v1/audio/translations", @@ -739,7 +775,10 @@ async def build_endpoint_registry( if openai_serving_translation is not None else None ), - "wrapper_fn": make_transcription_wrapper(is_translation=True), + "wrapper_fn": make_transcription_wrapper( + is_translation=True, + allowed_media_domains=allowed_media_domains, + ), }, }