Add option to restrict media domains (#25783)
Signed-off-by: Chenheli Hua <huachenheli@outlook.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Chenheli Hua <huachenheli@outlook.com> Signed-off-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
@@ -137,6 +137,9 @@ class ModelConfig:
|
||||
"""Allowing API requests to read local images or videos from directories
|
||||
specified by the server file system. This is a security risk. Should only
|
||||
be enabled in trusted environments."""
|
||||
allowed_media_domains: Optional[list[str]] = None
|
||||
"""If set, only media URLs that belong to this domain can be used for
|
||||
multi-modal inputs. """
|
||||
revision: Optional[str] = None
|
||||
"""The specific model version to use. It can be a branch name, a tag name,
|
||||
or a commit id. If unspecified, will use the default version."""
|
||||
|
||||
@@ -281,6 +281,8 @@ class SpeculativeConfig:
|
||||
trust_remote_code,
|
||||
allowed_local_media_path=self.target_model_config.
|
||||
allowed_local_media_path,
|
||||
allowed_media_domains=self.target_model_config.
|
||||
allowed_media_domains,
|
||||
dtype=self.target_model_config.dtype,
|
||||
seed=self.target_model_config.seed,
|
||||
revision=self.revision,
|
||||
|
||||
@@ -297,6 +297,8 @@ class EngineArgs:
|
||||
tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
|
||||
trust_remote_code: bool = ModelConfig.trust_remote_code
|
||||
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
|
||||
allowed_media_domains: Optional[
|
||||
list[str]] = ModelConfig.allowed_media_domains
|
||||
download_dir: Optional[str] = LoadConfig.download_dir
|
||||
safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
|
||||
load_format: Union[str, LoadFormats] = LoadConfig.load_format
|
||||
@@ -531,6 +533,8 @@ class EngineArgs:
|
||||
**model_kwargs["hf_config_path"])
|
||||
model_group.add_argument("--allowed-local-media-path",
|
||||
**model_kwargs["allowed_local_media_path"])
|
||||
model_group.add_argument("--allowed-media-domains",
|
||||
**model_kwargs["allowed_media_domains"])
|
||||
model_group.add_argument("--revision", **model_kwargs["revision"])
|
||||
model_group.add_argument("--code-revision",
|
||||
**model_kwargs["code_revision"])
|
||||
@@ -997,6 +1001,7 @@ class EngineArgs:
|
||||
tokenizer_mode=self.tokenizer_mode,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
allowed_local_media_path=self.allowed_local_media_path,
|
||||
allowed_media_domains=self.allowed_media_domains,
|
||||
dtype=self.dtype,
|
||||
seed=self.seed,
|
||||
revision=self.revision,
|
||||
|
||||
@@ -637,6 +637,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
def allowed_local_media_path(self):
|
||||
return self._model_config.allowed_local_media_path
|
||||
|
||||
@property
|
||||
def allowed_media_domains(self):
|
||||
return self._model_config.allowed_media_domains
|
||||
|
||||
@property
|
||||
def mm_registry(self):
|
||||
return MULTIMODAL_REGISTRY
|
||||
@@ -837,6 +841,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
self._connector = MediaConnector(
|
||||
media_io_kwargs=media_io_kwargs,
|
||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||
allowed_media_domains=tracker.allowed_media_domains,
|
||||
)
|
||||
|
||||
def parse_image(
|
||||
@@ -921,6 +926,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
self._connector = MediaConnector(
|
||||
media_io_kwargs=media_io_kwargs,
|
||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||
allowed_media_domains=tracker.allowed_media_domains,
|
||||
)
|
||||
|
||||
def parse_image(
|
||||
|
||||
@@ -86,6 +86,8 @@ class LLM:
|
||||
or videos from directories specified by the server file system.
|
||||
This is a security risk. Should only be enabled in trusted
|
||||
environments.
|
||||
allowed_media_domains: If set, only media URLs that belong to this
|
||||
domain can be used for multi-modal inputs.
|
||||
tensor_parallel_size: The number of GPUs to use for distributed
|
||||
execution with tensor parallelism.
|
||||
dtype: The data type for the model weights and activations. Currently,
|
||||
@@ -169,6 +171,7 @@ class LLM:
|
||||
skip_tokenizer_init: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
allowed_local_media_path: str = "",
|
||||
allowed_media_domains: Optional[list[str]] = None,
|
||||
tensor_parallel_size: int = 1,
|
||||
dtype: ModelDType = "auto",
|
||||
quantization: Optional[QuantizationMethods] = None,
|
||||
@@ -264,6 +267,7 @@ class LLM:
|
||||
skip_tokenizer_init=skip_tokenizer_init,
|
||||
trust_remote_code=trust_remote_code,
|
||||
allowed_local_media_path=allowed_local_media_path,
|
||||
allowed_media_domains=allowed_media_domains,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
dtype=dtype,
|
||||
quantization=quantization,
|
||||
|
||||
@@ -50,6 +50,7 @@ class MediaConnector:
|
||||
connection: HTTPConnection = global_http_connection,
|
||||
*,
|
||||
allowed_local_media_path: str = "",
|
||||
allowed_media_domains: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
@@ -82,6 +83,9 @@ class MediaConnector:
|
||||
allowed_local_media_path_ = None
|
||||
|
||||
self.allowed_local_media_path = allowed_local_media_path_
|
||||
if allowed_media_domains is None:
|
||||
allowed_media_domains = []
|
||||
self.allowed_media_domains = allowed_media_domains
|
||||
|
||||
def _load_data_url(
|
||||
self,
|
||||
@@ -115,6 +119,14 @@ class MediaConnector:
|
||||
|
||||
return media_io.load_file(filepath)
|
||||
|
||||
def _assert_url_in_allowed_media_domains(self, url_spec) -> None:
|
||||
if self.allowed_media_domains and url_spec.hostname not in \
|
||||
self.allowed_media_domains:
|
||||
raise ValueError(
|
||||
f"The URL must be from one of the allowed domains: "
|
||||
f"{self.allowed_media_domains}. Input URL domain: "
|
||||
f"{url_spec.hostname}")
|
||||
|
||||
def load_from_url(
|
||||
self,
|
||||
url: str,
|
||||
@@ -125,6 +137,8 @@ class MediaConnector:
|
||||
url_spec = urlparse(url)
|
||||
|
||||
if url_spec.scheme.startswith("http"):
|
||||
self._assert_url_in_allowed_media_domains(url_spec)
|
||||
|
||||
connection = self.connection
|
||||
data = connection.get_bytes(url, timeout=fetch_timeout)
|
||||
|
||||
@@ -150,6 +164,8 @@ class MediaConnector:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
if url_spec.scheme.startswith("http"):
|
||||
self._assert_url_in_allowed_media_domains(url_spec)
|
||||
|
||||
connection = self.connection
|
||||
data = await connection.async_get_bytes(url, timeout=fetch_timeout)
|
||||
future = loop.run_in_executor(global_thread_pool,
|
||||
|
||||
Reference in New Issue
Block a user