feat: expose media_io_kwargs at runtime (#34778)

Signed-off-by: Alexandre Milesi <milesial@users.noreply.github.com>
This commit is contained in:
milesial
2026-03-06 20:27:04 -08:00
committed by GitHub
parent 58928475e4
commit 755356b3d1
20 changed files with 298 additions and 16 deletions

View File

@@ -35,6 +35,8 @@ def server():
"--trust-remote-code",
"--limit-mm-per-prompt",
json.dumps({"video": MAXIMUM_VIDEOS}),
"--media-io-kwargs",
json.dumps({"video": {"num_frames": 32}}),
]
# ROCm: Increase timeouts to handle potential network delays and slower
@@ -127,6 +129,73 @@ async def test_single_chat_session_video(
assert message.content is not None and len(message.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("video_url", [TEST_VIDEO_URLS[0]])
async def test_request_media_io_kwargs_override_uses_fewer_video_frames(
client: openai.AsyncOpenAI, model_name: str, video_url: str
):
messages = dummy_messages_from_video_url(video_url)
default_resp = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=1,
temperature=0.0,
)
override_resp = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=1,
temperature=0.0,
extra_body={
"media_io_kwargs": {
"video": {
"num_frames": 4,
}
}
},
)
assert default_resp.usage is not None
assert override_resp.usage is not None
assert override_resp.usage.prompt_tokens < default_resp.usage.prompt_tokens
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("video_url", [TEST_VIDEO_URLS[0]])
async def test_invalid_num_frames_request_recoverable(
client: openai.AsyncOpenAI, model_name: str, video_url: str
):
messages = dummy_messages_from_video_url(video_url)
with pytest.raises((openai.BadRequestError, openai.APIStatusError)):
await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=1,
temperature=0.0,
extra_body={
"media_io_kwargs": {
"video": {
"num_frames": "invalid",
}
}
},
)
# Server should still handle subsequent requests after the failed one.
recovery_resp = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=1,
temperature=0.0,
)
recovery_msg = recovery_resp.choices[0].message
assert recovery_msg.content is not None and len(recovery_msg.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)

View File

@@ -127,6 +127,39 @@ def test_chat_image_base64_request(server: RemoteOpenAIServer, model_name: str):
assert output.usage.prompt_tokens == 767
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_chat_image_with_media_io_kwargs(server: RemoteOpenAIServer, model_name: str):
rgba_image_url = (
"https://vllm-public-assets.s3.us-west-2.amazonaws.com"
"/vision_model_images/RGBA_comp.png"
)
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Represent the user's input."},
{"type": "image_url", "image_url": {"url": rgba_image_url}},
],
}
]
response = requests.post(
server.url_for("v1/embeddings"),
json={
"model": model_name,
"messages": messages,
"media_io_kwargs": {
"image": {"rgba_background_color": [0, 0, 0]},
},
},
)
response.raise_for_status()
output = EmbeddingResponse.model_validate(response.json())
assert len(output.data) == 1
assert len(output.data[0].embedding) == 3072
def get_hf_prompt_tokens(model_name, content, image_url):
processor = AutoProcessor.from_pretrained(
model_name, trust_remote_code=True, num_crops=4

View File

@@ -462,10 +462,15 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
maximum per prompt.
"""
def __init__(self, model_config: ModelConfig):
def __init__(
self,
model_config: ModelConfig,
media_io_kwargs: dict[str, dict[str, Any]] | None = None,
):
super().__init__()
self._model_config = model_config
self._media_io_kwargs = media_io_kwargs
self._items_by_modality = defaultdict[str, list[_T]](list)
# Track original modality for each vision_chunk item (image or video)
@@ -487,6 +492,14 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
model_cls = get_model_cls(self.model_config)
return cast(type[SupportsMultiModal], model_cls)
@property
def media_io_kwargs(self) -> dict[str, dict[str, Any]] | None:
return self._media_io_kwargs or (
self._model_config.multimodal_config.media_io_kwargs
if self._model_config.multimodal_config
else None
)
@property
def allowed_local_media_path(self):
return self._model_config.allowed_local_media_path
@@ -769,12 +782,10 @@ class MultiModalContentParser(BaseMultiModalContentParser):
super().__init__()
self._tracker = tracker
multimodal_config = self._tracker.model_config.multimodal_config
media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
envs.VLLM_MEDIA_CONNECTOR,
media_io_kwargs=media_io_kwargs,
media_io_kwargs=tracker.media_io_kwargs,
allowed_local_media_path=tracker.allowed_local_media_path,
allowed_media_domains=tracker.allowed_media_domains,
)
@@ -881,11 +892,9 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
super().__init__()
self._tracker = tracker
multimodal_config = self._tracker.model_config.multimodal_config
media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
envs.VLLM_MEDIA_CONNECTOR,
media_io_kwargs=media_io_kwargs,
media_io_kwargs=tracker.media_io_kwargs,
allowed_local_media_path=tracker.allowed_local_media_path,
allowed_media_domains=tracker.allowed_media_domains,
)
@@ -1530,13 +1539,14 @@ def parse_chat_messages(
messages: list[ChatCompletionMessageParam],
model_config: ModelConfig,
content_format: ChatTemplateContentFormat,
media_io_kwargs: dict[str, dict[str, Any]] | None = None,
) -> tuple[
list[ConversationMessage],
MultiModalDataDict | None,
MultiModalUUIDDict | None,
]:
conversation: list[ConversationMessage] = []
mm_tracker = MultiModalItemTracker(model_config)
mm_tracker = MultiModalItemTracker(model_config, media_io_kwargs=media_io_kwargs)
for msg in messages:
sub_messages = _parse_chat_message_content(
@@ -1563,13 +1573,16 @@ async def parse_chat_messages_async(
messages: list[ChatCompletionMessageParam],
model_config: ModelConfig,
content_format: ChatTemplateContentFormat,
media_io_kwargs: dict[str, dict[str, Any]] | None = None,
) -> tuple[
list[ConversationMessage],
MultiModalDataDict | None,
MultiModalUUIDDict | None,
]:
conversation: list[ConversationMessage] = []
mm_tracker = AsyncMultiModalItemTracker(model_config)
mm_tracker = AsyncMultiModalItemTracker(
model_config, media_io_kwargs=media_io_kwargs
)
for msg in messages:
sub_messages = _parse_chat_message_content(

View File

@@ -268,6 +268,13 @@ class ChatCompletionRequest(OpenAIBaseModel):
"Will be accessible by the chat template."
),
)
media_io_kwargs: dict[str, dict[str, Any]] | None = Field(
default=None,
description=(
"Additional kwargs to pass to the media IO connectors, "
"keyed by modality. Merged with engine-level media_io_kwargs."
),
)
mm_processor_kwargs: dict[str, Any] | None = Field(
default=None,
description=("Additional kwargs to pass to the HF processor."),
@@ -366,6 +373,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
reasoning_effort=self.reasoning_effort,
),
),
media_io_kwargs=self.media_io_kwargs,
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:

View File

@@ -900,10 +900,15 @@ class OpenAIServing:
),
)
mm_config = self.model_config.multimodal_config
tok_params = request.build_tok_params(self.model_config)
chat_params = request.build_chat_params(
default_template, default_template_content_format
).with_defaults(default_template_kwargs)
).with_defaults(
default_template_kwargs,
default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None),
)
(conversation,), (engine_prompt,) = await renderer.render_chat_async(
[messages],

View File

@@ -197,6 +197,13 @@ class ResponsesRequest(OpenAIBaseModel):
"through out the inference process and return in response."
),
)
media_io_kwargs: dict[str, dict[str, Any]] | None = Field(
default=None,
description=(
"Additional kwargs to pass to the media IO connectors, "
"keyed by modality. Merged with engine-level media_io_kwargs."
),
)
mm_processor_kwargs: dict[str, Any] | None = Field(
default=None,
description=("Additional kwargs to pass to the HF processor."),
@@ -276,6 +283,7 @@ class ResponsesRequest(OpenAIBaseModel):
reasoning_effort=None if reasoning is None else reasoning.effort,
),
),
media_io_kwargs=self.media_io_kwargs,
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:

View File

@@ -123,10 +123,15 @@ class PoolingIOProcessor:
),
)
mm_config = self.model_config.multimodal_config
tok_params = request.build_tok_params(self.model_config)
chat_params = request.build_chat_params(
default_template, default_template_content_format
).with_defaults(default_template_kwargs)
).with_defaults(
default_template_kwargs,
default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None),
)
(conversation,), (engine_prompt,) = renderer.render_chat(
[messages],

View File

@@ -124,6 +124,13 @@ class ChatRequestMixin(OpenAIBaseModel):
"Will be accessible by the chat template."
),
)
media_io_kwargs: dict[str, dict[str, Any]] | None = Field(
default=None,
description=(
"Additional kwargs to pass to the media IO connectors, "
"keyed by modality. Merged with engine-level media_io_kwargs."
),
)
# --8<-- [end:chat-extra-params]
@model_validator(mode="before")
@@ -151,6 +158,7 @@ class ChatRequestMixin(OpenAIBaseModel):
continue_final_message=self.continue_final_message,
),
),
media_io_kwargs=self.media_io_kwargs,
)

View File

@@ -100,6 +100,13 @@ class TokenizeChatRequest(OpenAIBaseModel):
"Will be accessible by the chat template."
),
)
media_io_kwargs: dict[str, dict[str, Any]] | None = Field(
default=None,
description=(
"Additional kwargs to pass to the media IO connectors, "
"keyed by modality. Merged with engine-level media_io_kwargs."
),
)
mm_processor_kwargs: dict[str, Any] | None = Field(
default=None,
description="Additional kwargs to pass to the HF processor.",
@@ -134,6 +141,7 @@ class TokenizeChatRequest(OpenAIBaseModel):
continue_final_message=self.continue_final_message,
),
),
media_io_kwargs=self.media_io_kwargs,
)
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:

View File

@@ -83,11 +83,17 @@ def extract_audio_from_video_bytes(
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
"""Configuration values can be user-provided either by --media-io-kwargs or
by the runtime API field "media_io_kwargs". Ensure proper validation and
error handling.
"""
def __init__(self, **kwargs) -> None:
super().__init__()
# `kwargs` contains custom arguments from
# --media-io-kwargs for this modality.
# --media-io-kwargs for this modality, merged with
# per-request runtime media_io_kwargs via merge_kwargs().
# They can be passed to the underlying
# media loaders (e.g. custom implementations)
# for flexible control.
@@ -122,6 +128,11 @@ class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
class AudioEmbeddingMediaIO(MediaIO[torch.Tensor]):
"""Configuration values can be user-provided either by --media-io-kwargs or
by the runtime API field "media_io_kwargs". Ensure proper validation and
error handling.
"""
def __init__(self) -> None:
super().__init__()

View File

@@ -44,6 +44,28 @@ class MediaWithBytes(Generic[_T]):
class MediaIO(ABC, Generic[_T]):
"""Configuration values can be user-provided either by --media-io-kwargs or
by the runtime API field "media_io_kwargs". Ensure proper validation and
error handling.
"""
@classmethod
def merge_kwargs(
cls,
default_kwargs: dict[str, Any] | None,
runtime_kwargs: dict[str, Any] | None,
) -> dict[str, Any]:
"""Merge config-level kwargs and request-level kwargs.
By default this performs a shallow merge where runtime kwargs override
keys in default kwargs. Subclasses may override to apply modality-
specific behavior.
"""
merged = dict(default_kwargs or {})
if runtime_kwargs:
merged.update(runtime_kwargs)
return merged
@abstractmethod
def load_bytes(self, data: bytes) -> _T:
raise NotImplementedError

View File

@@ -32,9 +32,43 @@ atexit.register(global_thread_pool.shutdown)
MEDIA_CONNECTOR_REGISTRY = ExtensionManager()
MODALITY_IO_MAP: dict[str, type[MediaIO]] = {
"audio": AudioMediaIO,
"image": ImageMediaIO,
"video": VideoMediaIO,
}
def merge_media_io_kwargs(
defaults: dict[str, dict[str, Any]] | None,
overrides: dict[str, dict[str, Any]] | None,
) -> dict[str, dict[str, Any]] | None:
"""Merge config-level and per-request media_io_kwargs per modality.
Each modality key is merged using the corresponding MediaIO subclass's
``merge_kwargs``, which may apply modality-specific logic (e.g.
VideoMediaIO clears cross-dependent fps/num_frames fields).
"""
if not defaults and not overrides:
return None
all_keys = set(defaults or {}) | set(overrides or {})
merged = {}
for key in all_keys:
io_cls = MODALITY_IO_MAP.get(key, MediaIO)
merged[key] = io_cls.merge_kwargs(
(defaults or {}).get(key),
(overrides or {}).get(key),
)
return merged or None
@MEDIA_CONNECTOR_REGISTRY.register("http")
class MediaConnector:
"""Configuration values can be user-provided either by --media-io-kwargs or
by the runtime API field "media_io_kwargs". Ensure proper validation and
error handling.
"""
def __init__(
self,
media_io_kwargs: dict[str, dict[str, Any]] | None = None,

View File

@@ -15,12 +15,18 @@ from .base import MediaIO, MediaWithBytes
class ImageMediaIO(MediaIO[Image.Image]):
"""Configuration values can be user-provided either by --media-io-kwargs or
by the runtime API field "media_io_kwargs". Ensure proper validation and
error handling.
"""
def __init__(self, image_mode: str = "RGB", **kwargs) -> None:
super().__init__()
self.image_mode = image_mode
# `kwargs` contains custom arguments from
# --media-io-kwargs for this modality.
# --media-io-kwargs for this modality, merged with
# per-request runtime media_io_kwargs via merge_kwargs().
# They can be passed to the underlying
# media loaders (e.g. custom implementations)
# for flexible control.
@@ -88,6 +94,13 @@ class ImageMediaIO(MediaIO[Image.Image]):
class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):
"""Image embedding MediaIO implementation.
Configuration values can be user-provided either by --media-io-kwargs or
by the runtime API field "media_io_kwargs". Ensure proper validation and
error handling.
"""
def __init__(self) -> None:
super().__init__()

View File

@@ -17,6 +17,28 @@ from .image import ImageMediaIO
class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]):
"""Configuration values can be user-provided either by --media-io-kwargs or
by the runtime API field "media_io_kwargs". Ensure proper validation and
error handling.
"""
@classmethod
def merge_kwargs(
cls,
default_kwargs: dict[str, Any] | None,
runtime_kwargs: dict[str, Any] | None,
) -> dict[str, Any]:
merged = super().merge_kwargs(default_kwargs, runtime_kwargs)
# fps and num_frames interact with each other, so if either is
# overridden at request time, wipe the other from defaults to
# avoid unintuitive cross-field interactions.
if runtime_kwargs:
if "num_frames" in runtime_kwargs and "fps" not in runtime_kwargs:
merged.pop("fps", None)
elif "fps" in runtime_kwargs and "num_frames" not in runtime_kwargs:
merged.pop("num_frames", None)
return merged
def __init__(
self,
image_io: ImageMediaIO,
@@ -28,7 +50,8 @@ class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]):
self.image_io = image_io
self.num_frames = num_frames
# `kwargs` contains custom arguments from
# --media-io-kwargs for this modality.
# --media-io-kwargs for this modality, merged with
# per-request runtime media_io_kwargs via merge_kwargs().
# They can be passed to the underlying
# media loaders (e.g. custom implementations)
# for flexible control.

View File

@@ -49,6 +49,7 @@ class DeepseekV32Renderer(BaseRenderer[DeepseekV32Tokenizer]):
messages,
self.model_config,
content_format="string",
media_io_kwargs=params.media_io_kwargs,
)
prompt_raw = tokenizer.apply_chat_template(
@@ -75,6 +76,7 @@ class DeepseekV32Renderer(BaseRenderer[DeepseekV32Tokenizer]):
messages,
self.model_config,
content_format="string",
media_io_kwargs=params.media_io_kwargs,
)
prompt_raw = tokenizer.apply_chat_template(

View File

@@ -49,6 +49,7 @@ class Grok2Renderer(BaseRenderer[Grok2Tokenizer]):
messages,
self.model_config,
content_format="string",
media_io_kwargs=params.media_io_kwargs,
)
prompt_raw = tokenizer.apply_chat_template(
@@ -75,6 +76,7 @@ class Grok2Renderer(BaseRenderer[Grok2Tokenizer]):
messages,
self.model_config,
content_format="string",
media_io_kwargs=params.media_io_kwargs,
)
prompt_raw = tokenizer.apply_chat_template(

View File

@@ -635,6 +635,7 @@ class HfRenderer(BaseRenderer[HfTokenizer]):
tokenizer=tokenizer,
model_config=model_config,
),
media_io_kwargs=params.media_io_kwargs,
)
prompt_raw = safe_apply_chat_template(
@@ -689,6 +690,7 @@ class HfRenderer(BaseRenderer[HfTokenizer]):
tokenizer=tokenizer,
model_config=model_config,
),
media_io_kwargs=params.media_io_kwargs,
)
prompt_raw = safe_apply_chat_template(

View File

@@ -90,6 +90,7 @@ class MistralRenderer(BaseRenderer[MistralTokenizer]):
messages,
self.model_config,
content_format="string",
media_io_kwargs=params.media_io_kwargs,
)
prompt_raw = safe_apply_chat_template(
@@ -116,6 +117,7 @@ class MistralRenderer(BaseRenderer[MistralTokenizer]):
messages,
self.model_config,
content_format="string",
media_io_kwargs=params.media_io_kwargs,
)
prompt_raw = await self._apply_chat_template_async(

View File

@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, TypeVar
from vllm.exceptions import VLLMValidationError
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.multimodal.media.connector import merge_media_io_kwargs
from vllm.tokenizers import TokenizerLike
from vllm.utils.import_utils import LazyLoader
@@ -52,8 +53,15 @@ class ChatParams:
chat_template_kwargs: dict[str, Any] = field(default_factory=dict)
"""The kwargs to pass to the chat template."""
def with_defaults(self, default_chat_template_kwargs: dict[str, Any] | None):
if not default_chat_template_kwargs:
media_io_kwargs: dict[str, dict[str, Any]] | None = None
"""Per-modality kwargs for media I/O (loading/decoding images, videos, etc.)."""
def with_defaults(
self,
default_chat_template_kwargs: dict[str, Any] | None = None,
default_media_io_kwargs: dict[str, dict[str, Any]] | None = None,
):
if not default_chat_template_kwargs and not default_media_io_kwargs:
return self
return ChatParams(
@@ -63,6 +71,10 @@ class ChatParams:
default_chat_template_kwargs,
self.chat_template_kwargs,
),
media_io_kwargs=merge_media_io_kwargs(
default_media_io_kwargs,
self.media_io_kwargs,
),
)
def get_apply_chat_template_kwargs(self) -> dict[str, Any]:

View File

@@ -43,6 +43,7 @@ class TerratorchRenderer(BaseRenderer):
messages,
model_config,
content_format="string",
media_io_kwargs=params.media_io_kwargs,
)
prompt = parse_dec_only_prompt([1]) # Dummy token IDs
@@ -64,6 +65,7 @@ class TerratorchRenderer(BaseRenderer):
messages,
model_config,
content_format="string",
media_io_kwargs=params.media_io_kwargs,
)
prompt = parse_dec_only_prompt([1]) # Dummy token IDs