2024-07-16 12:18:09 +00:00
|
|
|
import codecs
|
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
from functools import lru_cache
|
2024-07-21 08:38:17 +08:00
|
|
|
from typing import Awaitable, Iterable, List, Optional, Union, cast, final
|
2024-07-16 12:18:09 +00:00
|
|
|
|
2024-07-21 08:38:17 +08:00
|
|
|
# yapf conflicts with isort for this block
|
|
|
|
|
# yapf: disable
|
|
|
|
|
from openai.types.chat import ChatCompletionContentPartImageParam
|
|
|
|
|
from openai.types.chat import (
|
|
|
|
|
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
|
|
|
|
|
from openai.types.chat import ChatCompletionContentPartTextParam
|
|
|
|
|
from openai.types.chat import (
|
|
|
|
|
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
|
|
|
|
|
# yapf: enable
|
|
|
|
|
# pydantic needs the TypedDict from typing_extensions
|
|
|
|
|
from pydantic import ConfigDict
|
2024-07-18 00:13:30 -07:00
|
|
|
from transformers import PreTrainedTokenizer
|
2024-07-21 08:38:17 +08:00
|
|
|
from typing_extensions import Required, TypedDict
|
2024-07-16 12:18:09 +00:00
|
|
|
|
2024-07-18 00:13:30 -07:00
|
|
|
from vllm.config import ModelConfig
|
2024-07-16 12:18:09 +00:00
|
|
|
from vllm.logger import init_logger
|
|
|
|
|
from vllm.multimodal import MultiModalDataDict
|
|
|
|
|
from vllm.multimodal.utils import async_get_and_parse_image
|
|
|
|
|
|
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
2024-07-21 08:38:17 +08:00
|
|
|
class CustomChatCompletionContentPartParam(TypedDict, total=False):
|
|
|
|
|
__pydantic_config__ = ConfigDict(extra="allow") # type: ignore
|
|
|
|
|
|
|
|
|
|
type: Required[str]
|
|
|
|
|
"""The type of the content part."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ChatCompletionContentPartParam = Union[OpenAIChatCompletionContentPartParam,
|
|
|
|
|
CustomChatCompletionContentPartParam]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CustomChatCompletionMessageParam(TypedDict, total=False):
|
|
|
|
|
"""Enables custom roles in the Chat Completion API."""
|
|
|
|
|
role: Required[str]
|
|
|
|
|
"""The role of the message's author."""
|
|
|
|
|
|
|
|
|
|
content: Union[str, List[ChatCompletionContentPartParam]]
|
|
|
|
|
"""The contents of the message."""
|
|
|
|
|
|
|
|
|
|
name: str
|
|
|
|
|
"""An optional name for the participant.
|
|
|
|
|
|
|
|
|
|
Provides the model information to differentiate between participants of the
|
|
|
|
|
same role.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
|
|
|
|
|
CustomChatCompletionMessageParam]
|
|
|
|
|
|
|
|
|
|
|
2024-07-16 12:18:09 +00:00
|
|
|
@final # So that it should be compatible with Dict[str, str]
|
|
|
|
|
class ConversationMessage(TypedDict):
|
|
|
|
|
role: str
|
|
|
|
|
content: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
|
class ChatMessageParseResult:
|
|
|
|
|
messages: List[ConversationMessage]
|
|
|
|
|
mm_futures: List[Awaitable[MultiModalDataDict]] = field(
|
|
|
|
|
default_factory=list)
|
|
|
|
|
|
|
|
|
|
|
2024-07-18 00:13:30 -07:00
|
|
|
def load_chat_template(chat_template: Optional[str]) -> Optional[str]:
|
|
|
|
|
if chat_template is None:
|
|
|
|
|
return None
|
|
|
|
|
try:
|
|
|
|
|
with open(chat_template, "r") as f:
|
|
|
|
|
resolved_chat_template = f.read()
|
|
|
|
|
except OSError as e:
|
|
|
|
|
JINJA_CHARS = "{}\n"
|
|
|
|
|
if not any(c in chat_template for c in JINJA_CHARS):
|
|
|
|
|
msg = (f"The supplied chat template ({chat_template}) "
|
|
|
|
|
f"looks like a file path, but it failed to be "
|
|
|
|
|
f"opened. Reason: {e}")
|
|
|
|
|
raise ValueError(msg) from e
|
2024-07-16 12:18:09 +00:00
|
|
|
|
2024-07-18 00:13:30 -07:00
|
|
|
# If opening a file fails, set chat template to be args to
|
|
|
|
|
# ensure we decode so our escape are interpreted correctly
|
|
|
|
|
resolved_chat_template = codecs.decode(chat_template, "unicode_escape")
|
2024-07-16 12:18:09 +00:00
|
|
|
|
2024-07-18 00:13:30 -07:00
|
|
|
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
|
|
|
|
|
return resolved_chat_template
|
2024-07-16 12:18:09 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@lru_cache(maxsize=None)
|
2024-07-18 00:13:30 -07:00
|
|
|
def _image_token_str(model_config: ModelConfig,
|
|
|
|
|
tokenizer: PreTrainedTokenizer) -> Optional[str]:
|
2024-07-16 12:18:09 +00:00
|
|
|
# TODO: Let user specify how to insert image tokens into prompt
|
|
|
|
|
# (similar to chat template)
|
2024-07-18 00:13:30 -07:00
|
|
|
model_type = model_config.hf_config.model_type
|
2024-07-16 12:18:09 +00:00
|
|
|
if model_type == "phi3_v":
|
|
|
|
|
# Workaround since this token is not defined in the tokenizer
|
|
|
|
|
return "<|image_1|>"
|
|
|
|
|
if model_type in ("blip-2", "chatglm", "fuyu", "minicpmv", "paligemma"):
|
|
|
|
|
# These models do not use image tokens in the prompt
|
|
|
|
|
return None
|
|
|
|
|
if model_type.startswith("llava"):
|
2024-07-18 00:13:30 -07:00
|
|
|
return tokenizer.decode(model_config.hf_config.image_token_index)
|
2024-07-22 23:50:48 -07:00
|
|
|
if model_type == "chameleon":
|
|
|
|
|
return "<image>"
|
2024-07-18 00:13:30 -07:00
|
|
|
raise TypeError("Unknown model type: {model_type}")
|
2024-07-16 12:18:09 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO: Let user specify how to insert image tokens into prompt
|
|
|
|
|
# (similar to chat template)
|
2024-07-18 00:13:30 -07:00
|
|
|
def _get_full_image_text_prompt(image_token_str: str, text_prompt: str) -> str:
|
2024-07-16 12:18:09 +00:00
|
|
|
"""Combine image and text prompts for vision language model"""
|
|
|
|
|
|
|
|
|
|
# NOTE: For now we assume all model architectures use the same
|
|
|
|
|
# image + text prompt format. This may change in the future.
|
|
|
|
|
return f"{image_token_str}\n{text_prompt}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _parse_chat_message_content_parts(
|
|
|
|
|
role: str,
|
|
|
|
|
parts: Iterable[ChatCompletionContentPartParam],
|
2024-07-18 00:13:30 -07:00
|
|
|
model_config: ModelConfig,
|
|
|
|
|
tokenizer: PreTrainedTokenizer,
|
2024-07-16 12:18:09 +00:00
|
|
|
) -> ChatMessageParseResult:
|
|
|
|
|
texts: List[str] = []
|
|
|
|
|
mm_futures: List[Awaitable[MultiModalDataDict]] = []
|
|
|
|
|
|
|
|
|
|
for part in parts:
|
|
|
|
|
part_type = part["type"]
|
|
|
|
|
if part_type == "text":
|
|
|
|
|
text = cast(ChatCompletionContentPartTextParam, part)["text"]
|
|
|
|
|
texts.append(text)
|
|
|
|
|
elif part_type == "image_url":
|
|
|
|
|
if len(mm_futures) > 0:
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"Multiple 'image_url' input is currently not supported.")
|
|
|
|
|
|
|
|
|
|
image_url = cast(ChatCompletionContentPartImageParam,
|
|
|
|
|
part)["image_url"]
|
|
|
|
|
|
|
|
|
|
if image_url.get("detail", "auto") != "auto":
|
|
|
|
|
logger.warning(
|
|
|
|
|
"'image_url.detail' is currently not supported and "
|
|
|
|
|
"will be ignored.")
|
|
|
|
|
|
|
|
|
|
image_future = async_get_and_parse_image(image_url["url"])
|
|
|
|
|
mm_futures.append(image_future)
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError(f"Unknown part type: {part_type}")
|
|
|
|
|
|
|
|
|
|
text_prompt = "\n".join(texts)
|
|
|
|
|
|
|
|
|
|
if mm_futures:
|
2024-07-18 00:13:30 -07:00
|
|
|
image_token_str = _image_token_str(model_config, tokenizer)
|
2024-07-16 12:18:09 +00:00
|
|
|
if image_token_str is not None:
|
|
|
|
|
if image_token_str in text_prompt:
|
|
|
|
|
logger.warning(
|
|
|
|
|
"Detected image token string in the text prompt. "
|
|
|
|
|
"Skipping prompt formatting.")
|
|
|
|
|
else:
|
|
|
|
|
text_prompt = _get_full_image_text_prompt(
|
|
|
|
|
image_token_str=image_token_str,
|
|
|
|
|
text_prompt=text_prompt,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
messages = [ConversationMessage(role=role, content=text_prompt)]
|
|
|
|
|
|
|
|
|
|
return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_chat_message_content(
|
|
|
|
|
message: ChatCompletionMessageParam,
|
2024-07-18 00:13:30 -07:00
|
|
|
model_config: ModelConfig,
|
|
|
|
|
tokenizer: PreTrainedTokenizer,
|
2024-07-16 12:18:09 +00:00
|
|
|
) -> ChatMessageParseResult:
|
|
|
|
|
role = message["role"]
|
|
|
|
|
content = message.get("content")
|
|
|
|
|
|
|
|
|
|
if content is None:
|
|
|
|
|
return ChatMessageParseResult(messages=[], mm_futures=[])
|
|
|
|
|
if isinstance(content, str):
|
|
|
|
|
messages = [ConversationMessage(role=role, content=content)]
|
|
|
|
|
return ChatMessageParseResult(messages=messages, mm_futures=[])
|
|
|
|
|
|
2024-07-18 00:13:30 -07:00
|
|
|
return _parse_chat_message_content_parts(role, content, model_config,
|
|
|
|
|
tokenizer)
|