(security) Fix SSRF in batch runner download_bytes_from_url (#38482)
Signed-off-by: jperezde <jperezde@redhat.com>
This commit is contained in:
committed by
GitHub
parent
ac30a8311e
commit
57861ae48d
@@ -4,11 +4,15 @@
|
||||
import json
|
||||
import subprocess
|
||||
import tempfile
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.entrypoints.openai.run_batch import BatchRequestOutput
|
||||
from vllm.entrypoints.openai.run_batch import (
|
||||
BatchRequestOutput,
|
||||
download_bytes_from_url,
|
||||
)
|
||||
|
||||
CHAT_MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
|
||||
EMBEDDING_MODEL_NAME = "intfloat/multilingual-e5-small"
|
||||
@@ -746,3 +750,131 @@ def test_tool_calling():
|
||||
assert "arguments" in tool_call["function"]
|
||||
# Verify the tool name matches our tool definition
|
||||
assert tool_call["function"]["name"] == "get_current_weather"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests for download_bytes_from_url SSRF protection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_aiohttp_mocks(response_data: bytes = b"fake-data", status: int = 200):
|
||||
"""Create mock objects that simulate aiohttp.ClientSession context managers."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status = status
|
||||
mock_resp.read = AsyncMock(return_value=response_data)
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.get = MagicMock(return_value=mock_resp)
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=False)
|
||||
return mock_session
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_bytes_data_url_bypasses_domain_check():
|
||||
"""data: URLs must work regardless of the domain allowlist."""
|
||||
data_url = f"data:audio/wav;base64,{MINIMAL_WAV_BASE64}"
|
||||
result = await download_bytes_from_url(
|
||||
data_url, allowed_media_domains=["example.com"]
|
||||
)
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_bytes_rejects_disallowed_domain():
|
||||
"""HTTP URLs whose hostname is not in the allowlist must be rejected."""
|
||||
url = "https://evil.internal/secret"
|
||||
with pytest.raises(ValueError, match="allowed domains"):
|
||||
await download_bytes_from_url(url, allowed_media_domains=["example.com"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_bytes_rejects_cloud_metadata_ip():
|
||||
"""Cloud metadata endpoints must be blocked when an allowlist is set."""
|
||||
url = "http://169.254.169.254/latest/meta-data/"
|
||||
with pytest.raises(ValueError, match="allowed domains"):
|
||||
await download_bytes_from_url(url, allowed_media_domains=["example.com"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_bytes_rejects_internal_ip():
|
||||
"""Private-range IPs must be blocked when an allowlist is set."""
|
||||
for internal_url in [
|
||||
"http://10.0.0.1/secret",
|
||||
"http://192.168.1.1/admin",
|
||||
"http://127.0.0.1:8080/internal",
|
||||
]:
|
||||
with pytest.raises(ValueError, match="allowed domains"):
|
||||
await download_bytes_from_url(
|
||||
internal_url, allowed_media_domains=["example.com"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_bytes_allows_permitted_domain():
|
||||
"""HTTP URLs whose hostname IS in the allowlist must be fetched."""
|
||||
url = "https://example.com/audio.wav"
|
||||
expected = b"audio-bytes"
|
||||
mock_session = _make_aiohttp_mocks(expected)
|
||||
|
||||
with patch(
|
||||
"vllm.entrypoints.openai.run_batch.aiohttp.ClientSession",
|
||||
return_value=mock_session,
|
||||
):
|
||||
result = await download_bytes_from_url(
|
||||
url, allowed_media_domains=["example.com"]
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_bytes_no_allowlist_permits_any_domain():
|
||||
"""Without an allowlist all HTTP URLs must be attempted (backward compat)."""
|
||||
url = "https://any-domain.example.org/file.wav"
|
||||
expected = b"some-data"
|
||||
mock_session = _make_aiohttp_mocks(expected)
|
||||
|
||||
with patch(
|
||||
"vllm.entrypoints.openai.run_batch.aiohttp.ClientSession",
|
||||
return_value=mock_session,
|
||||
):
|
||||
result = await download_bytes_from_url(url, allowed_media_domains=None)
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_bytes_empty_allowlist_denies_all():
|
||||
"""An empty allowlist must deny all HTTP URLs (least privilege)."""
|
||||
url = "https://any-domain.example.org/file.wav"
|
||||
with pytest.raises(ValueError, match="allowed domains"):
|
||||
await download_bytes_from_url(url, allowed_media_domains=[])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_bytes_unsupported_scheme():
|
||||
"""Unsupported URL schemes must be rejected regardless of allowlist."""
|
||||
with pytest.raises(ValueError, match="Unsupported URL scheme"):
|
||||
await download_bytes_from_url("ftp://example.com/file.wav")
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported URL scheme"):
|
||||
await download_bytes_from_url(
|
||||
"ftp://example.com/file.wav",
|
||||
allowed_media_domains=["example.com"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_bytes_backslash_bypass():
|
||||
"""Backslash-@ URL confusion must not bypass the allowlist.
|
||||
|
||||
urllib3.parse_url() and aiohttp/yarl disagree on backslash-before-@.
|
||||
The fix normalizes through urllib3 before handing to aiohttp.
|
||||
"""
|
||||
bypass_url = "http://allowed.example.com\\@evil.internal/secret"
|
||||
with pytest.raises(ValueError, match="allowed domains"):
|
||||
await download_bytes_from_url(
|
||||
bypass_url, allowed_media_domains=["evil.internal"]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user