feat: expose media_io_kwargs at runtime (#34778)
Signed-off-by: Alexandre Milesi <milesial@users.noreply.github.com>
This commit is contained in:
@@ -35,6 +35,8 @@ def server():
|
||||
"--trust-remote-code",
|
||||
"--limit-mm-per-prompt",
|
||||
json.dumps({"video": MAXIMUM_VIDEOS}),
|
||||
"--media-io-kwargs",
|
||||
json.dumps({"video": {"num_frames": 32}}),
|
||||
]
|
||||
|
||||
# ROCm: Increase timeouts to handle potential network delays and slower
|
||||
@@ -127,6 +129,73 @@ async def test_single_chat_session_video(
|
||||
assert message.content is not None and len(message.content) >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("video_url", [TEST_VIDEO_URLS[0]])
|
||||
async def test_request_media_io_kwargs_override_uses_fewer_video_frames(
|
||||
client: openai.AsyncOpenAI, model_name: str, video_url: str
|
||||
):
|
||||
messages = dummy_messages_from_video_url(video_url)
|
||||
|
||||
default_resp = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=1,
|
||||
temperature=0.0,
|
||||
)
|
||||
override_resp = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=1,
|
||||
temperature=0.0,
|
||||
extra_body={
|
||||
"media_io_kwargs": {
|
||||
"video": {
|
||||
"num_frames": 4,
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert default_resp.usage is not None
|
||||
assert override_resp.usage is not None
|
||||
assert override_resp.usage.prompt_tokens < default_resp.usage.prompt_tokens
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("video_url", [TEST_VIDEO_URLS[0]])
|
||||
async def test_invalid_num_frames_request_recoverable(
|
||||
client: openai.AsyncOpenAI, model_name: str, video_url: str
|
||||
):
|
||||
messages = dummy_messages_from_video_url(video_url)
|
||||
|
||||
with pytest.raises((openai.BadRequestError, openai.APIStatusError)):
|
||||
await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=1,
|
||||
temperature=0.0,
|
||||
extra_body={
|
||||
"media_io_kwargs": {
|
||||
"video": {
|
||||
"num_frames": "invalid",
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Server should still handle subsequent requests after the failed one.
|
||||
recovery_resp = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=1,
|
||||
temperature=0.0,
|
||||
)
|
||||
recovery_msg = recovery_resp.choices[0].message
|
||||
assert recovery_msg.content is not None and len(recovery_msg.content) >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||
|
||||
@@ -127,6 +127,39 @@ def test_chat_image_base64_request(server: RemoteOpenAIServer, model_name: str):
|
||||
assert output.usage.prompt_tokens == 767
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_chat_image_with_media_io_kwargs(server: RemoteOpenAIServer, model_name: str):
|
||||
rgba_image_url = (
|
||||
"https://vllm-public-assets.s3.us-west-2.amazonaws.com"
|
||||
"/vision_model_images/RGBA_comp.png"
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Represent the user's input."},
|
||||
{"type": "image_url", "image_url": {"url": rgba_image_url}},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
response = requests.post(
|
||||
server.url_for("v1/embeddings"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"messages": messages,
|
||||
"media_io_kwargs": {
|
||||
"image": {"rgba_background_color": [0, 0, 0]},
|
||||
},
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
output = EmbeddingResponse.model_validate(response.json())
|
||||
assert len(output.data) == 1
|
||||
assert len(output.data[0].embedding) == 3072
|
||||
|
||||
|
||||
def get_hf_prompt_tokens(model_name, content, image_url):
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_name, trust_remote_code=True, num_crops=4
|
||||
|
||||
@@ -462,10 +462,15 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
maximum per prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: ModelConfig):
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
media_io_kwargs: dict[str, dict[str, Any]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._model_config = model_config
|
||||
self._media_io_kwargs = media_io_kwargs
|
||||
|
||||
self._items_by_modality = defaultdict[str, list[_T]](list)
|
||||
# Track original modality for each vision_chunk item (image or video)
|
||||
@@ -487,6 +492,14 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
model_cls = get_model_cls(self.model_config)
|
||||
return cast(type[SupportsMultiModal], model_cls)
|
||||
|
||||
@property
|
||||
def media_io_kwargs(self) -> dict[str, dict[str, Any]] | None:
|
||||
return self._media_io_kwargs or (
|
||||
self._model_config.multimodal_config.media_io_kwargs
|
||||
if self._model_config.multimodal_config
|
||||
else None
|
||||
)
|
||||
|
||||
@property
|
||||
def allowed_local_media_path(self):
|
||||
return self._model_config.allowed_local_media_path
|
||||
@@ -769,12 +782,10 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
super().__init__()
|
||||
|
||||
self._tracker = tracker
|
||||
multimodal_config = self._tracker.model_config.multimodal_config
|
||||
media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
|
||||
|
||||
self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
|
||||
envs.VLLM_MEDIA_CONNECTOR,
|
||||
media_io_kwargs=media_io_kwargs,
|
||||
media_io_kwargs=tracker.media_io_kwargs,
|
||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||
allowed_media_domains=tracker.allowed_media_domains,
|
||||
)
|
||||
@@ -881,11 +892,9 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
super().__init__()
|
||||
|
||||
self._tracker = tracker
|
||||
multimodal_config = self._tracker.model_config.multimodal_config
|
||||
media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
|
||||
self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
|
||||
envs.VLLM_MEDIA_CONNECTOR,
|
||||
media_io_kwargs=media_io_kwargs,
|
||||
media_io_kwargs=tracker.media_io_kwargs,
|
||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||
allowed_media_domains=tracker.allowed_media_domains,
|
||||
)
|
||||
@@ -1530,13 +1539,14 @@ def parse_chat_messages(
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
model_config: ModelConfig,
|
||||
content_format: ChatTemplateContentFormat,
|
||||
media_io_kwargs: dict[str, dict[str, Any]] | None = None,
|
||||
) -> tuple[
|
||||
list[ConversationMessage],
|
||||
MultiModalDataDict | None,
|
||||
MultiModalUUIDDict | None,
|
||||
]:
|
||||
conversation: list[ConversationMessage] = []
|
||||
mm_tracker = MultiModalItemTracker(model_config)
|
||||
mm_tracker = MultiModalItemTracker(model_config, media_io_kwargs=media_io_kwargs)
|
||||
|
||||
for msg in messages:
|
||||
sub_messages = _parse_chat_message_content(
|
||||
@@ -1563,13 +1573,16 @@ async def parse_chat_messages_async(
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
model_config: ModelConfig,
|
||||
content_format: ChatTemplateContentFormat,
|
||||
media_io_kwargs: dict[str, dict[str, Any]] | None = None,
|
||||
) -> tuple[
|
||||
list[ConversationMessage],
|
||||
MultiModalDataDict | None,
|
||||
MultiModalUUIDDict | None,
|
||||
]:
|
||||
conversation: list[ConversationMessage] = []
|
||||
mm_tracker = AsyncMultiModalItemTracker(model_config)
|
||||
mm_tracker = AsyncMultiModalItemTracker(
|
||||
model_config, media_io_kwargs=media_io_kwargs
|
||||
)
|
||||
|
||||
for msg in messages:
|
||||
sub_messages = _parse_chat_message_content(
|
||||
|
||||
@@ -268,6 +268,13 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
"Will be accessible by the chat template."
|
||||
),
|
||||
)
|
||||
media_io_kwargs: dict[str, dict[str, Any]] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Additional kwargs to pass to the media IO connectors, "
|
||||
"keyed by modality. Merged with engine-level media_io_kwargs."
|
||||
),
|
||||
)
|
||||
mm_processor_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=("Additional kwargs to pass to the HF processor."),
|
||||
@@ -366,6 +373,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
reasoning_effort=self.reasoning_effort,
|
||||
),
|
||||
),
|
||||
media_io_kwargs=self.media_io_kwargs,
|
||||
)
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
|
||||
@@ -900,10 +900,15 @@ class OpenAIServing:
|
||||
),
|
||||
)
|
||||
|
||||
mm_config = self.model_config.multimodal_config
|
||||
|
||||
tok_params = request.build_tok_params(self.model_config)
|
||||
chat_params = request.build_chat_params(
|
||||
default_template, default_template_content_format
|
||||
).with_defaults(default_template_kwargs)
|
||||
).with_defaults(
|
||||
default_template_kwargs,
|
||||
default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None),
|
||||
)
|
||||
|
||||
(conversation,), (engine_prompt,) = await renderer.render_chat_async(
|
||||
[messages],
|
||||
|
||||
@@ -197,6 +197,13 @@ class ResponsesRequest(OpenAIBaseModel):
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
media_io_kwargs: dict[str, dict[str, Any]] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Additional kwargs to pass to the media IO connectors, "
|
||||
"keyed by modality. Merged with engine-level media_io_kwargs."
|
||||
),
|
||||
)
|
||||
mm_processor_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=("Additional kwargs to pass to the HF processor."),
|
||||
@@ -276,6 +283,7 @@ class ResponsesRequest(OpenAIBaseModel):
|
||||
reasoning_effort=None if reasoning is None else reasoning.effort,
|
||||
),
|
||||
),
|
||||
media_io_kwargs=self.media_io_kwargs,
|
||||
)
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
|
||||
@@ -123,10 +123,15 @@ class PoolingIOProcessor:
|
||||
),
|
||||
)
|
||||
|
||||
mm_config = self.model_config.multimodal_config
|
||||
|
||||
tok_params = request.build_tok_params(self.model_config)
|
||||
chat_params = request.build_chat_params(
|
||||
default_template, default_template_content_format
|
||||
).with_defaults(default_template_kwargs)
|
||||
).with_defaults(
|
||||
default_template_kwargs,
|
||||
default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None),
|
||||
)
|
||||
|
||||
(conversation,), (engine_prompt,) = renderer.render_chat(
|
||||
[messages],
|
||||
|
||||
@@ -124,6 +124,13 @@ class ChatRequestMixin(OpenAIBaseModel):
|
||||
"Will be accessible by the chat template."
|
||||
),
|
||||
)
|
||||
media_io_kwargs: dict[str, dict[str, Any]] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Additional kwargs to pass to the media IO connectors, "
|
||||
"keyed by modality. Merged with engine-level media_io_kwargs."
|
||||
),
|
||||
)
|
||||
# --8<-- [end:chat-extra-params]
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -151,6 +158,7 @@ class ChatRequestMixin(OpenAIBaseModel):
|
||||
continue_final_message=self.continue_final_message,
|
||||
),
|
||||
),
|
||||
media_io_kwargs=self.media_io_kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -100,6 +100,13 @@ class TokenizeChatRequest(OpenAIBaseModel):
|
||||
"Will be accessible by the chat template."
|
||||
),
|
||||
)
|
||||
media_io_kwargs: dict[str, dict[str, Any]] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Additional kwargs to pass to the media IO connectors, "
|
||||
"keyed by modality. Merged with engine-level media_io_kwargs."
|
||||
),
|
||||
)
|
||||
mm_processor_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Additional kwargs to pass to the HF processor.",
|
||||
@@ -134,6 +141,7 @@ class TokenizeChatRequest(OpenAIBaseModel):
|
||||
continue_final_message=self.continue_final_message,
|
||||
),
|
||||
),
|
||||
media_io_kwargs=self.media_io_kwargs,
|
||||
)
|
||||
|
||||
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
|
||||
|
||||
@@ -83,11 +83,17 @@ def extract_audio_from_video_bytes(
|
||||
|
||||
|
||||
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
|
||||
"""Configuration values can be user-provided either by --media-io-kwargs or
|
||||
by the runtime API field "media_io_kwargs". Ensure proper validation and
|
||||
error handling.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
# `kwargs` contains custom arguments from
|
||||
# --media-io-kwargs for this modality.
|
||||
# --media-io-kwargs for this modality, merged with
|
||||
# per-request runtime media_io_kwargs via merge_kwargs().
|
||||
# They can be passed to the underlying
|
||||
# media loaders (e.g. custom implementations)
|
||||
# for flexible control.
|
||||
@@ -122,6 +128,11 @@ class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
|
||||
|
||||
|
||||
class AudioEmbeddingMediaIO(MediaIO[torch.Tensor]):
|
||||
"""Configuration values can be user-provided either by --media-io-kwargs or
|
||||
by the runtime API field "media_io_kwargs". Ensure proper validation and
|
||||
error handling.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -44,6 +44,28 @@ class MediaWithBytes(Generic[_T]):
|
||||
|
||||
|
||||
class MediaIO(ABC, Generic[_T]):
|
||||
"""Configuration values can be user-provided either by --media-io-kwargs or
|
||||
by the runtime API field "media_io_kwargs". Ensure proper validation and
|
||||
error handling.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def merge_kwargs(
|
||||
cls,
|
||||
default_kwargs: dict[str, Any] | None,
|
||||
runtime_kwargs: dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Merge config-level kwargs and request-level kwargs.
|
||||
|
||||
By default this performs a shallow merge where runtime kwargs override
|
||||
keys in default kwargs. Subclasses may override to apply modality-
|
||||
specific behavior.
|
||||
"""
|
||||
merged = dict(default_kwargs or {})
|
||||
if runtime_kwargs:
|
||||
merged.update(runtime_kwargs)
|
||||
return merged
|
||||
|
||||
@abstractmethod
|
||||
def load_bytes(self, data: bytes) -> _T:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -32,9 +32,43 @@ atexit.register(global_thread_pool.shutdown)
|
||||
|
||||
MEDIA_CONNECTOR_REGISTRY = ExtensionManager()
|
||||
|
||||
MODALITY_IO_MAP: dict[str, type[MediaIO]] = {
|
||||
"audio": AudioMediaIO,
|
||||
"image": ImageMediaIO,
|
||||
"video": VideoMediaIO,
|
||||
}
|
||||
|
||||
|
||||
def merge_media_io_kwargs(
|
||||
defaults: dict[str, dict[str, Any]] | None,
|
||||
overrides: dict[str, dict[str, Any]] | None,
|
||||
) -> dict[str, dict[str, Any]] | None:
|
||||
"""Merge config-level and per-request media_io_kwargs per modality.
|
||||
|
||||
Each modality key is merged using the corresponding MediaIO subclass's
|
||||
``merge_kwargs``, which may apply modality-specific logic (e.g.
|
||||
VideoMediaIO clears cross-dependent fps/num_frames fields).
|
||||
"""
|
||||
if not defaults and not overrides:
|
||||
return None
|
||||
all_keys = set(defaults or {}) | set(overrides or {})
|
||||
merged = {}
|
||||
for key in all_keys:
|
||||
io_cls = MODALITY_IO_MAP.get(key, MediaIO)
|
||||
merged[key] = io_cls.merge_kwargs(
|
||||
(defaults or {}).get(key),
|
||||
(overrides or {}).get(key),
|
||||
)
|
||||
return merged or None
|
||||
|
||||
|
||||
@MEDIA_CONNECTOR_REGISTRY.register("http")
|
||||
class MediaConnector:
|
||||
"""Configuration values can be user-provided either by --media-io-kwargs or
|
||||
by the runtime API field "media_io_kwargs". Ensure proper validation and
|
||||
error handling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
media_io_kwargs: dict[str, dict[str, Any]] | None = None,
|
||||
|
||||
@@ -15,12 +15,18 @@ from .base import MediaIO, MediaWithBytes
|
||||
|
||||
|
||||
class ImageMediaIO(MediaIO[Image.Image]):
|
||||
"""Configuration values can be user-provided either by --media-io-kwargs or
|
||||
by the runtime API field "media_io_kwargs". Ensure proper validation and
|
||||
error handling.
|
||||
"""
|
||||
|
||||
def __init__(self, image_mode: str = "RGB", **kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.image_mode = image_mode
|
||||
# `kwargs` contains custom arguments from
|
||||
# --media-io-kwargs for this modality.
|
||||
# --media-io-kwargs for this modality, merged with
|
||||
# per-request runtime media_io_kwargs via merge_kwargs().
|
||||
# They can be passed to the underlying
|
||||
# media loaders (e.g. custom implementations)
|
||||
# for flexible control.
|
||||
@@ -88,6 +94,13 @@ class ImageMediaIO(MediaIO[Image.Image]):
|
||||
|
||||
|
||||
class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):
|
||||
"""Image embedding MediaIO implementation.
|
||||
|
||||
Configuration values can be user-provided either by --media-io-kwargs or
|
||||
by the runtime API field "media_io_kwargs". Ensure proper validation and
|
||||
error handling.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -17,6 +17,28 @@ from .image import ImageMediaIO
|
||||
|
||||
|
||||
class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]):
|
||||
"""Configuration values can be user-provided either by --media-io-kwargs or
|
||||
by the runtime API field "media_io_kwargs". Ensure proper validation and
|
||||
error handling.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def merge_kwargs(
|
||||
cls,
|
||||
default_kwargs: dict[str, Any] | None,
|
||||
runtime_kwargs: dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
merged = super().merge_kwargs(default_kwargs, runtime_kwargs)
|
||||
# fps and num_frames interact with each other, so if either is
|
||||
# overridden at request time, wipe the other from defaults to
|
||||
# avoid unintuitive cross-field interactions.
|
||||
if runtime_kwargs:
|
||||
if "num_frames" in runtime_kwargs and "fps" not in runtime_kwargs:
|
||||
merged.pop("fps", None)
|
||||
elif "fps" in runtime_kwargs and "num_frames" not in runtime_kwargs:
|
||||
merged.pop("num_frames", None)
|
||||
return merged
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_io: ImageMediaIO,
|
||||
@@ -28,7 +50,8 @@ class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]):
|
||||
self.image_io = image_io
|
||||
self.num_frames = num_frames
|
||||
# `kwargs` contains custom arguments from
|
||||
# --media-io-kwargs for this modality.
|
||||
# --media-io-kwargs for this modality, merged with
|
||||
# per-request runtime media_io_kwargs via merge_kwargs().
|
||||
# They can be passed to the underlying
|
||||
# media loaders (e.g. custom implementations)
|
||||
# for flexible control.
|
||||
|
||||
@@ -49,6 +49,7 @@ class DeepseekV32Renderer(BaseRenderer[DeepseekV32Tokenizer]):
|
||||
messages,
|
||||
self.model_config,
|
||||
content_format="string",
|
||||
media_io_kwargs=params.media_io_kwargs,
|
||||
)
|
||||
|
||||
prompt_raw = tokenizer.apply_chat_template(
|
||||
@@ -75,6 +76,7 @@ class DeepseekV32Renderer(BaseRenderer[DeepseekV32Tokenizer]):
|
||||
messages,
|
||||
self.model_config,
|
||||
content_format="string",
|
||||
media_io_kwargs=params.media_io_kwargs,
|
||||
)
|
||||
|
||||
prompt_raw = tokenizer.apply_chat_template(
|
||||
|
||||
@@ -49,6 +49,7 @@ class Grok2Renderer(BaseRenderer[Grok2Tokenizer]):
|
||||
messages,
|
||||
self.model_config,
|
||||
content_format="string",
|
||||
media_io_kwargs=params.media_io_kwargs,
|
||||
)
|
||||
|
||||
prompt_raw = tokenizer.apply_chat_template(
|
||||
@@ -75,6 +76,7 @@ class Grok2Renderer(BaseRenderer[Grok2Tokenizer]):
|
||||
messages,
|
||||
self.model_config,
|
||||
content_format="string",
|
||||
media_io_kwargs=params.media_io_kwargs,
|
||||
)
|
||||
|
||||
prompt_raw = tokenizer.apply_chat_template(
|
||||
|
||||
@@ -635,6 +635,7 @@ class HfRenderer(BaseRenderer[HfTokenizer]):
|
||||
tokenizer=tokenizer,
|
||||
model_config=model_config,
|
||||
),
|
||||
media_io_kwargs=params.media_io_kwargs,
|
||||
)
|
||||
|
||||
prompt_raw = safe_apply_chat_template(
|
||||
@@ -689,6 +690,7 @@ class HfRenderer(BaseRenderer[HfTokenizer]):
|
||||
tokenizer=tokenizer,
|
||||
model_config=model_config,
|
||||
),
|
||||
media_io_kwargs=params.media_io_kwargs,
|
||||
)
|
||||
|
||||
prompt_raw = safe_apply_chat_template(
|
||||
|
||||
@@ -90,6 +90,7 @@ class MistralRenderer(BaseRenderer[MistralTokenizer]):
|
||||
messages,
|
||||
self.model_config,
|
||||
content_format="string",
|
||||
media_io_kwargs=params.media_io_kwargs,
|
||||
)
|
||||
|
||||
prompt_raw = safe_apply_chat_template(
|
||||
@@ -116,6 +117,7 @@ class MistralRenderer(BaseRenderer[MistralTokenizer]):
|
||||
messages,
|
||||
self.model_config,
|
||||
content_format="string",
|
||||
media_io_kwargs=params.media_io_kwargs,
|
||||
)
|
||||
|
||||
prompt_raw = await self._apply_chat_template_async(
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, TypeVar
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal.media.connector import merge_media_io_kwargs
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
|
||||
@@ -52,8 +53,15 @@ class ChatParams:
|
||||
chat_template_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
"""The kwargs to pass to the chat template."""
|
||||
|
||||
def with_defaults(self, default_chat_template_kwargs: dict[str, Any] | None):
|
||||
if not default_chat_template_kwargs:
|
||||
media_io_kwargs: dict[str, dict[str, Any]] | None = None
|
||||
"""Per-modality kwargs for media I/O (loading/decoding images, videos, etc.)."""
|
||||
|
||||
def with_defaults(
|
||||
self,
|
||||
default_chat_template_kwargs: dict[str, Any] | None = None,
|
||||
default_media_io_kwargs: dict[str, dict[str, Any]] | None = None,
|
||||
):
|
||||
if not default_chat_template_kwargs and not default_media_io_kwargs:
|
||||
return self
|
||||
|
||||
return ChatParams(
|
||||
@@ -63,6 +71,10 @@ class ChatParams:
|
||||
default_chat_template_kwargs,
|
||||
self.chat_template_kwargs,
|
||||
),
|
||||
media_io_kwargs=merge_media_io_kwargs(
|
||||
default_media_io_kwargs,
|
||||
self.media_io_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
def get_apply_chat_template_kwargs(self) -> dict[str, Any]:
|
||||
|
||||
@@ -43,6 +43,7 @@ class TerratorchRenderer(BaseRenderer):
|
||||
messages,
|
||||
model_config,
|
||||
content_format="string",
|
||||
media_io_kwargs=params.media_io_kwargs,
|
||||
)
|
||||
|
||||
prompt = parse_dec_only_prompt([1]) # Dummy token IDs
|
||||
@@ -64,6 +65,7 @@ class TerratorchRenderer(BaseRenderer):
|
||||
messages,
|
||||
model_config,
|
||||
content_format="string",
|
||||
media_io_kwargs=params.media_io_kwargs,
|
||||
)
|
||||
|
||||
prompt = parse_dec_only_prompt([1]) # Dummy token IDs
|
||||
|
||||
Reference in New Issue
Block a user