[Frontend] Support image object in llm.chat (#19635)
Signed-off-by: sfeng33 <4florafeng@gmail.com> Signed-off-by: Flora Feng <4florafeng@gmail.com>
This commit is contained in:
@@ -28,7 +28,8 @@ from openai.types.chat import (ChatCompletionMessageToolCallParam,
|
||||
ChatCompletionToolMessageParam)
|
||||
from openai.types.chat.chat_completion_content_part_input_audio_param import (
|
||||
InputAudio)
|
||||
from pydantic import TypeAdapter
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, ConfigDict, TypeAdapter
|
||||
# yapf: enable
|
||||
from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast,
|
||||
ProcessorMixin)
|
||||
@@ -91,6 +92,25 @@ class ChatCompletionContentPartVideoParam(TypedDict, total=False):
|
||||
"""The type of the content part."""
|
||||
|
||||
|
||||
class PILImage(BaseModel):
|
||||
"""
|
||||
A PIL.Image.Image object.
|
||||
"""
|
||||
image_pil: Image.Image
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
|
||||
"""A simpler version of the param that only accepts a PIL image.
|
||||
|
||||
Example:
|
||||
{
|
||||
"image_pil": ImageAsset('cherry_blossom').pil_image
|
||||
}
|
||||
"""
|
||||
image_pil: Required[PILImage]
|
||||
|
||||
|
||||
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
|
||||
"""A simpler version of the param that only accepts a plain image_url.
|
||||
This is supported by OpenAI API, although it is not documented.
|
||||
@@ -129,6 +149,7 @@ ChatCompletionContentPartParam: TypeAlias = Union[
|
||||
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
|
||||
ChatCompletionContentPartInputAudioParam,
|
||||
ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam,
|
||||
CustomChatCompletionContentPILImageParam,
|
||||
CustomChatCompletionContentSimpleImageParam,
|
||||
ChatCompletionContentPartImageEmbedsParam,
|
||||
CustomChatCompletionContentSimpleAudioParam,
|
||||
@@ -631,6 +652,10 @@ class BaseMultiModalContentParser(ABC):
|
||||
image_embeds: Union[str, dict[str, str]]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_image_pil(self, image_pil: Image.Image) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
raise NotImplementedError
|
||||
@@ -677,6 +702,10 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_image_pil(self, image_pil: Image.Image) -> None:
|
||||
placeholder = self._tracker.add("image", image_pil)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
audio = self._connector.fetch_audio(audio_url)
|
||||
|
||||
@@ -733,6 +762,13 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
placeholder = self._tracker.add("image_embeds", future)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_image_pil(self, image_pil: Image.Image) -> None:
|
||||
future: asyncio.Future[Image.Image] = asyncio.Future()
|
||||
future.set_result(image_pil)
|
||||
|
||||
placeholder = self._tracker.add("image", future)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
audio_coro = self._connector.fetch_audio_async(audio_url)
|
||||
|
||||
@@ -851,12 +887,13 @@ _TextParser = partial(cast, ChatCompletionContentPartTextParam)
|
||||
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
|
||||
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
|
||||
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
|
||||
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
|
||||
# Need to validate url objects
|
||||
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
|
||||
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
|
||||
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
|
||||
|
||||
_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio]
|
||||
_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio, PILImage]
|
||||
|
||||
# Define a mapping from part types to their corresponding parsing functions.
|
||||
MM_PARSER_MAP: dict[
|
||||
@@ -869,6 +906,7 @@ MM_PARSER_MAP: dict[
|
||||
lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
|
||||
"image_embeds":
|
||||
lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
|
||||
"image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
|
||||
"audio_url":
|
||||
lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
|
||||
"input_audio":
|
||||
@@ -938,7 +976,7 @@ def _parse_chat_message_content_mm_part(
|
||||
|
||||
|
||||
VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
|
||||
"image_embeds",
|
||||
"image_embeds", "image_pil",
|
||||
"audio_url", "input_audio", "video_url")
|
||||
|
||||
|
||||
@@ -1009,6 +1047,10 @@ def _parse_chat_message_content_part(
|
||||
else:
|
||||
return str_content
|
||||
|
||||
if part_type == "image_pil":
|
||||
image_content = cast(Image.Image, content)
|
||||
mm_parser.parse_image_pil(image_content)
|
||||
return {'type': 'image'} if wrap_dicts else None
|
||||
if part_type == "image_url":
|
||||
str_content = cast(str, content)
|
||||
mm_parser.parse_image(str_content)
|
||||
|
||||
Reference in New Issue
Block a user