Support online use_audio_in_video (#36319)

Signed-off-by: Tianyu Guo <guoty9@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Tianyu Guo
2026-03-09 22:16:44 +08:00
committed by GitHub
parent 3ec2115015
commit 5578f2a4d3
10 changed files with 152 additions and 10 deletions

View File

@@ -564,7 +564,9 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return self.model_cls.get_placeholder_str(modality, num_items) return self.model_cls.get_placeholder_str(modality, num_items)
@abstractmethod @abstractmethod
def create_parser(self) -> "BaseMultiModalContentParser": def create_parser(
self, mm_processor_kwargs: dict[str, Any] | None = None
) -> "BaseMultiModalContentParser":
raise NotImplementedError raise NotImplementedError
@@ -690,8 +692,10 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[tuple[object, str | None]]
dict(self._items_by_modality), self.mm_processor, self._modality_order dict(self._items_by_modality), self.mm_processor, self._modality_order
) )
def create_parser(self) -> "BaseMultiModalContentParser": def create_parser(
return MultiModalContentParser(self) self, mm_processor_kwargs: dict[str, Any] | None = None
) -> "BaseMultiModalContentParser":
return MultiModalContentParser(self, mm_processor_kwargs=mm_processor_kwargs)
class AsyncMultiModalItemTracker( class AsyncMultiModalItemTracker(
@@ -712,8 +716,12 @@ class AsyncMultiModalItemTracker(
resolved_items_by_modality, self.mm_processor, self._modality_order resolved_items_by_modality, self.mm_processor, self._modality_order
) )
def create_parser(self) -> "BaseMultiModalContentParser": def create_parser(
return AsyncMultiModalContentParser(self) self, mm_processor_kwargs: dict[str, Any] | None = None
) -> "BaseMultiModalContentParser":
return AsyncMultiModalContentParser(
self, mm_processor_kwargs=mm_processor_kwargs
)
class BaseMultiModalContentParser(ABC): class BaseMultiModalContentParser(ABC):
@@ -778,7 +786,11 @@ class BaseMultiModalContentParser(ABC):
class MultiModalContentParser(BaseMultiModalContentParser): class MultiModalContentParser(BaseMultiModalContentParser):
def __init__(self, tracker: MultiModalItemTracker) -> None: def __init__(
self,
tracker: MultiModalItemTracker,
mm_processor_kwargs: dict[str, Any] | None = None,
) -> None:
super().__init__() super().__init__()
self._tracker = tracker self._tracker = tracker
@@ -790,6 +802,8 @@ class MultiModalContentParser(BaseMultiModalContentParser):
allowed_media_domains=tracker.allowed_media_domains, allowed_media_domains=tracker.allowed_media_domains,
) )
self._mm_processor_kwargs = mm_processor_kwargs
@property @property
def model_config(self) -> ModelConfig: def model_config(self) -> ModelConfig:
return self._tracker.model_config return self._tracker.model_config
@@ -886,9 +900,23 @@ class MultiModalContentParser(BaseMultiModalContentParser):
placeholder = self._tracker.add("video", (video, uuid)) placeholder = self._tracker.add("video", (video, uuid))
self._add_placeholder("video", placeholder) self._add_placeholder("video", placeholder)
# Extract audio from video if use_audio_in_video is True
if (
video_url
and self._mm_processor_kwargs
and self._mm_processor_kwargs.get("use_audio_in_video", False)
):
audio = self._connector.fetch_audio(video_url) if video_url else None
audio_placeholder = self._tracker.add("audio", (audio, uuid))
self._add_placeholder("audio", audio_placeholder)
class AsyncMultiModalContentParser(BaseMultiModalContentParser): class AsyncMultiModalContentParser(BaseMultiModalContentParser):
def __init__(self, tracker: AsyncMultiModalItemTracker) -> None: def __init__(
self,
tracker: AsyncMultiModalItemTracker,
mm_processor_kwargs: dict[str, Any] | None = None,
) -> None:
super().__init__() super().__init__()
self._tracker = tracker self._tracker = tracker
@@ -898,6 +926,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
allowed_local_media_path=tracker.allowed_local_media_path, allowed_local_media_path=tracker.allowed_local_media_path,
allowed_media_domains=tracker.allowed_media_domains, allowed_media_domains=tracker.allowed_media_domains,
) )
self._mm_processor_kwargs: dict[str, Any] | None = mm_processor_kwargs
@property @property
def model_config(self) -> ModelConfig: def model_config(self) -> ModelConfig:
@@ -1033,6 +1062,16 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
placeholder = self._tracker.add("video", coro) placeholder = self._tracker.add("video", coro)
self._add_placeholder("video", placeholder) self._add_placeholder("video", placeholder)
# Extract audio from video if use_audio_in_video is True
if (
video_url
and self._mm_processor_kwargs
and self._mm_processor_kwargs.get("use_audio_in_video", False)
):
audio_coro = self._audio_with_uuid_async(video_url, uuid)
audio_placeholder = self._tracker.add("audio", audio_coro)
self._add_placeholder("audio", audio_placeholder)
@dataclass @dataclass
class ChatTemplateConfig: class ChatTemplateConfig:
@@ -1343,10 +1382,11 @@ def _parse_chat_message_content_parts(
*, *,
wrap_dicts: bool, wrap_dicts: bool,
interleave_strings: bool, interleave_strings: bool,
mm_processor_kwargs: dict[str, Any] | None = None,
) -> list[ConversationMessage]: ) -> list[ConversationMessage]:
content = list[_ContentPart]() content = list[_ContentPart]()
mm_parser = mm_tracker.create_parser() mm_parser = mm_tracker.create_parser(mm_processor_kwargs=mm_processor_kwargs)
for part in parts: for part in parts:
parse_res = _parse_chat_message_content_part( parse_res = _parse_chat_message_content_part(
@@ -1464,6 +1504,7 @@ def _parse_chat_message_content(
mm_tracker: BaseMultiModalItemTracker, mm_tracker: BaseMultiModalItemTracker,
content_format: ChatTemplateContentFormat, content_format: ChatTemplateContentFormat,
interleave_strings: bool, interleave_strings: bool,
mm_processor_kwargs: dict[str, Any] | None = None,
) -> list[ConversationMessage]: ) -> list[ConversationMessage]:
role = message["role"] role = message["role"]
content = message.get("content") content = message.get("content")
@@ -1479,6 +1520,7 @@ def _parse_chat_message_content(
mm_tracker, mm_tracker,
wrap_dicts=(content_format == "openai"), wrap_dicts=(content_format == "openai"),
interleave_strings=interleave_strings, interleave_strings=interleave_strings,
mm_processor_kwargs=mm_processor_kwargs,
) )
for result_msg in result: for result_msg in result:
@@ -1540,6 +1582,7 @@ def parse_chat_messages(
model_config: ModelConfig, model_config: ModelConfig,
content_format: ChatTemplateContentFormat, content_format: ChatTemplateContentFormat,
media_io_kwargs: dict[str, dict[str, Any]] | None = None, media_io_kwargs: dict[str, dict[str, Any]] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None,
) -> tuple[ ) -> tuple[
list[ConversationMessage], list[ConversationMessage],
MultiModalDataDict | None, MultiModalDataDict | None,
@@ -1558,6 +1601,7 @@ def parse_chat_messages(
and model_config.multimodal_config is not None and model_config.multimodal_config is not None
and model_config.multimodal_config.interleave_mm_strings and model_config.multimodal_config.interleave_mm_strings
), ),
mm_processor_kwargs=mm_processor_kwargs,
) )
conversation.extend(sub_messages) conversation.extend(sub_messages)
@@ -1574,6 +1618,7 @@ async def parse_chat_messages_async(
model_config: ModelConfig, model_config: ModelConfig,
content_format: ChatTemplateContentFormat, content_format: ChatTemplateContentFormat,
media_io_kwargs: dict[str, dict[str, Any]] | None = None, media_io_kwargs: dict[str, dict[str, Any]] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None,
) -> tuple[ ) -> tuple[
list[ConversationMessage], list[ConversationMessage],
MultiModalDataDict | None, MultiModalDataDict | None,
@@ -1594,6 +1639,7 @@ async def parse_chat_messages_async(
and model_config.multimodal_config is not None and model_config.multimodal_config is not None
and model_config.multimodal_config.interleave_mm_strings and model_config.multimodal_config.interleave_mm_strings
), ),
mm_processor_kwargs=mm_processor_kwargs,
) )
conversation.extend(sub_messages) conversation.extend(sub_messages)

View File

@@ -892,6 +892,7 @@ class OpenAIServing:
).with_defaults( ).with_defaults(
default_template_kwargs, default_template_kwargs,
default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None), default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None),
default_mm_processor_kwargs=getattr(request, "mm_processor_kwargs", None),
) )
(conversation,), (engine_prompt,) = await renderer.render_chat_async( (conversation,), (engine_prompt,) = await renderer.render_chat_async(

View File

@@ -78,7 +78,11 @@ from vllm.multimodal.parse import (
ModalityDataItems, ModalityDataItems,
MultiModalDataItems, MultiModalDataItems,
) )
from vllm.multimodal.processing import BaseDummyInputsBuilder from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
ProcessorInputs,
TimingContext,
)
from vllm.multimodal.processing.processor import ( from vllm.multimodal.processing.processor import (
BaseMultiModalProcessor, BaseMultiModalProcessor,
MultiModalPromptUpdates, MultiModalPromptUpdates,
@@ -811,6 +815,16 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
), ),
] ]
def _cached_apply_hf_processor(
self,
inputs: ProcessorInputs,
timing_ctx: TimingContext,
):
mm_processor_kwargs = inputs.hf_processor_mm_kwargs
if mm_processor_kwargs.get("use_audio_in_video", False):
return self._apply_hf_processor(inputs, timing_ctx)
return super()._cached_apply_hf_processor(inputs, timing_ctx)
def _apply_hf_processor_main( def _apply_hf_processor_main(
self, self,
prompt: str | list[int], prompt: str | list[int],

View File

@@ -82,6 +82,35 @@ def extract_audio_from_video_bytes(
return audio, float(native_sr) return audio, float(native_sr)
def is_video(data: bytes) -> bool:
"""Check if the fetched bytes are video"""
if len(data) < 12:
return False
box_type = data[4:8]
major_brand = data[8:12]
MP4_BRANDS = {
b"mp41",
b"mp42", # MP4
b"isom", # ISO Base Media
b"iso2",
b"iso4",
b"iso5",
b"iso6",
b"M4V ",
b"M4A ", # Apple
b"avc1", # H.264
b"dash", # DASH
b"mmp4",
b"MSNV",
}
is_avi = data[:4] == b"RIFF" and major_brand == b"AVI "
is_mp4 = box_type == b"ftyp" and major_brand in MP4_BRANDS
return is_mp4 or is_avi
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]): class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
"""Configuration values can be user-provided either by --media-io-kwargs or """Configuration values can be user-provided either by --media-io-kwargs or
by the runtime API field "media_io_kwargs". Ensure proper validation and by the runtime API field "media_io_kwargs". Ensure proper validation and
@@ -100,6 +129,8 @@ class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
self.kwargs = kwargs self.kwargs = kwargs
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]: def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
if is_video(data):
return extract_audio_from_video_bytes(data)
return librosa.load(BytesIO(data), sr=None) return librosa.load(BytesIO(data), sr=None)
def load_base64( def load_base64(

View File

@@ -50,6 +50,7 @@ class DeepseekV32Renderer(BaseRenderer[DeepseekV32Tokenizer]):
self.model_config, self.model_config,
content_format="string", content_format="string",
media_io_kwargs=params.media_io_kwargs, media_io_kwargs=params.media_io_kwargs,
mm_processor_kwargs=params.mm_processor_kwargs,
) )
prompt_raw = tokenizer.apply_chat_template( prompt_raw = tokenizer.apply_chat_template(
@@ -77,6 +78,7 @@ class DeepseekV32Renderer(BaseRenderer[DeepseekV32Tokenizer]):
self.model_config, self.model_config,
content_format="string", content_format="string",
media_io_kwargs=params.media_io_kwargs, media_io_kwargs=params.media_io_kwargs,
mm_processor_kwargs=params.mm_processor_kwargs,
) )
prompt_raw = tokenizer.apply_chat_template( prompt_raw = tokenizer.apply_chat_template(

View File

@@ -50,6 +50,7 @@ class Grok2Renderer(BaseRenderer[Grok2Tokenizer]):
self.model_config, self.model_config,
content_format="string", content_format="string",
media_io_kwargs=params.media_io_kwargs, media_io_kwargs=params.media_io_kwargs,
mm_processor_kwargs=params.mm_processor_kwargs,
) )
prompt_raw = tokenizer.apply_chat_template( prompt_raw = tokenizer.apply_chat_template(
@@ -77,6 +78,7 @@ class Grok2Renderer(BaseRenderer[Grok2Tokenizer]):
self.model_config, self.model_config,
content_format="string", content_format="string",
media_io_kwargs=params.media_io_kwargs, media_io_kwargs=params.media_io_kwargs,
mm_processor_kwargs=params.mm_processor_kwargs,
) )
prompt_raw = tokenizer.apply_chat_template( prompt_raw = tokenizer.apply_chat_template(

View File

@@ -636,6 +636,7 @@ class HfRenderer(BaseRenderer[HfTokenizer]):
model_config=model_config, model_config=model_config,
), ),
media_io_kwargs=params.media_io_kwargs, media_io_kwargs=params.media_io_kwargs,
mm_processor_kwargs=params.mm_processor_kwargs,
) )
prompt_raw = safe_apply_chat_template( prompt_raw = safe_apply_chat_template(
@@ -691,6 +692,7 @@ class HfRenderer(BaseRenderer[HfTokenizer]):
model_config=model_config, model_config=model_config,
), ),
media_io_kwargs=params.media_io_kwargs, media_io_kwargs=params.media_io_kwargs,
mm_processor_kwargs=params.mm_processor_kwargs,
) )
prompt_raw = safe_apply_chat_template( prompt_raw = safe_apply_chat_template(

View File

@@ -91,6 +91,7 @@ class MistralRenderer(BaseRenderer[MistralTokenizer]):
self.model_config, self.model_config,
content_format="string", content_format="string",
media_io_kwargs=params.media_io_kwargs, media_io_kwargs=params.media_io_kwargs,
mm_processor_kwargs=params.mm_processor_kwargs,
) )
prompt_raw = safe_apply_chat_template( prompt_raw = safe_apply_chat_template(
@@ -118,6 +119,7 @@ class MistralRenderer(BaseRenderer[MistralTokenizer]):
self.model_config, self.model_config,
content_format="string", content_format="string",
media_io_kwargs=params.media_io_kwargs, media_io_kwargs=params.media_io_kwargs,
mm_processor_kwargs=params.mm_processor_kwargs,
) )
prompt_raw = await self._apply_chat_template_async( prompt_raw = await self._apply_chat_template_async(

View File

@@ -40,6 +40,34 @@ def merge_kwargs(
return defaults | {k: v for k, v in overrides.items() if v not in unset_values} return defaults | {k: v for k, v in overrides.items() if v not in unset_values}
def recursively_merge_kwargs(
defaults: dict[str, Any] | None,
overrides: dict[str, Any] | None,
/,
*,
unset_values: tuple[object, ...] = (None, "auto"),
) -> dict[str, Any]:
if defaults is None:
defaults = {}
if overrides is None:
overrides = {}
merged = dict(defaults)
for k, v in overrides.items():
if v in unset_values:
continue
if k in merged and isinstance(merged[k], dict) and isinstance(v, dict):
merged[k] = recursively_merge_kwargs(
merged[k], v, unset_values=unset_values
)
else:
merged[k] = v
return merged
@dataclass(frozen=True) @dataclass(frozen=True)
class ChatParams: class ChatParams:
"""Configuration to control how to parse chat messages.""" """Configuration to control how to parse chat messages."""
@@ -56,12 +84,20 @@ class ChatParams:
media_io_kwargs: dict[str, dict[str, Any]] | None = None media_io_kwargs: dict[str, dict[str, Any]] | None = None
"""Per-modality kwargs for media I/O (loading/decoding images, videos, etc.).""" """Per-modality kwargs for media I/O (loading/decoding images, videos, etc.)."""
mm_processor_kwargs: dict[str, Any] | None = None
"""The kwargs to pass to the multi-modal processor."""
def with_defaults( def with_defaults(
self, self,
default_chat_template_kwargs: dict[str, Any] | None = None, default_chat_template_kwargs: dict[str, Any] | None = None,
default_media_io_kwargs: dict[str, dict[str, Any]] | None = None, default_media_io_kwargs: dict[str, dict[str, Any]] | None = None,
default_mm_processor_kwargs: dict[str, Any] | None = None,
):
if (
not default_chat_template_kwargs
and not default_media_io_kwargs
and not default_mm_processor_kwargs
): ):
if not default_chat_template_kwargs and not default_media_io_kwargs:
return self return self
return ChatParams( return ChatParams(
@@ -75,6 +111,10 @@ class ChatParams:
default_media_io_kwargs, default_media_io_kwargs,
self.media_io_kwargs, self.media_io_kwargs,
), ),
mm_processor_kwargs=recursively_merge_kwargs(
default_mm_processor_kwargs,
self.mm_processor_kwargs,
),
) )
def get_apply_chat_template_kwargs(self) -> dict[str, Any]: def get_apply_chat_template_kwargs(self) -> dict[str, Any]:

View File

@@ -44,6 +44,7 @@ class TerratorchRenderer(BaseRenderer):
model_config, model_config,
content_format="string", content_format="string",
media_io_kwargs=params.media_io_kwargs, media_io_kwargs=params.media_io_kwargs,
mm_processor_kwargs=params.mm_processor_kwargs,
) )
prompt = parse_dec_only_prompt([1]) # Dummy token IDs prompt = parse_dec_only_prompt([1]) # Dummy token IDs
@@ -66,6 +67,7 @@ class TerratorchRenderer(BaseRenderer):
model_config, model_config,
content_format="string", content_format="string",
media_io_kwargs=params.media_io_kwargs, media_io_kwargs=params.media_io_kwargs,
mm_processor_kwargs=params.mm_processor_kwargs,
) )
prompt = parse_dec_only_prompt([1]) # Dummy token IDs prompt = parse_dec_only_prompt([1]) # Dummy token IDs