Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: Chenheli Hua <huachenheli@outlook.com> Signed-off-by: Roger Wang <hey@rogerw.me> Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.me> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@@ -41,7 +41,8 @@ from typing_extensions import Required, TypeAlias, TypedDict
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models import SupportsMultiModal
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
|
||||
MultiModalUUIDDict)
|
||||
from vllm.multimodal.utils import MediaConnector
|
||||
# yapf: disable
|
||||
from vllm.transformers_utils.chat_templates import (
|
||||
@@ -72,6 +73,11 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False):
|
||||
|
||||
type: Required[Literal["audio_url"]]
|
||||
"""The type of the content part."""
|
||||
uuid: Optional[str]
|
||||
"""
|
||||
User-provided UUID of a media. User must guarantee that it is properly
|
||||
generated and unique for different medias.
|
||||
"""
|
||||
|
||||
|
||||
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
|
||||
@@ -83,6 +89,11 @@ class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
|
||||
"""
|
||||
type: Required[Literal["image_embeds"]]
|
||||
"""The type of the content part."""
|
||||
uuid: Optional[str]
|
||||
"""
|
||||
User-provided UUID of a media. User must guarantee that it is properly
|
||||
generated and unique for different medias.
|
||||
"""
|
||||
|
||||
|
||||
class VideoURL(TypedDict, total=False):
|
||||
@@ -97,6 +108,11 @@ class ChatCompletionContentPartVideoParam(TypedDict, total=False):
|
||||
|
||||
type: Required[Literal["video_url"]]
|
||||
"""The type of the content part."""
|
||||
uuid: Optional[str]
|
||||
"""
|
||||
User-provided UUID of a media. User must guarantee that it is properly
|
||||
generated and unique for different medias.
|
||||
"""
|
||||
|
||||
|
||||
class PILImage(BaseModel):
|
||||
@@ -118,6 +134,11 @@ class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
|
||||
"""
|
||||
|
||||
image_pil: Required[PILImage]
|
||||
uuid: Optional[str]
|
||||
"""
|
||||
User-provided UUID of a media. User must guarantee that it is properly
|
||||
generated and unique for different medias.
|
||||
"""
|
||||
|
||||
|
||||
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
|
||||
@@ -131,6 +152,11 @@ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
|
||||
"""
|
||||
|
||||
image_url: Required[str]
|
||||
uuid: Optional[str]
|
||||
"""
|
||||
User-provided UUID of a media. User must guarantee that it is properly
|
||||
generated and unique for different medias.
|
||||
"""
|
||||
|
||||
|
||||
class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
|
||||
@@ -155,6 +181,11 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
|
||||
"""
|
||||
|
||||
video_url: Required[str]
|
||||
uuid: Optional[str]
|
||||
"""
|
||||
User-provided UUID of a media. User must guarantee that it is properly
|
||||
generated and unique for different medias.
|
||||
"""
|
||||
|
||||
|
||||
class CustomThinkCompletionContentParam(TypedDict, total=False):
|
||||
@@ -567,6 +598,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
self._tokenizer = tokenizer
|
||||
|
||||
self._items_by_modality = defaultdict[str, list[_T]](list)
|
||||
self._uuids_by_modality = defaultdict[str, list[Optional[str]]](list)
|
||||
|
||||
@property
|
||||
def model_config(self) -> ModelConfig:
|
||||
@@ -591,10 +623,15 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
def mm_processor(self):
|
||||
return self.mm_registry.create_processor(self.model_config)
|
||||
|
||||
def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
|
||||
def add(
|
||||
self, modality: ModalityStr, item: _T, uuid: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Add a multi-modal item to the current prompt and returns the
|
||||
placeholder string to use, if any.
|
||||
|
||||
An optional uuid can be added which serves as a unique identifier of the
|
||||
media.
|
||||
"""
|
||||
input_modality = modality.replace("_embeds", "")
|
||||
num_items = len(self._items_by_modality[modality]) + 1
|
||||
@@ -602,9 +639,35 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
self.mm_processor.validate_num_items(input_modality, num_items)
|
||||
|
||||
self._items_by_modality[modality].append(item)
|
||||
self._uuids_by_modality[modality].append(uuid)
|
||||
|
||||
return self.model_cls.get_placeholder_str(modality, num_items)
|
||||
|
||||
def all_mm_uuids(self) -> Optional[MultiModalUUIDDict]:
|
||||
if not self._items_by_modality:
|
||||
return None
|
||||
mm_uuids = {}
|
||||
uuids_by_modality = dict(self._uuids_by_modality)
|
||||
if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality:
|
||||
raise ValueError(
|
||||
"Mixing raw image and embedding inputs is not allowed"
|
||||
)
|
||||
|
||||
if "image_embeds" in uuids_by_modality:
|
||||
image_embeds_uuids = uuids_by_modality["image_embeds"]
|
||||
if len(image_embeds_uuids) > 1:
|
||||
raise ValueError(
|
||||
"Only one message can have {'type': 'image_embeds'}"
|
||||
)
|
||||
mm_uuids["image"] = uuids_by_modality["image_embeds"]
|
||||
if "image" in uuids_by_modality:
|
||||
mm_uuids["image"] = uuids_by_modality["image"] # UUIDs of images
|
||||
if "audio" in uuids_by_modality:
|
||||
mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios
|
||||
if "video" in uuids_by_modality:
|
||||
mm_uuids["video"] = uuids_by_modality["video"] # UUIDs of videos
|
||||
return mm_uuids
|
||||
|
||||
@abstractmethod
|
||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||
raise NotImplementedError
|
||||
@@ -697,29 +760,35 @@ class BaseMultiModalContentParser(ABC):
|
||||
return dict(self._placeholder_storage)
|
||||
|
||||
@abstractmethod
|
||||
def parse_image(self, image_url: str) -> None:
|
||||
def parse_image(self, image_url: str, uuid: Optional[str] = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_image_embeds(
|
||||
self, image_embeds: Union[str, dict[str, str]]
|
||||
self,
|
||||
image_embeds: Union[str, dict[str, str]],
|
||||
uuid: Optional[str] = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_image_pil(self, image_pil: Image.Image) -> None:
|
||||
def parse_image_pil(
|
||||
self, image_pil: Image.Image, uuid: Optional[str] = None
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
def parse_audio(self, audio_url: str, uuid: Optional[str] = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_input_audio(self, input_audio: InputAudio) -> None:
|
||||
def parse_input_audio(
|
||||
self, input_audio: InputAudio, uuid: Optional[str] = None
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_video(self, video_url: str) -> None:
|
||||
def parse_video(self, video_url: str, uuid: Optional[str] = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -734,49 +803,55 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||
)
|
||||
|
||||
def parse_image(self, image_url: str) -> None:
|
||||
def parse_image(self, image_url: str, uuid: Optional[str] = None) -> None:
|
||||
image = self._connector.fetch_image(image_url)
|
||||
|
||||
placeholder = self._tracker.add("image", image)
|
||||
placeholder = self._tracker.add("image", image, uuid)
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_image_embeds(
|
||||
self, image_embeds: Union[str, dict[str, str]]
|
||||
self,
|
||||
image_embeds: Union[str, dict[str, str]],
|
||||
uuid: Optional[str] = None,
|
||||
) -> None:
|
||||
if isinstance(image_embeds, dict):
|
||||
embeds = {
|
||||
k: self._connector.fetch_image_embedding(v)
|
||||
for k, v in image_embeds.items()
|
||||
}
|
||||
placeholder = self._tracker.add("image_embeds", embeds)
|
||||
placeholder = self._tracker.add("image_embeds", embeds, uuid)
|
||||
|
||||
if isinstance(image_embeds, str):
|
||||
embedding = self._connector.fetch_image_embedding(image_embeds)
|
||||
placeholder = self._tracker.add("image_embeds", embedding)
|
||||
placeholder = self._tracker.add("image_embeds", embedding, uuid)
|
||||
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_image_pil(self, image_pil: Image.Image) -> None:
|
||||
placeholder = self._tracker.add("image", image_pil)
|
||||
def parse_image_pil(
|
||||
self, image_pil: Image.Image, uuid: Optional[str] = None
|
||||
) -> None:
|
||||
placeholder = self._tracker.add("image", image_pil, uuid)
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
def parse_audio(self, audio_url: str, uuid: Optional[str] = None) -> None:
|
||||
audio = self._connector.fetch_audio(audio_url)
|
||||
|
||||
placeholder = self._tracker.add("audio", audio)
|
||||
placeholder = self._tracker.add("audio", audio, uuid)
|
||||
self._add_placeholder("audio", placeholder)
|
||||
|
||||
def parse_input_audio(self, input_audio: InputAudio) -> None:
|
||||
def parse_input_audio(
|
||||
self, input_audio: InputAudio, uuid: Optional[str] = None
|
||||
) -> None:
|
||||
audio_data = input_audio.get("data", "")
|
||||
audio_format = input_audio.get("format", "")
|
||||
audio_url = f"data:audio/{audio_format};base64,{audio_data}"
|
||||
|
||||
return self.parse_audio(audio_url)
|
||||
return self.parse_audio(audio_url, uuid)
|
||||
|
||||
def parse_video(self, video_url: str) -> None:
|
||||
def parse_video(self, video_url: str, uuid: Optional[str] = None) -> None:
|
||||
video = self._connector.fetch_video(video_url=video_url)
|
||||
|
||||
placeholder = self._tracker.add("video", video)
|
||||
placeholder = self._tracker.add("video", video, uuid)
|
||||
self._add_placeholder("video", placeholder)
|
||||
|
||||
|
||||
@@ -790,14 +865,16 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||
)
|
||||
|
||||
def parse_image(self, image_url: str) -> None:
|
||||
def parse_image(self, image_url: str, uuid: Optional[str] = None) -> None:
|
||||
image_coro = self._connector.fetch_image_async(image_url)
|
||||
|
||||
placeholder = self._tracker.add("image", image_coro)
|
||||
placeholder = self._tracker.add("image", image_coro, uuid)
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_image_embeds(
|
||||
self, image_embeds: Union[str, dict[str, str]]
|
||||
self,
|
||||
image_embeds: Union[str, dict[str, str]],
|
||||
uuid: Optional[str] = None,
|
||||
) -> None:
|
||||
future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future()
|
||||
|
||||
@@ -812,33 +889,37 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
embedding = self._connector.fetch_image_embedding(image_embeds)
|
||||
future.set_result(embedding)
|
||||
|
||||
placeholder = self._tracker.add("image_embeds", future)
|
||||
placeholder = self._tracker.add("image_embeds", future, uuid)
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_image_pil(self, image_pil: Image.Image) -> None:
|
||||
def parse_image_pil(
|
||||
self, image_pil: Image.Image, uuid: Optional[str] = None
|
||||
) -> None:
|
||||
future: asyncio.Future[Image.Image] = asyncio.Future()
|
||||
future.set_result(image_pil)
|
||||
|
||||
placeholder = self._tracker.add("image", future)
|
||||
placeholder = self._tracker.add("image", future, uuid)
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
def parse_audio(self, audio_url: str, uuid: Optional[str] = None) -> None:
|
||||
audio_coro = self._connector.fetch_audio_async(audio_url)
|
||||
|
||||
placeholder = self._tracker.add("audio", audio_coro)
|
||||
placeholder = self._tracker.add("audio", audio_coro, uuid)
|
||||
self._add_placeholder("audio", placeholder)
|
||||
|
||||
def parse_input_audio(self, input_audio: InputAudio) -> None:
|
||||
def parse_input_audio(
|
||||
self, input_audio: InputAudio, uuid: Optional[str] = None
|
||||
) -> None:
|
||||
audio_data = input_audio.get("data", "")
|
||||
audio_format = input_audio.get("format", "")
|
||||
audio_url = f"data:audio/{audio_format};base64,{audio_data}"
|
||||
|
||||
return self.parse_audio(audio_url)
|
||||
return self.parse_audio(audio_url, uuid)
|
||||
|
||||
def parse_video(self, video_url: str) -> None:
|
||||
def parse_video(self, video_url: str, uuid: Optional[str] = None) -> None:
|
||||
video = self._connector.fetch_video_async(video_url=video_url)
|
||||
|
||||
placeholder = self._tracker.add("video", video)
|
||||
placeholder = self._tracker.add("video", video, uuid)
|
||||
self._add_placeholder("video", placeholder)
|
||||
|
||||
|
||||
@@ -1177,30 +1258,36 @@ def _parse_chat_message_content_part(
|
||||
else:
|
||||
return str_content
|
||||
|
||||
# For media items, if a user has provided one, use it. Otherwise, insert
|
||||
# a placeholder empty uuid.
|
||||
uuid = part.get("uuid", None)
|
||||
if uuid is not None:
|
||||
uuid = str(uuid)
|
||||
|
||||
modality = None
|
||||
if part_type == "image_pil":
|
||||
image_content = cast(Image.Image, content)
|
||||
mm_parser.parse_image_pil(image_content)
|
||||
mm_parser.parse_image_pil(image_content, uuid)
|
||||
modality = "image"
|
||||
elif part_type in ("image_url", "input_image"):
|
||||
str_content = cast(str, content)
|
||||
mm_parser.parse_image(str_content)
|
||||
mm_parser.parse_image(str_content, uuid)
|
||||
modality = "image"
|
||||
elif part_type == "image_embeds":
|
||||
content = cast(Union[str, dict[str, str]], content)
|
||||
mm_parser.parse_image_embeds(content)
|
||||
mm_parser.parse_image_embeds(content, uuid)
|
||||
modality = "image"
|
||||
elif part_type == "audio_url":
|
||||
str_content = cast(str, content)
|
||||
mm_parser.parse_audio(str_content)
|
||||
mm_parser.parse_audio(str_content, uuid)
|
||||
modality = "audio"
|
||||
elif part_type == "input_audio":
|
||||
dict_content = cast(InputAudio, content)
|
||||
mm_parser.parse_input_audio(dict_content)
|
||||
mm_parser.parse_input_audio(dict_content, uuid)
|
||||
modality = "audio"
|
||||
elif part_type == "video_url":
|
||||
str_content = cast(str, content)
|
||||
mm_parser.parse_video(str_content)
|
||||
mm_parser.parse_video(str_content, uuid)
|
||||
modality = "video"
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown part type: {part_type}")
|
||||
@@ -1288,7 +1375,11 @@ def parse_chat_messages(
|
||||
model_config: ModelConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
content_format: _ChatTemplateContentFormat,
|
||||
) -> tuple[list[ConversationMessage], Optional[MultiModalDataDict]]:
|
||||
) -> tuple[
|
||||
list[ConversationMessage],
|
||||
Optional[MultiModalDataDict],
|
||||
Optional[MultiModalUUIDDict],
|
||||
]:
|
||||
conversation: list[ConversationMessage] = []
|
||||
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
|
||||
|
||||
@@ -1308,7 +1399,7 @@ def parse_chat_messages(
|
||||
|
||||
_postprocess_messages(conversation)
|
||||
|
||||
return conversation, mm_tracker.all_mm_data()
|
||||
return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
|
||||
|
||||
|
||||
def parse_chat_messages_futures(
|
||||
@@ -1316,7 +1407,11 @@ def parse_chat_messages_futures(
|
||||
model_config: ModelConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
content_format: _ChatTemplateContentFormat,
|
||||
) -> tuple[list[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]:
|
||||
) -> tuple[
|
||||
list[ConversationMessage],
|
||||
Awaitable[Optional[MultiModalDataDict]],
|
||||
Optional[MultiModalUUIDDict],
|
||||
]:
|
||||
conversation: list[ConversationMessage] = []
|
||||
mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
|
||||
|
||||
@@ -1336,7 +1431,7 @@ def parse_chat_messages_futures(
|
||||
|
||||
_postprocess_messages(conversation)
|
||||
|
||||
return conversation, mm_tracker.all_mm_data()
|
||||
return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
|
||||
|
||||
|
||||
def apply_hf_chat_template(
|
||||
|
||||
Reference in New Issue
Block a user