[Frontend] Multimodal support in offline chat (#8098)

This commit is contained in:
Cyrus Leung
2024-09-04 13:22:17 +08:00
committed by GitHub
parent 2be8ec6e71
commit 855c262a6b
8 changed files with 356 additions and 112 deletions

View File

@@ -1,10 +1,11 @@
import asyncio
import codecs
from abc import ABC, abstractmethod
from collections import defaultdict
from functools import lru_cache
from pathlib import Path
from typing import (Any, Awaitable, Dict, Iterable, List, Literal, Mapping,
Optional, Tuple, Union)
from typing import (Any, Awaitable, Dict, Generic, Iterable, List, Literal,
Mapping, Optional, Tuple, TypeVar, Union)
# yapf conflicts with isort for this block
# yapf: disable
@@ -23,7 +24,8 @@ from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import (async_get_and_parse_audio,
async_get_and_parse_image)
async_get_and_parse_image,
get_and_parse_audio, get_and_parse_image)
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
@@ -81,7 +83,11 @@ class ConversationMessage(TypedDict):
content: str
class MultiModalItemTracker:
ModalityStr = Literal["image", "audio"]
_T = TypeVar("_T")
class BaseMultiModalItemTracker(ABC, Generic[_T]):
"""
Tracks multi-modal items in a given request and ensures that the number
of multi-modal items in a given request does not exceed the configured
@@ -89,37 +95,28 @@ class MultiModalItemTracker:
"""
def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
super().__init__()
self._model_config = model_config
self._tokenizer = tokenizer
self._allowed_items = (model_config.multimodal_config.limit_per_prompt
if model_config.multimodal_config else {})
self._consumed_items = {k: 0 for k in self._allowed_items}
self._futures: List[Awaitable[MultiModalDataDict]] = []
self._items: List[_T] = []
@staticmethod
@lru_cache(maxsize=None)
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int):
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
return tokenizer.decode(token_index)
def add(self, modality: Literal["image", "audio"],
mm_future: Awaitable[MultiModalDataDict]) -> Optional[str]:
"""
Adds the multi-modal item to the current prompt and returns the
placeholder string to use, if any.
"""
allowed_count = self._allowed_items.get(modality, 1)
current_count = self._consumed_items.get(modality, 0) + 1
if current_count > allowed_count:
raise ValueError(
f"At most {allowed_count} {modality}(s) may be provided in "
"one request.")
self._consumed_items[modality] = current_count
self._futures.append(mm_future)
def _placeholder_str(self, modality: ModalityStr,
current_count: int) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
model_type = self._model_config.hf_config.model_type
hf_config = self._model_config.hf_config
model_type = hf_config.model_type
if modality == "image":
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
@@ -130,9 +127,8 @@ class MultiModalItemTracker:
# These models do not use image tokens in the prompt
return None
if model_type.startswith("llava"):
return MultiModalItemTracker._cached_token_str(
self._tokenizer,
self._model_config.hf_config.image_token_index)
return self._cached_token_str(self._tokenizer,
hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat"):
return "<image>"
@@ -145,11 +141,11 @@ class MultiModalItemTracker:
raise TypeError(f"Unknown modality: {modality}")
@staticmethod
async def _combine(futures: List[Awaitable[MultiModalDataDict]]):
def _combine(items: List[MultiModalDataDict]) -> MultiModalDataDict:
mm_lists: Mapping[str, List[object]] = defaultdict(list)
# Merge all the multi-modal items
for single_mm_data in (await asyncio.gather(*futures)):
for single_mm_data in items:
for mm_key, mm_item in single_mm_data.items():
if isinstance(mm_item, list):
mm_lists[mm_key].extend(mm_item)
@@ -162,9 +158,113 @@ class MultiModalItemTracker:
for mm_key, mm_list in mm_lists.items()
}
def all_mm_data(self) -> Optional[Awaitable[MultiModalDataDict]]:
return MultiModalItemTracker._combine(
self._futures) if self._futures else None
def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
"""
Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any.
"""
allowed_count = self._allowed_items.get(modality, 1)
current_count = self._consumed_items.get(modality, 0) + 1
if current_count > allowed_count:
raise ValueError(
f"At most {allowed_count} {modality}(s) may be provided in "
"one request.")
self._consumed_items[modality] = current_count
self._items.append(item)
return self._placeholder_str(modality, current_count)
@abstractmethod
def create_parser(self) -> "BaseMultiModalContentParser":
raise NotImplementedError
class MultiModalItemTracker(BaseMultiModalItemTracker[MultiModalDataDict]):
def all_mm_data(self) -> Optional[MultiModalDataDict]:
return self._combine(self._items) if self._items else None
def create_parser(self) -> "BaseMultiModalContentParser":
return MultiModalContentParser(self)
class AsyncMultiModalItemTracker(
BaseMultiModalItemTracker[Awaitable[MultiModalDataDict]]):
async def all_mm_data(self) -> Optional[MultiModalDataDict]:
if self._items:
items = await asyncio.gather(*self._items)
return self._combine(items)
return None
def create_parser(self) -> "BaseMultiModalContentParser":
return AsyncMultiModalContentParser(self)
class BaseMultiModalContentParser(ABC):
def __init__(self) -> None:
super().__init__()
# multimodal placeholder_string : count
self._placeholder_counts: Dict[str, int] = defaultdict(lambda: 0)
def _add_placeholder(self, placeholder: Optional[str]):
if placeholder:
self._placeholder_counts[placeholder] += 1
def mm_placeholder_counts(self) -> Dict[str, int]:
return dict(self._placeholder_counts)
@abstractmethod
def parse_image(self, image_url: str) -> None:
raise NotImplementedError
@abstractmethod
def parse_audio(self, audio_url: str) -> None:
raise NotImplementedError
class MultiModalContentParser(BaseMultiModalContentParser):
def __init__(self, tracker: MultiModalItemTracker) -> None:
super().__init__()
self._tracker = tracker
def parse_image(self, image_url: str) -> None:
image = get_and_parse_image(image_url)
placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder)
def parse_audio(self, audio_url: str) -> None:
audio = get_and_parse_audio(audio_url)
placeholder = self._tracker.add("audio", audio)
self._add_placeholder(placeholder)
class AsyncMultiModalContentParser(BaseMultiModalContentParser):
def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
super().__init__()
self._tracker = tracker
def parse_image(self, image_url: str) -> None:
image_coro = async_get_and_parse_image(image_url)
placeholder = self._tracker.add("image", image_coro)
self._add_placeholder(placeholder)
def parse_audio(self, audio_url: str) -> None:
audio_coro = async_get_and_parse_audio(audio_url)
placeholder = self._tracker.add("audio", audio_coro)
self._add_placeholder(placeholder)
def load_chat_template(
@@ -197,10 +297,10 @@ def load_chat_template(
# (similar to chat template)
def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
text_prompt: str) -> str:
"""Combine multimodal prompts for a multimodal language model"""
"""Combine multimodal prompts for a multimodal language model."""
# Look through the text prompt to check for missing placeholders
missing_placeholders = []
missing_placeholders: List[str] = []
for placeholder in placeholder_counts:
# For any existing placeholder in the text prompt, we leave it as is
@@ -227,12 +327,11 @@ _AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam)
def _parse_chat_message_content_parts(
role: str,
parts: Iterable[ChatCompletionContentPartParam],
mm_tracker: MultiModalItemTracker,
mm_tracker: BaseMultiModalItemTracker,
) -> List[ConversationMessage]:
texts: List[str] = []
# multimodal placeholder_string : count
mm_placeholder_counts: Dict[str, int] = {}
mm_parser = mm_tracker.create_parser()
for part in parts:
part_type = part["type"]
@@ -247,22 +346,16 @@ def _parse_chat_message_content_parts(
"'image_url.detail' is currently not supported and "
"will be ignored.")
image_coro = async_get_and_parse_image(image_url["url"])
placeholder = mm_tracker.add("image", image_coro)
if placeholder:
mm_placeholder_counts[placeholder] = mm_placeholder_counts.get(
placeholder, 0) + 1
mm_parser.parse_image(image_url["url"])
elif part_type == "audio_url":
audio_url = _AudioParser.validate_python(part)["audio_url"]
audio_coro = async_get_and_parse_audio(audio_url["url"])
placeholder = mm_tracker.add("audio", audio_coro)
if placeholder:
mm_placeholder_counts[placeholder] = mm_placeholder_counts.get(
placeholder, 0) + 1
mm_parser.parse_audio(audio_url["url"])
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
text_prompt = "\n".join(texts)
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
if mm_placeholder_counts:
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
text_prompt)
@@ -271,8 +364,9 @@ def _parse_chat_message_content_parts(
def _parse_chat_message_content(
message: ChatCompletionMessageParam,
mm_tracker: MultiModalItemTracker) -> List[ConversationMessage]:
message: ChatCompletionMessageParam,
mm_tracker: BaseMultiModalItemTracker,
) -> List[ConversationMessage]:
role = message["role"]
content = message.get("content")
@@ -292,7 +386,7 @@ def parse_chat_messages(
messages: List[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: AnyTokenizer,
) -> Tuple[List[ConversationMessage], Optional[Awaitable[MultiModalDataDict]]]:
) -> Tuple[List[ConversationMessage], Optional[MultiModalDataDict]]:
conversation: List[ConversationMessage] = []
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
@@ -304,6 +398,22 @@ def parse_chat_messages(
return conversation, mm_tracker.all_mm_data()
def parse_chat_messages_futures(
messages: List[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: AnyTokenizer,
) -> Tuple[List[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]:
conversation: List[ConversationMessage] = []
mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
for msg in messages:
sub_messages = _parse_chat_message_content(msg, mm_tracker)
conversation.extend(sub_messages)
return conversation, mm_tracker.all_mm_data()
def apply_chat_template(
tokenizer: AnyTokenizer,
conversation: List[ConversationMessage],