[Misc] Replace urllib's urlparse with urllib3's parse_url (#32746)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -3,10 +3,10 @@
|
||||
|
||||
from collections.abc import Mapping, MutableMapping
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from urllib3.util import parse_url
|
||||
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
@@ -37,7 +37,7 @@ class HTTPConnection:
|
||||
return self._async_client
|
||||
|
||||
def _validate_http_url(self, url: str):
|
||||
parsed_url = urlparse(url)
|
||||
parsed_url = parse_url(url)
|
||||
|
||||
if parsed_url.scheme not in ("http", "https"):
|
||||
raise ValueError(
|
||||
|
||||
@@ -439,9 +439,9 @@ def get_vllm_port() -> int | None:
|
||||
try:
|
||||
return int(port)
|
||||
except ValueError as err:
|
||||
from urllib.parse import urlparse
|
||||
from urllib3.util import parse_url
|
||||
|
||||
parsed = urlparse(port)
|
||||
parsed = parse_url(port)
|
||||
if parsed.scheme:
|
||||
raise ValueError(
|
||||
f"VLLM_PORT '{port}' appears to be a URI. "
|
||||
|
||||
@@ -9,13 +9,13 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from itertools import groupby
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
from urllib.parse import ParseResult, urlparse
|
||||
from urllib.request import url2pathname
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
from urllib3.util import Url, parse_url
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.connections import HTTPConnection, global_http_connection
|
||||
@@ -105,11 +105,14 @@ class MediaConnector:
|
||||
|
||||
def _load_data_url(
|
||||
self,
|
||||
url_spec: ParseResult,
|
||||
url_spec: Url,
|
||||
media_io: MediaIO[_M],
|
||||
) -> _M: # type: ignore[type-var]
|
||||
data_spec, data = url_spec.path.split(",", 1)
|
||||
url_spec_path = url_spec.path or ""
|
||||
data_spec, data = url_spec_path.split(",", 1)
|
||||
media_type, data_type = data_spec.split(";", 1)
|
||||
# media_type starts with a leading "/" (e.g., "/video/jpeg")
|
||||
media_type = media_type.lstrip("/")
|
||||
|
||||
if data_type != "base64":
|
||||
msg = "Only base64 data URLs are supported for now."
|
||||
@@ -119,7 +122,7 @@ class MediaConnector:
|
||||
|
||||
def _load_file_url(
|
||||
self,
|
||||
url_spec: ParseResult,
|
||||
url_spec: Url,
|
||||
media_io: MediaIO[_M],
|
||||
) -> _M: # type: ignore[type-var]
|
||||
allowed_local_media_path = self.allowed_local_media_path
|
||||
@@ -128,7 +131,9 @@ class MediaConnector:
|
||||
"Cannot load local files without `--allowed-local-media-path`."
|
||||
)
|
||||
|
||||
filepath = Path(url2pathname(url_spec.netloc + url_spec.path))
|
||||
url_spec_path = url_spec.path or ""
|
||||
url_spec_netloc = url_spec.netloc or ""
|
||||
filepath = Path(url2pathname(url_spec_netloc + url_spec_path))
|
||||
if allowed_local_media_path not in filepath.resolve().parents:
|
||||
raise ValueError(
|
||||
f"The file path {filepath} must be a subpath "
|
||||
@@ -137,7 +142,7 @@ class MediaConnector:
|
||||
|
||||
return media_io.load_file(filepath)
|
||||
|
||||
def _assert_url_in_allowed_media_domains(self, url_spec: ParseResult) -> None:
|
||||
def _assert_url_in_allowed_media_domains(self, url_spec: Url) -> None:
|
||||
if (
|
||||
self.allowed_media_domains
|
||||
and url_spec.hostname not in self.allowed_media_domains
|
||||
@@ -155,9 +160,9 @@ class MediaConnector:
|
||||
*,
|
||||
fetch_timeout: int | None = None,
|
||||
) -> _M: # type: ignore[type-var]
|
||||
url_spec = urlparse(url)
|
||||
url_spec = parse_url(url)
|
||||
|
||||
if url_spec.scheme.startswith("http"):
|
||||
if url_spec.scheme and url_spec.scheme.startswith("http"):
|
||||
self._assert_url_in_allowed_media_domains(url_spec)
|
||||
|
||||
connection = self.connection
|
||||
@@ -185,10 +190,10 @@ class MediaConnector:
|
||||
*,
|
||||
fetch_timeout: int | None = None,
|
||||
) -> _M:
|
||||
url_spec = urlparse(url)
|
||||
url_spec = parse_url(url)
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
if url_spec.scheme.startswith("http"):
|
||||
if url_spec.scheme and url_spec.scheme.startswith("http"):
|
||||
self._assert_url_in_allowed_media_domains(url_spec)
|
||||
|
||||
connection = self.connection
|
||||
|
||||
@@ -11,12 +11,12 @@ from collections.abc import (
|
||||
Sequence,
|
||||
)
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
import psutil
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from urllib3.util import parse_url
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
@@ -217,13 +217,15 @@ def find_process_using_port(port: int) -> psutil.Process | None:
|
||||
|
||||
def split_zmq_path(path: str) -> tuple[str, str, str]:
|
||||
"""Split a zmq path into its parts."""
|
||||
parsed = urlparse(path)
|
||||
parsed = parse_url(path)
|
||||
if not parsed.scheme:
|
||||
raise ValueError(f"Invalid zmq path: {path}")
|
||||
|
||||
scheme = parsed.scheme
|
||||
host = parsed.hostname or ""
|
||||
port = str(parsed.port or "")
|
||||
if host.startswith("[") and host.endswith("]"):
|
||||
host = host[1:-1] # Remove brackets for IPv6 address
|
||||
|
||||
if scheme == "tcp" and not all((host, port)):
|
||||
# The host and port fields are required for tcp
|
||||
|
||||
Reference in New Issue
Block a user