[Frontend] support image embeds (#13955)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
@@ -56,6 +56,17 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False):
|
||||
"""The type of the content part."""
|
||||
|
||||
|
||||
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
|
||||
image_embeds: Required[Union[str, dict[str, str]]]
|
||||
"""
|
||||
The image embeddings. It can be either:
|
||||
- A single base64 string.
|
||||
- A dictionary where each value is a base64 string.
|
||||
"""
|
||||
type: Required[Literal["image_embeds"]]
|
||||
"""The type of the content part."""
|
||||
|
||||
|
||||
class VideoURL(TypedDict, total=False):
|
||||
url: Required[str]
|
||||
"""
|
||||
@@ -109,6 +120,7 @@ ChatCompletionContentPartParam: TypeAlias = Union[
|
||||
ChatCompletionContentPartInputAudioParam,
|
||||
ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam,
|
||||
CustomChatCompletionContentSimpleImageParam,
|
||||
ChatCompletionContentPartImageEmbedsParam,
|
||||
CustomChatCompletionContentSimpleAudioParam,
|
||||
CustomChatCompletionContentSimpleVideoParam, str]
|
||||
|
||||
@@ -350,7 +362,7 @@ def resolve_chat_template_content_format(
|
||||
return detected_format
|
||||
|
||||
|
||||
ModalityStr = Literal["image", "audio", "video"]
|
||||
ModalityStr = Literal["image", "audio", "video", "image_embeds"]
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
@@ -391,7 +403,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
hf_config = self._model_config.hf_config
|
||||
model_type = hf_config.model_type
|
||||
|
||||
if modality == "image":
|
||||
if modality in ["image", "image_embeds"]:
|
||||
if model_type == "phi3_v":
|
||||
# Workaround since this token is not defined in the tokenizer
|
||||
return f"<|image_{current_count}|>"
|
||||
@@ -470,10 +482,27 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
|
||||
|
||||
def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
||||
if self._items_by_modality:
|
||||
return dict(self._items_by_modality)
|
||||
if not self._items_by_modality:
|
||||
return None
|
||||
mm_inputs = {}
|
||||
items_by_modality = dict(self._items_by_modality)
|
||||
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
||||
raise ValueError(\
|
||||
"Mixing raw image and embedding inputs is not allowed")
|
||||
|
||||
return None
|
||||
if "image_embeds" in items_by_modality:
|
||||
image_embeds_lst = items_by_modality["image_embeds"]
|
||||
if len(image_embeds_lst) > 1:
|
||||
raise ValueError(\
|
||||
"Only one message can have {'type': 'image_embeds'}")
|
||||
mm_inputs["image"] = image_embeds_lst[0]
|
||||
elif "image" in items_by_modality:
|
||||
mm_inputs["image"] = items_by_modality["image"] # A list of images
|
||||
elif "audio" in items_by_modality:
|
||||
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
|
||||
elif "video" in items_by_modality:
|
||||
mm_inputs["video"] = items_by_modality["video"] # A list of videos
|
||||
return mm_inputs
|
||||
|
||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||
return MultiModalContentParser(self)
|
||||
@@ -482,13 +511,31 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
|
||||
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
|
||||
|
||||
async def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
||||
if self._items_by_modality:
|
||||
return {
|
||||
if not self._items_by_modality:
|
||||
return None
|
||||
mm_inputs = {}
|
||||
items_by_modality = {
|
||||
modality: await asyncio.gather(*items)
|
||||
for modality, items in self._items_by_modality.items()
|
||||
}
|
||||
|
||||
return None
|
||||
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
||||
raise ValueError(
|
||||
"Mixing raw image and embedding inputs is not allowed")
|
||||
|
||||
if "image_embeds" in items_by_modality:
|
||||
image_embeds_lst = items_by_modality["image_embeds"]
|
||||
if len(image_embeds_lst) > 1:
|
||||
raise ValueError(
|
||||
"Only one message can have {'type': 'image_embeds'}")
|
||||
mm_inputs["image"] = image_embeds_lst[0]
|
||||
elif "image" in items_by_modality:
|
||||
mm_inputs["image"] = items_by_modality["image"] # A list of images
|
||||
elif "audio" in items_by_modality:
|
||||
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
|
||||
elif "video" in items_by_modality:
|
||||
mm_inputs["video"] = items_by_modality["video"] # A list of videos
|
||||
return mm_inputs
|
||||
|
||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||
return AsyncMultiModalContentParser(self)
|
||||
@@ -513,6 +560,11 @@ class BaseMultiModalContentParser(ABC):
|
||||
def parse_image(self, image_url: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_image_embeds(self,
|
||||
image_embeds: Union[str, dict[str, str]]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
raise NotImplementedError
|
||||
@@ -543,6 +595,21 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
placeholder = self._tracker.add("image", image)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_image_embeds(self,
|
||||
image_embeds: Union[str, dict[str, str]]) -> None:
|
||||
if isinstance(image_embeds, dict):
|
||||
embeds = {
|
||||
k: self._connector.fetch_image_embedding(v)
|
||||
for k, v in image_embeds.items()
|
||||
}
|
||||
placeholder = self._tracker.add("image_embeds", embeds)
|
||||
|
||||
if isinstance(image_embeds, str):
|
||||
embedding = self._connector.fetch_image_embedding(image_embeds)
|
||||
placeholder = self._tracker.add("image_embeds", embedding)
|
||||
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
audio = self._connector.fetch_audio(audio_url)
|
||||
|
||||
@@ -579,6 +646,25 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
placeholder = self._tracker.add("image", image_coro)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_image_embeds(self,
|
||||
image_embeds: Union[str, dict[str, str]]) -> None:
|
||||
future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future()
|
||||
|
||||
if isinstance(image_embeds, dict):
|
||||
embeds = {
|
||||
k: self._connector.fetch_image_embedding(v)
|
||||
for k, v in image_embeds.items()
|
||||
}
|
||||
future.set_result(embeds)
|
||||
|
||||
if isinstance(image_embeds, str):
|
||||
embedding = self._connector.\
|
||||
fetch_image_embedding(image_embeds)
|
||||
future.set_result(embedding)
|
||||
|
||||
placeholder = self._tracker.add("image_embeds", future)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
audio_coro = self._connector.fetch_audio_async(audio_url)
|
||||
|
||||
@@ -684,6 +770,7 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
|
||||
# No need to validate using Pydantic again
|
||||
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
|
||||
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
|
||||
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
|
||||
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
|
||||
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
|
||||
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
|
||||
@@ -700,6 +787,8 @@ MM_PARSER_MAP: dict[
|
||||
lambda part: _TextParser(part).get("text", ""),
|
||||
"image_url":
|
||||
lambda part: _ImageParser(part).get("image_url", {}).get("url", ""),
|
||||
"image_embeds":
|
||||
lambda part: _ImageEmbedsParser(part).get("image_embeds", {}),
|
||||
"audio_url":
|
||||
lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""),
|
||||
"input_audio":
|
||||
@@ -769,6 +858,7 @@ def _parse_chat_message_content_mm_part(
|
||||
|
||||
|
||||
VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
|
||||
"image_embeds",
|
||||
"audio_url", "input_audio", "video_url")
|
||||
|
||||
|
||||
@@ -843,7 +933,10 @@ def _parse_chat_message_content_part(
|
||||
str_content = cast(str, content)
|
||||
mm_parser.parse_image(str_content)
|
||||
return {'type': 'image'} if wrap_dicts else None
|
||||
|
||||
if part_type == "image_embeds":
|
||||
content = cast(Union[str, dict[str, str]], content)
|
||||
mm_parser.parse_image_embeds(content)
|
||||
return {'type': 'image'} if wrap_dicts else None
|
||||
if part_type == "audio_url":
|
||||
str_content = cast(str, content)
|
||||
mm_parser.parse_audio(str_content)
|
||||
|
||||
Reference in New Issue
Block a user