[Frontend] Automatic detection of chat content format from AST (#9919)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -2,12 +2,14 @@ import asyncio
|
||||
import codecs
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from collections import defaultdict, deque
|
||||
from functools import lru_cache, partial
|
||||
from pathlib import Path
|
||||
from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
|
||||
Literal, Mapping, Optional, Tuple, TypeVar, Union, cast)
|
||||
|
||||
import jinja2.nodes
|
||||
import transformers.utils.chat_template_utils as hf_chat_utils
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from openai.types.chat import (ChatCompletionAssistantMessageParam,
|
||||
@@ -153,6 +155,199 @@ class ConversationMessage(TypedDict, total=False):
|
||||
"""The tool calls generated by the model, such as function calls."""
|
||||
|
||||
|
||||
# Passed in by user
|
||||
ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]
|
||||
|
||||
# Used internally
|
||||
_ChatTemplateContentFormat = Literal["string", "openai"]
|
||||
|
||||
|
||||
def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
|
||||
if isinstance(node, jinja2.nodes.Name):
|
||||
return node.ctx == "load" and node.name == varname
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
|
||||
if isinstance(node, jinja2.nodes.Getitem):
|
||||
return (_is_var_access(node.node, varname)
|
||||
and isinstance(node.arg, jinja2.nodes.Const)
|
||||
and node.arg.value == key)
|
||||
|
||||
if isinstance(node, jinja2.nodes.Getattr):
|
||||
return _is_var_access(node.node, varname) and node.attr == key
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _is_var_or_elems_access(
|
||||
node: jinja2.nodes.Node,
|
||||
varname: str,
|
||||
key: Optional[str] = None,
|
||||
) -> bool:
|
||||
if isinstance(node, jinja2.nodes.Filter):
|
||||
return (node.node is not None
|
||||
and _is_var_or_elems_access(node.node, varname, key))
|
||||
if isinstance(node, jinja2.nodes.Test):
|
||||
return _is_var_or_elems_access(node.node, varname, key)
|
||||
|
||||
if (isinstance(node, jinja2.nodes.Getitem)
|
||||
and isinstance(node.arg, jinja2.nodes.Slice)):
|
||||
return _is_var_or_elems_access(node.node, varname, key)
|
||||
|
||||
# yapf: disable
|
||||
return (
|
||||
_is_attr_access(node, varname, key) if key
|
||||
else _is_var_access(node, varname)
|
||||
) # yapf: enable
|
||||
|
||||
|
||||
def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str):
|
||||
# Global variable that is implicitly defined at the root
|
||||
yield root, varname
|
||||
|
||||
# Iterative BFS
|
||||
related_varnames = deque([varname])
|
||||
while related_varnames:
|
||||
related_varname = related_varnames.popleft()
|
||||
|
||||
for assign_ast in root.find_all(jinja2.nodes.Assign):
|
||||
lhs = assign_ast.target
|
||||
rhs = assign_ast.node
|
||||
|
||||
if _is_var_or_elems_access(rhs, related_varname):
|
||||
assert isinstance(lhs, jinja2.nodes.Name)
|
||||
yield assign_ast, lhs.name
|
||||
|
||||
# Avoid infinite looping for self-assignment
|
||||
if lhs.name != related_varname:
|
||||
related_varnames.append(lhs.name)
|
||||
|
||||
|
||||
# NOTE: The proper way to handle this is to build a CFG so that we can handle
|
||||
# the scope in which each variable is defined, but that is too complicated
|
||||
def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
|
||||
messages_varnames = [
|
||||
varname
|
||||
for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
|
||||
]
|
||||
|
||||
# Search for {%- for message in messages -%} loops
|
||||
for loop_ast in root.find_all(jinja2.nodes.For):
|
||||
loop_iter = loop_ast.iter
|
||||
loop_target = loop_ast.target
|
||||
|
||||
for varname in messages_varnames:
|
||||
if _is_var_or_elems_access(loop_iter, varname):
|
||||
assert isinstance(loop_target, jinja2.nodes.Name)
|
||||
yield loop_ast, loop_target.name
|
||||
break
|
||||
|
||||
|
||||
def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
|
||||
message_varnames = [
|
||||
varname for _, varname in _iter_nodes_assign_messages_item(root)
|
||||
]
|
||||
|
||||
# Search for {%- for content in message['content'] -%} loops
|
||||
for loop_ast in root.find_all(jinja2.nodes.For):
|
||||
loop_iter = loop_ast.iter
|
||||
loop_target = loop_ast.target
|
||||
|
||||
for varname in message_varnames:
|
||||
if _is_var_or_elems_access(loop_iter, varname, "content"):
|
||||
assert isinstance(loop_target, jinja2.nodes.Name)
|
||||
yield loop_ast, loop_target.name
|
||||
break
|
||||
|
||||
|
||||
def _try_extract_ast(chat_template: str) -> Optional[jinja2.nodes.Template]:
|
||||
try:
|
||||
jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
|
||||
return jinja_compiled.environment.parse(chat_template)
|
||||
except Exception:
|
||||
logger.exception("Error when compiling Jinja template")
|
||||
return None
|
||||
|
||||
|
||||
def _detect_content_format(
|
||||
chat_template: str,
|
||||
*,
|
||||
default: _ChatTemplateContentFormat,
|
||||
) -> _ChatTemplateContentFormat:
|
||||
jinja_ast = _try_extract_ast(chat_template)
|
||||
if jinja_ast is None:
|
||||
return default
|
||||
|
||||
try:
|
||||
next(_iter_nodes_assign_content_item(jinja_ast))
|
||||
except StopIteration:
|
||||
return "string"
|
||||
except Exception:
|
||||
logger.exception("Error when parsing AST of Jinja template")
|
||||
return default
|
||||
else:
|
||||
return "openai"
|
||||
|
||||
|
||||
def _resolve_chat_template_content_format(
|
||||
chat_template: Optional[str],
|
||||
given_format: ChatTemplateContentFormatOption,
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> _ChatTemplateContentFormat:
|
||||
if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
|
||||
tokenizer_chat_template = tokenizer.chat_template
|
||||
else:
|
||||
tokenizer_chat_template = None
|
||||
|
||||
jinja_text: Optional[str]
|
||||
if isinstance(tokenizer_chat_template, str) and chat_template is None:
|
||||
jinja_text = tokenizer_chat_template
|
||||
elif (isinstance(tokenizer_chat_template, dict)
|
||||
and chat_template in tokenizer_chat_template):
|
||||
jinja_text = tokenizer_chat_template[chat_template]
|
||||
else:
|
||||
jinja_text = load_chat_template(chat_template, is_literal=True)
|
||||
|
||||
detected_format = ("string" if jinja_text is None else
|
||||
_detect_content_format(jinja_text, default="string"))
|
||||
|
||||
return detected_format if given_format == "auto" else given_format
|
||||
|
||||
|
||||
@lru_cache
|
||||
def resolve_chat_template_content_format(
|
||||
chat_template: Optional[str],
|
||||
given_format: ChatTemplateContentFormatOption,
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> _ChatTemplateContentFormat:
|
||||
detected_format = _resolve_chat_template_content_format(
|
||||
chat_template,
|
||||
given_format,
|
||||
tokenizer,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Detected the chat template content format to be '%s'. "
|
||||
"You can set `--chat-template-content-format` to override this.",
|
||||
detected_format,
|
||||
)
|
||||
|
||||
if given_format != "auto" and given_format != detected_format:
|
||||
logger.warning(
|
||||
"You specified `--chat-template-content-format %s` "
|
||||
"which is different from the detected format '%s'. "
|
||||
"If our automatic detection is incorrect, please consider "
|
||||
"opening a GitHub issue so that we can improve it: "
|
||||
"https://github.com/vllm-project/vllm/issues/new/choose",
|
||||
given_format,
|
||||
detected_format,
|
||||
)
|
||||
|
||||
return detected_format
|
||||
|
||||
|
||||
ModalityStr = Literal["image", "audio", "video"]
|
||||
_T = TypeVar("_T")
|
||||
|
||||
@@ -407,12 +602,23 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]):
|
||||
|
||||
|
||||
def load_chat_template(
|
||||
chat_template: Optional[Union[Path, str]]) -> Optional[str]:
|
||||
chat_template: Optional[Union[Path, str]],
|
||||
*,
|
||||
is_literal: bool = False,
|
||||
) -> Optional[str]:
|
||||
if chat_template is None:
|
||||
return None
|
||||
|
||||
if is_literal:
|
||||
if isinstance(chat_template, Path):
|
||||
raise TypeError("chat_template is expected to be read directly "
|
||||
"from its value")
|
||||
|
||||
return codecs.decode(chat_template, "unicode_escape")
|
||||
|
||||
try:
|
||||
with open(chat_template) as f:
|
||||
resolved_chat_template = f.read()
|
||||
return f.read()
|
||||
except OSError as e:
|
||||
if isinstance(chat_template, Path):
|
||||
raise
|
||||
@@ -426,10 +632,7 @@ def load_chat_template(
|
||||
|
||||
# If opening a file fails, set chat template to be args to
|
||||
# ensure we decode so our escape are interpreted correctly
|
||||
resolved_chat_template = codecs.decode(chat_template, "unicode_escape")
|
||||
|
||||
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
|
||||
return resolved_chat_template
|
||||
return load_chat_template(chat_template, is_literal=True)
|
||||
|
||||
|
||||
# TODO: Let user specify how to insert multimodal tokens into prompt
|
||||
@@ -464,7 +667,6 @@ _ImageParser = partial(cast, ChatCompletionContentPartImageParam)
|
||||
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
|
||||
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
|
||||
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
|
||||
MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'}
|
||||
|
||||
# Define a mapping from part types to their corresponding parsing functions.
|
||||
MM_PARSER_MAP: Dict[str, Callable[[ChatCompletionContentPartParam], str]] = {
|
||||
@@ -542,18 +744,12 @@ def _parse_chat_message_content_parts(
|
||||
role: str,
|
||||
parts: Iterable[ChatCompletionContentPartParam],
|
||||
mm_tracker: BaseMultiModalItemTracker,
|
||||
chat_template_text_format: str,
|
||||
*,
|
||||
wrap_dicts: bool,
|
||||
) -> List[ConversationMessage]:
|
||||
content: List[Union[str, Dict[str, str]]] = []
|
||||
|
||||
mm_parser = mm_tracker.create_parser()
|
||||
model_config = mm_tracker.model_config
|
||||
|
||||
wrap_dicts = (chat_template_text_format == "openai"
|
||||
or (model_config.task == "embedding"
|
||||
and model_config.is_multimodal_model)
|
||||
or (model_config.hf_config.model_type
|
||||
in MODEL_KEEP_MULTI_MODAL_CONTENT))
|
||||
|
||||
for part in parts:
|
||||
parse_res = _parse_chat_message_content_part(
|
||||
@@ -578,9 +774,11 @@ def _parse_chat_message_content_parts(
|
||||
|
||||
|
||||
def _parse_chat_message_content_part(
|
||||
part: ChatCompletionContentPartParam,
|
||||
mm_parser: BaseMultiModalContentParser,
|
||||
wrap_dicts: bool) -> Optional[Union[str, Dict[str, str]]]:
|
||||
part: ChatCompletionContentPartParam,
|
||||
mm_parser: BaseMultiModalContentParser,
|
||||
*,
|
||||
wrap_dicts: bool,
|
||||
) -> Optional[Union[str, Dict[str, str]]]:
|
||||
"""Parses a single part of a conversation. If wrap_dicts is True,
|
||||
structured dictionary pieces for texts and images will be
|
||||
wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
|
||||
@@ -629,7 +827,7 @@ _ToolParser = partial(cast, ChatCompletionToolMessageParam)
|
||||
def _parse_chat_message_content(
|
||||
message: ChatCompletionMessageParam,
|
||||
mm_tracker: BaseMultiModalItemTracker,
|
||||
chat_template_text_format: str,
|
||||
content_format: _ChatTemplateContentFormat,
|
||||
) -> List[ConversationMessage]:
|
||||
role = message["role"]
|
||||
content = message.get("content")
|
||||
@@ -645,7 +843,7 @@ def _parse_chat_message_content(
|
||||
role,
|
||||
content, # type: ignore
|
||||
mm_tracker,
|
||||
chat_template_text_format,
|
||||
wrap_dicts=(content_format == "openai"),
|
||||
)
|
||||
|
||||
for result_msg in result:
|
||||
@@ -684,6 +882,7 @@ def parse_chat_messages(
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
model_config: ModelConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
content_format: _ChatTemplateContentFormat,
|
||||
) -> Tuple[List[ConversationMessage], Optional[MultiModalDataDict]]:
|
||||
conversation: List[ConversationMessage] = []
|
||||
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
|
||||
@@ -692,7 +891,7 @@ def parse_chat_messages(
|
||||
sub_messages = _parse_chat_message_content(
|
||||
msg,
|
||||
mm_tracker,
|
||||
model_config.chat_template_text_format,
|
||||
content_format,
|
||||
)
|
||||
|
||||
conversation.extend(sub_messages)
|
||||
@@ -706,6 +905,7 @@ def parse_chat_messages_futures(
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
model_config: ModelConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
content_format: _ChatTemplateContentFormat,
|
||||
) -> Tuple[List[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]:
|
||||
conversation: List[ConversationMessage] = []
|
||||
mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
|
||||
@@ -714,7 +914,7 @@ def parse_chat_messages_futures(
|
||||
sub_messages = _parse_chat_message_content(
|
||||
msg,
|
||||
mm_tracker,
|
||||
model_config.chat_template_text_format,
|
||||
content_format,
|
||||
)
|
||||
|
||||
conversation.extend(sub_messages)
|
||||
|
||||
Reference in New Issue
Block a user