[Core] Add audio_embeds support to chat completions (#29059)

Signed-off-by: Jeremy Teboul <jeremyteboul@fb.com>
Co-authored-by: Jeremy Teboul <jeremyteboul@fb.com>
This commit is contained in:
jeremyteboul
2025-11-20 19:39:47 -08:00
committed by GitHub
parent a982f5b5ea
commit 0730414999
5 changed files with 360 additions and 3 deletions

View File

@@ -94,6 +94,22 @@ class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
"""
class ChatCompletionContentPartAudioEmbedsParam(TypedDict, total=False):
audio_embeds: str | dict[str, str] | None
"""
The audio embeddings. It can be either:
- A single base64 string representing a serialized torch tensor.
- A dictionary where each value is a base64 string.
"""
type: Required[Literal["audio_embeds"]]
"""The type of the content part."""
uuid: str | None
"""
User-provided UUID of a media. User must guarantee that it is properly
generated and unique for different medias.
"""
class VideoURL(TypedDict, total=False):
url: Required[str]
"""
@@ -211,6 +227,7 @@ ChatCompletionContentPartParam: TypeAlias = (
| CustomChatCompletionContentPILImageParam
| CustomChatCompletionContentSimpleImageParam
| ChatCompletionContentPartImageEmbedsParam
| ChatCompletionContentPartAudioEmbedsParam
| CustomChatCompletionContentSimpleAudioParam
| CustomChatCompletionContentSimpleVideoParam
| str
@@ -599,7 +616,7 @@ def resolve_chat_template_content_format(
return detected_format
ModalityStr = Literal["image", "audio", "video", "image_embeds"]
ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"]
_T = TypeVar("_T")
@@ -684,6 +701,11 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
mm_uuids["image"] = uuids_by_modality["image_embeds"]
if "image" in uuids_by_modality:
mm_uuids["image"] = uuids_by_modality["image"] # UUIDs of images
if "audio_embeds" in uuids_by_modality:
audio_embeds_uuids = uuids_by_modality["audio_embeds"]
if len(audio_embeds_uuids) > 1:
raise ValueError("Only one message can have {'type': 'audio_embeds'}")
mm_uuids["audio"] = uuids_by_modality["audio_embeds"]
if "audio" in uuids_by_modality:
mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios
if "video" in uuids_by_modality:
@@ -703,6 +725,8 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
items_by_modality = dict(self._items_by_modality)
if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError("Mixing raw image and embedding inputs is not allowed")
if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"]
@@ -711,6 +735,11 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
mm_inputs["image"] = image_embeds_lst[0]
if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images
if "audio_embeds" in items_by_modality:
audio_embeds_lst = items_by_modality["audio_embeds"]
if len(audio_embeds_lst) > 1:
raise ValueError("Only one message can have {'type': 'audio_embeds'}")
mm_inputs["audio"] = audio_embeds_lst[0]
if "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
if "video" in items_by_modality:
@@ -738,6 +767,8 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError("Mixing raw image and embedding inputs is not allowed")
if "audio" in items_by_modality and "audio_embeds" in items_by_modality:
raise ValueError("Mixing raw audio and embedding inputs is not allowed")
if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"]
@@ -746,6 +777,11 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
mm_inputs["image"] = image_embeds_lst[0]
if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images
if "audio_embeds" in items_by_modality:
audio_embeds_lst = items_by_modality["audio_embeds"]
if len(audio_embeds_lst) > 1:
raise ValueError("Only one message can have {'type': 'audio_embeds'}")
mm_inputs["audio"] = audio_embeds_lst[0]
if "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
if "video" in items_by_modality:
@@ -804,6 +840,14 @@ class BaseMultiModalContentParser(ABC):
) -> None:
raise NotImplementedError
@abstractmethod
def parse_audio_embeds(
self,
audio_embeds: str | dict[str, str] | None,
uuid: str | None = None,
) -> None:
raise NotImplementedError
@abstractmethod
def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
raise NotImplementedError
@@ -861,6 +905,31 @@ class MultiModalContentParser(BaseMultiModalContentParser):
self._add_placeholder("image", placeholder)
def parse_audio_embeds(
self,
audio_embeds: str | dict[str, str] | None,
uuid: str | None = None,
) -> None:
mm_config = self.model_config.get_multimodal_config()
if not mm_config.enable_mm_embeds:
raise ValueError(
"You must set `--enable-mm-embeds` to input `audio_embeds`"
)
if isinstance(audio_embeds, dict):
embeds = {
k: self._connector.fetch_audio_embedding(v)
for k, v in audio_embeds.items()
}
placeholder = self._tracker.add("audio_embeds", embeds, uuid)
elif isinstance(audio_embeds, str):
embedding = self._connector.fetch_audio_embedding(audio_embeds)
placeholder = self._tracker.add("audio_embeds", embedding, uuid)
else:
placeholder = self._tracker.add("audio_embeds", None, uuid)
self._add_placeholder("audio", placeholder)
def parse_image_pil(
self, image_pil: Image.Image | None, uuid: str | None = None
) -> None:
@@ -950,6 +1019,67 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
placeholder = self._tracker.add("image_embeds", future, uuid)
self._add_placeholder("image", placeholder)
def parse_audio_embeds(
self,
audio_embeds: str | dict[str, str] | None,
uuid: str | None = None,
) -> None:
mm_config = self.model_config.get_multimodal_config()
if not mm_config.enable_mm_embeds:
raise ValueError(
"You must set `--enable-mm-embeds` to input `audio_embeds`"
)
logger.info(
"🎵 Parsing audio_embeds: type=%s, uuid=%s, is_dict=%s, "
"is_str=%s, is_none=%s",
type(audio_embeds).__name__,
uuid,
isinstance(audio_embeds, dict),
isinstance(audio_embeds, str),
audio_embeds is None,
)
future: asyncio.Future[str | dict[str, str] | None] = asyncio.Future()
if isinstance(audio_embeds, dict):
logger.info(
"🎵 Processing dict audio_embeds with %d entries",
len(audio_embeds),
)
embeds = {
k: self._connector.fetch_audio_embedding(v)
for k, v in audio_embeds.items()
}
future.set_result(embeds)
logger.info(
"🎵 Successfully loaded %d audio embeddings from dict",
len(embeds),
)
if isinstance(audio_embeds, str):
base64_size = len(audio_embeds)
logger.info(
"🎵 Processing base64 audio_embeds: %d chars (%.2f KB)",
base64_size,
base64_size / 1024,
)
embedding = self._connector.fetch_audio_embedding(audio_embeds)
future.set_result(embedding)
logger.info(
"🎵 Successfully loaded audio embedding tensor: shape=%s, dtype=%s",
embedding.shape,
embedding.dtype,
)
if audio_embeds is None:
logger.info("🎵 Audio embeds is None (UUID-only reference)")
future.set_result(None)
placeholder = self._tracker.add("audio_embeds", future, uuid)
self._add_placeholder("audio", placeholder)
logger.info("🎵 Added audio_embeds placeholder with uuid=%s", uuid)
def parse_image_pil(
self, image_pil: Image.Image | None, uuid: str | None = None
) -> None:
@@ -1132,6 +1262,7 @@ def _get_full_multimodal_text_prompt(
# No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
_AudioEmbedsParser = partial(cast, ChatCompletionContentPartAudioEmbedsParam)
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
@@ -1155,6 +1286,7 @@ MM_PARSER_MAP: dict[
"input_image": lambda part: _ResponsesInputImageParser(part).get("image_url", None),
"image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
"image_embeds": lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
"audio_embeds": lambda part: _AudioEmbedsParser(part).get("audio_embeds", None),
"image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
"audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
"input_audio": lambda part: _InputAudioParser(part).get("input_audio", None),
@@ -1223,8 +1355,17 @@ def _parse_chat_message_content_mm_part(
)
image_embeds = image_params.get("image_embeds", None)
return "image_embeds", image_embeds
if "audio_embeds" in part:
# "audio_embeds" could be None if UUID is provided.
audio_params = cast( # type: ignore[assignment]
ChatCompletionContentPartAudioEmbedsParam, part
)
audio_embeds = audio_params.get("audio_embeds", None)
return "audio_embeds", audio_embeds
if "audio_url" in part:
audio_params = cast(CustomChatCompletionContentSimpleAudioParam, part)
audio_params = cast( # type: ignore[assignment]
CustomChatCompletionContentSimpleAudioParam, part
)
audio_url = audio_params.get("audio_url", None)
if isinstance(audio_url, dict):
# Can potentially happen if user provides a uuid
@@ -1348,6 +1489,10 @@ def _parse_chat_message_content_part(
content = cast(str | dict[str, str], content) if content is not None else None
mm_parser.parse_image_embeds(content, uuid)
modality = "image"
elif part_type == "audio_embeds":
content = cast(str | dict[str, str], content) if content is not None else None
mm_parser.parse_audio_embeds(content, uuid)
modality = "audio"
elif part_type == "audio_url":
str_content = cast(str, content)
mm_parser.parse_audio(str_content, uuid)

View File

@@ -7,6 +7,8 @@ from typing import Literal
import numpy as np
import numpy.typing as npt
import pybase64
import torch
from vllm.utils.import_utils import PlaceholderModule
@@ -116,3 +118,25 @@ class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
data = buffer.getvalue()
return base64.b64encode(data).decode("utf-8")
class AudioEmbeddingMediaIO(MediaIO[torch.Tensor]):
def __init__(self) -> None:
super().__init__()
def load_bytes(self, data: bytes) -> torch.Tensor:
buffer = BytesIO(data)
return torch.load(buffer, weights_only=True)
def load_base64(self, media_type: str, data: str) -> torch.Tensor:
return self.load_bytes(pybase64.b64decode(data, validate=True))
def load_file(self, filepath: Path) -> torch.Tensor:
return torch.load(filepath, weights_only=True)
def encode_base64(self, media: torch.Tensor) -> str:
buffer = BytesIO()
torch.save(media, buffer)
buffer.seek(0)
binary_data = buffer.read()
return pybase64.b64encode(binary_data).decode("utf-8")

View File

@@ -22,7 +22,7 @@ from vllm.logger import init_logger
from vllm.utils.jsontree import json_map_leaves
from vllm.utils.registry import ExtensionManager
from .audio import AudioMediaIO
from .audio import AudioEmbeddingMediaIO, AudioMediaIO
from .base import MediaIO
from .image import ImageEmbeddingMediaIO, ImageMediaIO
from .video import VideoMediaIO
@@ -342,6 +342,17 @@ class MediaConnector:
return image_embedding_io.load_base64("", data)
def fetch_audio_embedding(
self,
data: str,
) -> torch.Tensor:
"""
Load audio embedding from a URL.
"""
audio_embedding_io = AudioEmbeddingMediaIO()
return audio_embedding_io.load_base64("", data)
def encode_audio_base64(
audio: np.ndarray,