(security) Fix SSRF in batch runner download_bytes_from_url (#38482)

Signed-off-by: jperezde <jperezde@redhat.com>
This commit is contained in:
Juan Pérez de Algaba
2026-03-30 09:10:01 +02:00
committed by GitHub
parent ac30a8311e
commit 57861ae48d
3 changed files with 183 additions and 8 deletions

View File

@@ -4,11 +4,15 @@
import json
import subprocess
import tempfile
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from vllm.assets.audio import AudioAsset
from vllm.entrypoints.openai.run_batch import BatchRequestOutput
from vllm.entrypoints.openai.run_batch import (
BatchRequestOutput,
download_bytes_from_url,
)
CHAT_MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
EMBEDDING_MODEL_NAME = "intfloat/multilingual-e5-small"
@@ -746,3 +750,131 @@ def test_tool_calling():
assert "arguments" in tool_call["function"]
# Verify the tool name matches our tool definition
assert tool_call["function"]["name"] == "get_current_weather"
# ---------------------------------------------------------------------------
# Unit tests for download_bytes_from_url SSRF protection
# ---------------------------------------------------------------------------
def _make_aiohttp_mocks(response_data: bytes = b"fake-data", status: int = 200):
"""Create mock objects that simulate aiohttp.ClientSession context managers."""
mock_resp = MagicMock()
mock_resp.status = status
mock_resp.read = AsyncMock(return_value=response_data)
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False)
mock_session = MagicMock()
mock_session.get = MagicMock(return_value=mock_resp)
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=False)
return mock_session
@pytest.mark.asyncio
async def test_download_bytes_data_url_bypasses_domain_check():
"""data: URLs must work regardless of the domain allowlist."""
data_url = f"data:audio/wav;base64,{MINIMAL_WAV_BASE64}"
result = await download_bytes_from_url(
data_url, allowed_media_domains=["example.com"]
)
assert isinstance(result, bytes)
assert len(result) > 0
@pytest.mark.asyncio
async def test_download_bytes_rejects_disallowed_domain():
"""HTTP URLs whose hostname is not in the allowlist must be rejected."""
url = "https://evil.internal/secret"
with pytest.raises(ValueError, match="allowed domains"):
await download_bytes_from_url(url, allowed_media_domains=["example.com"])
@pytest.mark.asyncio
async def test_download_bytes_rejects_cloud_metadata_ip():
"""Cloud metadata endpoints must be blocked when an allowlist is set."""
url = "http://169.254.169.254/latest/meta-data/"
with pytest.raises(ValueError, match="allowed domains"):
await download_bytes_from_url(url, allowed_media_domains=["example.com"])
@pytest.mark.asyncio
async def test_download_bytes_rejects_internal_ip():
"""Private-range IPs must be blocked when an allowlist is set."""
for internal_url in [
"http://10.0.0.1/secret",
"http://192.168.1.1/admin",
"http://127.0.0.1:8080/internal",
]:
with pytest.raises(ValueError, match="allowed domains"):
await download_bytes_from_url(
internal_url, allowed_media_domains=["example.com"]
)
@pytest.mark.asyncio
async def test_download_bytes_allows_permitted_domain():
"""HTTP URLs whose hostname IS in the allowlist must be fetched."""
url = "https://example.com/audio.wav"
expected = b"audio-bytes"
mock_session = _make_aiohttp_mocks(expected)
with patch(
"vllm.entrypoints.openai.run_batch.aiohttp.ClientSession",
return_value=mock_session,
):
result = await download_bytes_from_url(
url, allowed_media_domains=["example.com"]
)
assert result == expected
@pytest.mark.asyncio
async def test_download_bytes_no_allowlist_permits_any_domain():
"""Without an allowlist all HTTP URLs must be attempted (backward compat)."""
url = "https://any-domain.example.org/file.wav"
expected = b"some-data"
mock_session = _make_aiohttp_mocks(expected)
with patch(
"vllm.entrypoints.openai.run_batch.aiohttp.ClientSession",
return_value=mock_session,
):
result = await download_bytes_from_url(url, allowed_media_domains=None)
assert result == expected
@pytest.mark.asyncio
async def test_download_bytes_empty_allowlist_denies_all():
"""An empty allowlist must deny all HTTP URLs (least privilege)."""
url = "https://any-domain.example.org/file.wav"
with pytest.raises(ValueError, match="allowed domains"):
await download_bytes_from_url(url, allowed_media_domains=[])
@pytest.mark.asyncio
async def test_download_bytes_unsupported_scheme():
"""Unsupported URL schemes must be rejected regardless of allowlist."""
with pytest.raises(ValueError, match="Unsupported URL scheme"):
await download_bytes_from_url("ftp://example.com/file.wav")
with pytest.raises(ValueError, match="Unsupported URL scheme"):
await download_bytes_from_url(
"ftp://example.com/file.wav",
allowed_media_domains=["example.com"],
)
@pytest.mark.asyncio
async def test_download_bytes_backslash_bypass():
"""Backslash-@ URL confusion must not bypass the allowlist.
urllib3.parse_url() and aiohttp/yarl disagree on backslash-before-@.
The fix normalizes through urllib3 before handing to aiohttp.
"""
bypass_url = "http://allowed.example.com\\@evil.internal/secret"
with pytest.raises(ValueError, match="allowed domains"):
await download_bytes_from_url(
bypass_url, allowed_media_domains=["evil.internal"]
)