[Frontend][VLM] Add support for multiple multi-modal items (#8049)

This commit is contained in:
Roger Wang
2024-08-31 16:35:53 -07:00
committed by GitHub
parent 8423aef4c8
commit 5231f0898e
8 changed files with 524 additions and 136 deletions

View File

@@ -1,9 +1,10 @@
import asyncio
import codecs
from dataclasses import dataclass
from collections import defaultdict
from functools import lru_cache
from pathlib import Path
from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple,
Union)
from typing import (Any, Awaitable, Dict, Iterable, List, Literal, Mapping,
Optional, Tuple, Union)
# yapf conflicts with isort for this block
# yapf: disable
@@ -80,10 +81,90 @@ class ConversationMessage(TypedDict):
content: str
@dataclass(frozen=True)
class ChatMessageParseResult:
messages: List[ConversationMessage]
mm_futures: List[Awaitable[MultiModalDataDict]]
class MultiModalItemTracker:
"""
Tracks multi-modal items in a given request and ensures that the number
of multi-modal items in a given request does not exceed the configured
maximum per prompt.
"""
def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
self._model_config = model_config
self._tokenizer = tokenizer
self._allowed_items = (model_config.multimodal_config.limit_per_prompt
if model_config.multimodal_config else {})
self._consumed_items = {k: 0 for k in self._allowed_items}
self._futures: List[Awaitable[MultiModalDataDict]] = []
@staticmethod
@lru_cache(maxsize=None)
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int):
return tokenizer.decode(token_index)
def add(self, modality: Literal["image", "audio"],
mm_future: Awaitable[MultiModalDataDict]) -> Optional[str]:
"""
Adds the multi-modal item to the current prompt and returns the
placeholder string to use, if any.
"""
allowed_count = self._allowed_items.get(modality, 1)
current_count = self._consumed_items.get(modality, 0) + 1
if current_count > allowed_count:
raise ValueError(
f"At most {allowed_count} {modality}(s) may be provided in "
"one request.")
self._consumed_items[modality] = current_count
self._futures.append(mm_future)
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
model_type = self._model_config.hf_config.model_type
if modality == "image":
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return f"<|image_{current_count}|>"
if model_type == "minicpmv":
return "(<image>./</image>)"
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
# These models do not use image tokens in the prompt
return None
if model_type.startswith("llava"):
return MultiModalItemTracker._cached_token_str(
self._tokenizer,
self._model_config.hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat"):
return "<image>"
raise TypeError(f"Unknown model type: {model_type}")
elif modality == "audio":
if model_type == "ultravox":
return "<|reserved_special_token_0|>"
raise TypeError(f"Unknown model type: {model_type}")
else:
raise TypeError(f"Unknown modality: {modality}")
@staticmethod
async def _combine(futures: List[Awaitable[MultiModalDataDict]]):
mm_lists: Mapping[str, List[object]] = defaultdict(list)
# Merge all the multi-modal items
for single_mm_data in (await asyncio.gather(*futures)):
for mm_key, mm_item in single_mm_data.items():
if isinstance(mm_item, list):
mm_lists[mm_key].extend(mm_item)
else:
mm_lists[mm_key].append(mm_item)
# Unpack any single item lists for models that don't expect multiple.
return {
mm_key: mm_list[0] if len(mm_list) == 1 else mm_list
for mm_key, mm_list in mm_lists.items()
}
def all_mm_data(self) -> Optional[Awaitable[MultiModalDataDict]]:
return MultiModalItemTracker._combine(
self._futures) if self._futures else None
def load_chat_template(
@@ -112,44 +193,30 @@ def load_chat_template(
return resolved_chat_template
@lru_cache(maxsize=None)
def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer,
modality: Literal["image", "audio"]) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
model_type = model_config.hf_config.model_type
if modality == "image":
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return "<|image_1|>"
if model_type == "minicpmv":
return "(<image>./</image>)"
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
# These models do not use image tokens in the prompt
return None
if model_type.startswith("llava"):
return tokenizer.decode(model_config.hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat"):
return "<image>"
raise TypeError(f"Unknown model type: {model_type}")
elif modality == "audio":
if model_type == "ultravox":
return "<|reserved_special_token_0|>"
raise TypeError(f"Unknown model type: {model_type}")
else:
raise TypeError(f"Unknown modality: {modality}")
# TODO: Let user specify how to insert multimodal tokens into prompt
# (similar to chat template)
def _get_full_multimodal_text_prompt(placeholder_token_str: str,
def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
text_prompt: str) -> str:
"""Combine multimodal prompts for a multimodal language model"""
# NOTE: For now we assume all model architectures use the same
# placeholder + text prompt format. This may change in the future.
return f"{placeholder_token_str}\n{text_prompt}"
# Look through the text prompt to check for missing placeholders
missing_placeholders = []
for placeholder in placeholder_counts:
# For any existing placeholder in the text prompt, we leave it as is
placeholder_counts[placeholder] -= text_prompt.count(placeholder)
if placeholder_counts[placeholder] < 0:
raise ValueError(
f"Found more '{placeholder}' placeholders in input prompt than "
"actual multimodal data items.")
missing_placeholders.extend([placeholder] *
placeholder_counts[placeholder])
# NOTE: For now we always add missing placeholders at the front of
# the prompt. This may change to be customizable in the future.
return "\n".join(missing_placeholders + [text_prompt])
_TextParser = TypeAdapter(ChatCompletionContentPartTextParam)
@@ -160,12 +227,12 @@ _AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam)
def _parse_chat_message_content_parts(
role: str,
parts: Iterable[ChatCompletionContentPartParam],
model_config: ModelConfig,
tokenizer: AnyTokenizer,
) -> ChatMessageParseResult:
mm_tracker: MultiModalItemTracker,
) -> List[ConversationMessage]:
texts: List[str] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
modality: Literal["image", "audio"] = "image"
# multimodal placeholder_string : count
mm_placeholder_counts: Dict[str, int] = {}
for part in parts:
part_type = part["type"]
@@ -173,11 +240,6 @@ def _parse_chat_message_content_parts(
text = _TextParser.validate_python(part)["text"]
texts.append(text)
elif part_type == "image_url":
modality = "image"
if len(mm_futures) > 0:
raise NotImplementedError(
"Multiple multimodal inputs is currently not supported.")
image_url = _ImageParser.validate_python(part)["image_url"]
if image_url.get("detail", "auto") != "auto":
@@ -185,60 +247,44 @@ def _parse_chat_message_content_parts(
"'image_url.detail' is currently not supported and "
"will be ignored.")
image_future = async_get_and_parse_image(image_url["url"])
mm_futures.append(image_future)
image_coro = async_get_and_parse_image(image_url["url"])
placeholder = mm_tracker.add("image", image_coro)
if placeholder:
mm_placeholder_counts[placeholder] = mm_placeholder_counts.get(
placeholder, 0) + 1
elif part_type == "audio_url":
modality = "audio"
if len(mm_futures) > 0:
raise NotImplementedError(
"Multiple multimodal inputs is currently not supported.")
audio_url = _AudioParser.validate_python(part)["audio_url"]
audio_future = async_get_and_parse_audio(audio_url["url"])
mm_futures.append(audio_future)
audio_coro = async_get_and_parse_audio(audio_url["url"])
placeholder = mm_tracker.add("audio", audio_coro)
if placeholder:
mm_placeholder_counts[placeholder] = mm_placeholder_counts.get(
placeholder, 0) + 1
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
text_prompt = "\n".join(texts)
if mm_placeholder_counts:
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
text_prompt)
if mm_futures:
placeholder_token_str = _mm_token_str(model_config, tokenizer,
modality)
if placeholder_token_str is not None:
if placeholder_token_str in text_prompt:
logger.warning(
"Detected multi-modal token string in the text prompt. "
"Skipping prompt formatting.")
else:
text_prompt = _get_full_multimodal_text_prompt(
placeholder_token_str=placeholder_token_str,
text_prompt=text_prompt,
)
messages = [ConversationMessage(role=role, content=text_prompt)]
return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
return [ConversationMessage(role=role, content=text_prompt)]
def _parse_chat_message_content(
message: ChatCompletionMessageParam,
model_config: ModelConfig,
tokenizer: AnyTokenizer,
) -> ChatMessageParseResult:
message: ChatCompletionMessageParam,
mm_tracker: MultiModalItemTracker) -> List[ConversationMessage]:
role = message["role"]
content = message.get("content")
if content is None:
return ChatMessageParseResult(messages=[], mm_futures=[])
return []
if isinstance(content, str):
messages = [ConversationMessage(role=role, content=content)]
return ChatMessageParseResult(messages=messages, mm_futures=[])
return [ConversationMessage(role=role, content=content)]
return _parse_chat_message_content_parts(
role,
content, # type: ignore
model_config,
tokenizer,
mm_tracker,
)
@@ -246,18 +292,16 @@ def parse_chat_messages(
messages: List[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: AnyTokenizer,
) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]:
) -> Tuple[List[ConversationMessage], Optional[Awaitable[MultiModalDataDict]]]:
conversation: List[ConversationMessage] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
for msg in messages:
parse_result = _parse_chat_message_content(msg, model_config,
tokenizer)
sub_messages = _parse_chat_message_content(msg, mm_tracker)
conversation.extend(parse_result.messages)
mm_futures.extend(parse_result.mm_futures)
conversation.extend(sub_messages)
return conversation, mm_futures
return conversation, mm_tracker.all_mm_data()
def apply_chat_template(