[Frontend] Multimodal support in offline chat (#8098)
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
import asyncio
|
||||
import codecs
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import (Any, Awaitable, Dict, Iterable, List, Literal, Mapping,
|
||||
Optional, Tuple, Union)
|
||||
from typing import (Any, Awaitable, Dict, Generic, Iterable, List, Literal,
|
||||
Mapping, Optional, Tuple, TypeVar, Union)
|
||||
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
@@ -23,7 +24,8 @@ from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.multimodal.utils import (async_get_and_parse_audio,
|
||||
async_get_and_parse_image)
|
||||
async_get_and_parse_image,
|
||||
get_and_parse_audio, get_and_parse_image)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -81,7 +83,11 @@ class ConversationMessage(TypedDict):
|
||||
content: str
|
||||
|
||||
|
||||
class MultiModalItemTracker:
|
||||
ModalityStr = Literal["image", "audio"]
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
"""
|
||||
Tracks multi-modal items in a given request and ensures that the number
|
||||
of multi-modal items in a given request does not exceed the configured
|
||||
@@ -89,37 +95,28 @@ class MultiModalItemTracker:
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
|
||||
super().__init__()
|
||||
|
||||
self._model_config = model_config
|
||||
self._tokenizer = tokenizer
|
||||
self._allowed_items = (model_config.multimodal_config.limit_per_prompt
|
||||
if model_config.multimodal_config else {})
|
||||
self._consumed_items = {k: 0 for k in self._allowed_items}
|
||||
self._futures: List[Awaitable[MultiModalDataDict]] = []
|
||||
|
||||
self._items: List[_T] = []
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=None)
|
||||
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int):
|
||||
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
|
||||
return tokenizer.decode(token_index)
|
||||
|
||||
def add(self, modality: Literal["image", "audio"],
|
||||
mm_future: Awaitable[MultiModalDataDict]) -> Optional[str]:
|
||||
"""
|
||||
Adds the multi-modal item to the current prompt and returns the
|
||||
placeholder string to use, if any.
|
||||
"""
|
||||
allowed_count = self._allowed_items.get(modality, 1)
|
||||
current_count = self._consumed_items.get(modality, 0) + 1
|
||||
if current_count > allowed_count:
|
||||
raise ValueError(
|
||||
f"At most {allowed_count} {modality}(s) may be provided in "
|
||||
"one request.")
|
||||
|
||||
self._consumed_items[modality] = current_count
|
||||
self._futures.append(mm_future)
|
||||
|
||||
def _placeholder_str(self, modality: ModalityStr,
|
||||
current_count: int) -> Optional[str]:
|
||||
# TODO: Let user specify how to insert image tokens into prompt
|
||||
# (similar to chat template)
|
||||
model_type = self._model_config.hf_config.model_type
|
||||
hf_config = self._model_config.hf_config
|
||||
model_type = hf_config.model_type
|
||||
|
||||
if modality == "image":
|
||||
if model_type == "phi3_v":
|
||||
# Workaround since this token is not defined in the tokenizer
|
||||
@@ -130,9 +127,8 @@ class MultiModalItemTracker:
|
||||
# These models do not use image tokens in the prompt
|
||||
return None
|
||||
if model_type.startswith("llava"):
|
||||
return MultiModalItemTracker._cached_token_str(
|
||||
self._tokenizer,
|
||||
self._model_config.hf_config.image_token_index)
|
||||
return self._cached_token_str(self._tokenizer,
|
||||
hf_config.image_token_index)
|
||||
if model_type in ("chameleon", "internvl_chat"):
|
||||
return "<image>"
|
||||
|
||||
@@ -145,11 +141,11 @@ class MultiModalItemTracker:
|
||||
raise TypeError(f"Unknown modality: {modality}")
|
||||
|
||||
@staticmethod
|
||||
async def _combine(futures: List[Awaitable[MultiModalDataDict]]):
|
||||
def _combine(items: List[MultiModalDataDict]) -> MultiModalDataDict:
|
||||
mm_lists: Mapping[str, List[object]] = defaultdict(list)
|
||||
|
||||
# Merge all the multi-modal items
|
||||
for single_mm_data in (await asyncio.gather(*futures)):
|
||||
for single_mm_data in items:
|
||||
for mm_key, mm_item in single_mm_data.items():
|
||||
if isinstance(mm_item, list):
|
||||
mm_lists[mm_key].extend(mm_item)
|
||||
@@ -162,9 +158,113 @@ class MultiModalItemTracker:
|
||||
for mm_key, mm_list in mm_lists.items()
|
||||
}
|
||||
|
||||
def all_mm_data(self) -> Optional[Awaitable[MultiModalDataDict]]:
|
||||
return MultiModalItemTracker._combine(
|
||||
self._futures) if self._futures else None
|
||||
def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
|
||||
"""
|
||||
Add a multi-modal item to the current prompt and returns the
|
||||
placeholder string to use, if any.
|
||||
"""
|
||||
allowed_count = self._allowed_items.get(modality, 1)
|
||||
current_count = self._consumed_items.get(modality, 0) + 1
|
||||
if current_count > allowed_count:
|
||||
raise ValueError(
|
||||
f"At most {allowed_count} {modality}(s) may be provided in "
|
||||
"one request.")
|
||||
|
||||
self._consumed_items[modality] = current_count
|
||||
self._items.append(item)
|
||||
|
||||
return self._placeholder_str(modality, current_count)
|
||||
|
||||
@abstractmethod
|
||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MultiModalItemTracker(BaseMultiModalItemTracker[MultiModalDataDict]):
|
||||
|
||||
def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
||||
return self._combine(self._items) if self._items else None
|
||||
|
||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||
return MultiModalContentParser(self)
|
||||
|
||||
|
||||
class AsyncMultiModalItemTracker(
|
||||
BaseMultiModalItemTracker[Awaitable[MultiModalDataDict]]):
|
||||
|
||||
async def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
||||
if self._items:
|
||||
items = await asyncio.gather(*self._items)
|
||||
return self._combine(items)
|
||||
|
||||
return None
|
||||
|
||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||
return AsyncMultiModalContentParser(self)
|
||||
|
||||
|
||||
class BaseMultiModalContentParser(ABC):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
# multimodal placeholder_string : count
|
||||
self._placeholder_counts: Dict[str, int] = defaultdict(lambda: 0)
|
||||
|
||||
def _add_placeholder(self, placeholder: Optional[str]):
|
||||
if placeholder:
|
||||
self._placeholder_counts[placeholder] += 1
|
||||
|
||||
def mm_placeholder_counts(self) -> Dict[str, int]:
|
||||
return dict(self._placeholder_counts)
|
||||
|
||||
@abstractmethod
|
||||
def parse_image(self, image_url: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
|
||||
def __init__(self, tracker: MultiModalItemTracker) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._tracker = tracker
|
||||
|
||||
def parse_image(self, image_url: str) -> None:
|
||||
image = get_and_parse_image(image_url)
|
||||
|
||||
placeholder = self._tracker.add("image", image)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
audio = get_and_parse_audio(audio_url)
|
||||
|
||||
placeholder = self._tracker.add("audio", audio)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
|
||||
class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
|
||||
def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._tracker = tracker
|
||||
|
||||
def parse_image(self, image_url: str) -> None:
|
||||
image_coro = async_get_and_parse_image(image_url)
|
||||
|
||||
placeholder = self._tracker.add("image", image_coro)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
audio_coro = async_get_and_parse_audio(audio_url)
|
||||
|
||||
placeholder = self._tracker.add("audio", audio_coro)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
|
||||
def load_chat_template(
|
||||
@@ -197,10 +297,10 @@ def load_chat_template(
|
||||
# (similar to chat template)
|
||||
def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
|
||||
text_prompt: str) -> str:
|
||||
"""Combine multimodal prompts for a multimodal language model"""
|
||||
"""Combine multimodal prompts for a multimodal language model."""
|
||||
|
||||
# Look through the text prompt to check for missing placeholders
|
||||
missing_placeholders = []
|
||||
missing_placeholders: List[str] = []
|
||||
for placeholder in placeholder_counts:
|
||||
|
||||
# For any existing placeholder in the text prompt, we leave it as is
|
||||
@@ -227,12 +327,11 @@ _AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam)
|
||||
def _parse_chat_message_content_parts(
|
||||
role: str,
|
||||
parts: Iterable[ChatCompletionContentPartParam],
|
||||
mm_tracker: MultiModalItemTracker,
|
||||
mm_tracker: BaseMultiModalItemTracker,
|
||||
) -> List[ConversationMessage]:
|
||||
texts: List[str] = []
|
||||
|
||||
# multimodal placeholder_string : count
|
||||
mm_placeholder_counts: Dict[str, int] = {}
|
||||
mm_parser = mm_tracker.create_parser()
|
||||
|
||||
for part in parts:
|
||||
part_type = part["type"]
|
||||
@@ -247,22 +346,16 @@ def _parse_chat_message_content_parts(
|
||||
"'image_url.detail' is currently not supported and "
|
||||
"will be ignored.")
|
||||
|
||||
image_coro = async_get_and_parse_image(image_url["url"])
|
||||
placeholder = mm_tracker.add("image", image_coro)
|
||||
if placeholder:
|
||||
mm_placeholder_counts[placeholder] = mm_placeholder_counts.get(
|
||||
placeholder, 0) + 1
|
||||
mm_parser.parse_image(image_url["url"])
|
||||
elif part_type == "audio_url":
|
||||
audio_url = _AudioParser.validate_python(part)["audio_url"]
|
||||
audio_coro = async_get_and_parse_audio(audio_url["url"])
|
||||
placeholder = mm_tracker.add("audio", audio_coro)
|
||||
if placeholder:
|
||||
mm_placeholder_counts[placeholder] = mm_placeholder_counts.get(
|
||||
placeholder, 0) + 1
|
||||
|
||||
mm_parser.parse_audio(audio_url["url"])
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown part type: {part_type}")
|
||||
|
||||
text_prompt = "\n".join(texts)
|
||||
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
|
||||
if mm_placeholder_counts:
|
||||
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
|
||||
text_prompt)
|
||||
@@ -271,8 +364,9 @@ def _parse_chat_message_content_parts(
|
||||
|
||||
|
||||
def _parse_chat_message_content(
|
||||
message: ChatCompletionMessageParam,
|
||||
mm_tracker: MultiModalItemTracker) -> List[ConversationMessage]:
|
||||
message: ChatCompletionMessageParam,
|
||||
mm_tracker: BaseMultiModalItemTracker,
|
||||
) -> List[ConversationMessage]:
|
||||
role = message["role"]
|
||||
content = message.get("content")
|
||||
|
||||
@@ -292,7 +386,7 @@ def parse_chat_messages(
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
model_config: ModelConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> Tuple[List[ConversationMessage], Optional[Awaitable[MultiModalDataDict]]]:
|
||||
) -> Tuple[List[ConversationMessage], Optional[MultiModalDataDict]]:
|
||||
conversation: List[ConversationMessage] = []
|
||||
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
|
||||
|
||||
@@ -304,6 +398,22 @@ def parse_chat_messages(
|
||||
return conversation, mm_tracker.all_mm_data()
|
||||
|
||||
|
||||
def parse_chat_messages_futures(
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
model_config: ModelConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> Tuple[List[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]:
|
||||
conversation: List[ConversationMessage] = []
|
||||
mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
|
||||
|
||||
for msg in messages:
|
||||
sub_messages = _parse_chat_message_content(msg, mm_tracker)
|
||||
|
||||
conversation.extend(sub_messages)
|
||||
|
||||
return conversation, mm_tracker.all_mm_data()
|
||||
|
||||
|
||||
def apply_chat_template(
|
||||
tokenizer: AnyTokenizer,
|
||||
conversation: List[ConversationMessage],
|
||||
|
||||
Reference in New Issue
Block a user