[Frontend] Add OpenAI API support for input_audio (#11027)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -13,7 +13,8 @@ import transformers.utils.chat_template_utils as hf_chat_utils
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from openai.types.chat import (ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionContentPartImageParam)
|
||||
ChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartInputAudioParam)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
|
||||
from openai.types.chat import (ChatCompletionContentPartRefusalParam,
|
||||
@@ -105,6 +106,7 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
|
||||
|
||||
ChatCompletionContentPartParam: TypeAlias = Union[
|
||||
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
|
||||
ChatCompletionContentPartInputAudioParam,
|
||||
ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam,
|
||||
CustomChatCompletionContentSimpleImageParam,
|
||||
CustomChatCompletionContentSimpleAudioParam,
|
||||
@@ -519,6 +521,10 @@ class BaseMultiModalContentParser(ABC):
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_video(self, video_url: str) -> None:
|
||||
raise NotImplementedError
|
||||
@@ -545,6 +551,15 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
placeholder = self._tracker.add("audio", audio)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
|
||||
input_audio_data = input_audio.get("data","")
|
||||
input_audio_format = input_audio.get("format","")
|
||||
audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}"
|
||||
audio = get_and_parse_audio(audio_url)
|
||||
|
||||
placeholder = self._tracker.add("audio", audio)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_video(self, video_url: str) -> None:
|
||||
video = get_and_parse_video(video_url)
|
||||
|
||||
@@ -574,6 +589,15 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
placeholder = self._tracker.add("audio", audio_coro)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
|
||||
input_audio_data = input_audio.get("data","")
|
||||
input_audio_format = input_audio.get("format","")
|
||||
audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}"
|
||||
audio_coro = async_get_and_parse_audio(audio_url)
|
||||
|
||||
placeholder = self._tracker.add("audio", audio_coro)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_video(self, video_url: str) -> None:
|
||||
video = async_get_and_parse_video(video_url)
|
||||
|
||||
@@ -667,17 +691,22 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
|
||||
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
|
||||
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
|
||||
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
|
||||
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
|
||||
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
|
||||
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
|
||||
|
||||
# Define a mapping from part types to their corresponding parsing functions.
|
||||
MM_PARSER_MAP: Dict[str, Callable[[ChatCompletionContentPartParam], str]] = {
|
||||
MM_PARSER_MAP: Dict[str,
|
||||
Callable[[ChatCompletionContentPartParam],
|
||||
Union[str, Dict[str,str]]]] = {
|
||||
"text":
|
||||
lambda part: _TextParser(part).get("text", ""),
|
||||
"image_url":
|
||||
lambda part: _ImageParser(part).get("image_url", {}).get("url", ""),
|
||||
"audio_url":
|
||||
lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""),
|
||||
"input_audio":
|
||||
lambda part: _InputAudioParser(part).get("input_audio", {}),
|
||||
"refusal":
|
||||
lambda part: _RefusalParser(part).get("refusal", ""),
|
||||
"video_url":
|
||||
@@ -686,7 +715,8 @@ MM_PARSER_MAP: Dict[str, Callable[[ChatCompletionContentPartParam], str]] = {
|
||||
|
||||
|
||||
def _parse_chat_message_content_mm_part(
|
||||
part: ChatCompletionContentPartParam) -> Tuple[str, str]:
|
||||
part: ChatCompletionContentPartParam) -> Tuple[str,
|
||||
Union[str, Dict[str, str]]]:
|
||||
"""
|
||||
Parses a given multi-modal content part based on its type.
|
||||
|
||||
@@ -717,6 +747,7 @@ def _parse_chat_message_content_mm_part(
|
||||
return part_type, content
|
||||
|
||||
# Handle missing 'type' but provided direct URL fields.
|
||||
# 'type' is required field by pydantic
|
||||
if part_type is None:
|
||||
if part.get("image_url") is not None:
|
||||
image_params = cast(CustomChatCompletionContentSimpleImageParam,
|
||||
@@ -726,6 +757,9 @@ def _parse_chat_message_content_mm_part(
|
||||
audio_params = cast(CustomChatCompletionContentSimpleAudioParam,
|
||||
part)
|
||||
return "audio_url", audio_params.get("audio_url", "")
|
||||
if part.get("input_audio") is not None:
|
||||
input_audio_params = cast(Dict[str, str], part)
|
||||
return "input_audio", input_audio_params
|
||||
if part.get("video_url") is not None:
|
||||
video_params = cast(CustomChatCompletionContentSimpleVideoParam,
|
||||
part)
|
||||
@@ -739,7 +773,7 @@ def _parse_chat_message_content_mm_part(
|
||||
|
||||
|
||||
VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
|
||||
"audio_url", "video_url")
|
||||
"audio_url", "input_audio", "video_url")
|
||||
|
||||
|
||||
def _parse_chat_message_content_parts(
|
||||
@@ -795,7 +829,7 @@ def _parse_chat_message_content_part(
|
||||
# Handle structured dictionary parts
|
||||
part_type, content = _parse_chat_message_content_mm_part(part)
|
||||
|
||||
# if part_type is text/refusal/image_url/audio_url/video_url but
|
||||
# if part_type is text/refusal/image_url/audio_url/video_url/input_audio but
|
||||
# content is empty, log a warning and skip
|
||||
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content:
|
||||
logger.warning(
|
||||
@@ -804,18 +838,30 @@ def _parse_chat_message_content_part(
|
||||
return None
|
||||
|
||||
if part_type in ("text", "refusal"):
|
||||
return {'type': 'text', 'text': content} if wrap_dicts else content
|
||||
str_content = cast(str, content)
|
||||
if wrap_dicts:
|
||||
return {'type': 'text', 'text': str_content}
|
||||
else:
|
||||
return str_content
|
||||
|
||||
if part_type == "image_url":
|
||||
mm_parser.parse_image(content)
|
||||
str_content = cast(str, content)
|
||||
mm_parser.parse_image(str_content)
|
||||
return {'type': 'image'} if wrap_dicts else None
|
||||
|
||||
if part_type == "audio_url":
|
||||
mm_parser.parse_audio(content)
|
||||
str_content = cast(str, content)
|
||||
mm_parser.parse_audio(str_content)
|
||||
return {'type': 'audio'} if wrap_dicts else None
|
||||
|
||||
if part_type == "input_audio":
|
||||
dict_content = cast(Dict[str, str], content)
|
||||
mm_parser.parse_input_audio(dict_content)
|
||||
return {'type': 'audio'} if wrap_dicts else None
|
||||
|
||||
if part_type == "video_url":
|
||||
mm_parser.parse_video(content)
|
||||
str_content = cast(str, content)
|
||||
mm_parser.parse_video(str_content)
|
||||
return {'type': 'video'} if wrap_dicts else None
|
||||
|
||||
raise NotImplementedError(f"Unknown part type: {part_type}")
|
||||
@@ -840,7 +886,6 @@ def _parse_chat_message_content(
|
||||
content = [
|
||||
ChatCompletionContentPartTextParam(type="text", text=content)
|
||||
]
|
||||
|
||||
result = _parse_chat_message_content_parts(
|
||||
role,
|
||||
content, # type: ignore
|
||||
|
||||
Reference in New Issue
Block a user