[Frontend] Introduce Renderer for processing chat messages (using ModelConfig) (#30200)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -2,22 +2,15 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import Counter, defaultdict, deque
|
||||
from collections import Counter, defaultdict
|
||||
from collections.abc import Awaitable, Callable, Iterable
|
||||
from functools import cached_property, lru_cache, partial
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast
|
||||
|
||||
import jinja2
|
||||
import jinja2.ext
|
||||
import jinja2.meta
|
||||
import jinja2.nodes
|
||||
import jinja2.parser
|
||||
import jinja2.sandbox
|
||||
import transformers.utils.chat_template_utils as hf_chat_utils
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionContentPartImageParam,
|
||||
@@ -39,7 +32,6 @@ from openai.types.responses import ResponseInputImageParam
|
||||
from openai_harmony import Message as OpenAIHarmonyMessage
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, ConfigDict, TypeAdapter
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin
|
||||
|
||||
# pydantic needs the TypedDict from typing_extensions
|
||||
from typing_extensions import Required, TypedDict
|
||||
@@ -50,24 +42,35 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.models import SupportsMultiModal
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
|
||||
from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
from vllm.utils.func_utils import supports_kw
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
else:
|
||||
torch = LazyLoader("torch", globals(), "torch")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name == "resolve_hf_chat_template":
|
||||
from vllm.renderers.hf import resolve_chat_template
|
||||
|
||||
warnings.warn(
|
||||
"`vllm.entrypoints.chat_utils.resolve_hf_chat_template` has been moved to "
|
||||
"`vllm.renderers.hf.resolve_chat_template`. "
|
||||
"The old name will be removed in v0.16.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return resolve_chat_template
|
||||
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
class ChatTemplateResolutionError(ValueError):
|
||||
"""Raised when chat template resolution fails.
|
||||
|
||||
@@ -320,325 +323,8 @@ class ConversationMessage(TypedDict, total=False):
|
||||
# Passed in by user
|
||||
ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]
|
||||
|
||||
# Used internally
|
||||
_ChatTemplateContentFormat = Literal["string", "openai"]
|
||||
|
||||
|
||||
def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
|
||||
if isinstance(node, jinja2.nodes.Name):
|
||||
return node.ctx == "load" and node.name == varname
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
|
||||
if isinstance(node, jinja2.nodes.Getitem):
|
||||
return (
|
||||
_is_var_access(node.node, varname)
|
||||
and isinstance(node.arg, jinja2.nodes.Const)
|
||||
and node.arg.value == key
|
||||
)
|
||||
|
||||
if isinstance(node, jinja2.nodes.Getattr):
|
||||
return _is_var_access(node.node, varname) and node.attr == key
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _is_var_or_elems_access(
|
||||
node: jinja2.nodes.Node,
|
||||
varname: str,
|
||||
key: str | None = None,
|
||||
) -> bool:
|
||||
if isinstance(node, jinja2.nodes.Filter):
|
||||
return node.node is not None and _is_var_or_elems_access(
|
||||
node.node, varname, key
|
||||
)
|
||||
if isinstance(node, jinja2.nodes.Test):
|
||||
return _is_var_or_elems_access(node.node, varname, key)
|
||||
|
||||
if isinstance(node, jinja2.nodes.Getitem) and isinstance(
|
||||
node.arg, jinja2.nodes.Slice
|
||||
):
|
||||
return _is_var_or_elems_access(node.node, varname, key)
|
||||
|
||||
return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname)
|
||||
|
||||
|
||||
def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str):
|
||||
# Global variable that is implicitly defined at the root
|
||||
yield root, varname
|
||||
|
||||
# Iterative BFS
|
||||
related_varnames = deque([varname])
|
||||
while related_varnames:
|
||||
related_varname = related_varnames.popleft()
|
||||
|
||||
for assign_ast in root.find_all(jinja2.nodes.Assign):
|
||||
lhs = assign_ast.target
|
||||
rhs = assign_ast.node
|
||||
|
||||
if _is_var_or_elems_access(rhs, related_varname):
|
||||
assert isinstance(lhs, jinja2.nodes.Name)
|
||||
yield assign_ast, lhs.name
|
||||
|
||||
# Avoid infinite looping for self-assignment
|
||||
if lhs.name != related_varname:
|
||||
related_varnames.append(lhs.name)
|
||||
|
||||
|
||||
# NOTE: The proper way to handle this is to build a CFG so that we can handle
|
||||
# the scope in which each variable is defined, but that is too complicated
|
||||
def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
|
||||
messages_varnames = [
|
||||
varname for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
|
||||
]
|
||||
|
||||
# Search for {%- for message in messages -%} loops
|
||||
for loop_ast in root.find_all(jinja2.nodes.For):
|
||||
loop_iter = loop_ast.iter
|
||||
loop_target = loop_ast.target
|
||||
|
||||
for varname in messages_varnames:
|
||||
if _is_var_or_elems_access(loop_iter, varname):
|
||||
assert isinstance(loop_target, jinja2.nodes.Name)
|
||||
yield loop_ast, loop_target.name
|
||||
break
|
||||
|
||||
|
||||
def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
|
||||
message_varnames = [
|
||||
varname for _, varname in _iter_nodes_assign_messages_item(root)
|
||||
]
|
||||
|
||||
# Search for {%- for content in message['content'] -%} loops
|
||||
for loop_ast in root.find_all(jinja2.nodes.For):
|
||||
loop_iter = loop_ast.iter
|
||||
loop_target = loop_ast.target
|
||||
|
||||
for varname in message_varnames:
|
||||
if _is_var_or_elems_access(loop_iter, varname, "content"):
|
||||
assert isinstance(loop_target, jinja2.nodes.Name)
|
||||
yield loop_ast, loop_target.name
|
||||
break
|
||||
|
||||
|
||||
def _try_extract_ast(chat_template: str) -> jinja2.nodes.Template | None:
|
||||
try:
|
||||
jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
|
||||
return jinja_compiled.environment.parse(chat_template)
|
||||
except Exception:
|
||||
logger.exception("Error when compiling Jinja template")
|
||||
return None
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _detect_content_format(
|
||||
chat_template: str,
|
||||
*,
|
||||
default: _ChatTemplateContentFormat,
|
||||
) -> _ChatTemplateContentFormat:
|
||||
jinja_ast = _try_extract_ast(chat_template)
|
||||
if jinja_ast is None:
|
||||
return default
|
||||
|
||||
try:
|
||||
next(_iter_nodes_assign_content_item(jinja_ast))
|
||||
except StopIteration:
|
||||
return "string"
|
||||
except Exception:
|
||||
logger.exception("Error when parsing AST of Jinja template")
|
||||
return default
|
||||
else:
|
||||
return "openai"
|
||||
|
||||
|
||||
def resolve_mistral_chat_template(
|
||||
chat_template: str | None,
|
||||
**kwargs: Any,
|
||||
) -> str | None:
|
||||
if chat_template is not None or kwargs.get("chat_template_kwargs") is not None:
|
||||
raise ValueError(
|
||||
"'chat_template' or 'chat_template_kwargs' cannot be overridden "
|
||||
"for mistral tokenizer."
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], str | None]()
|
||||
"""
|
||||
Used in `_try_get_processor_chat_template` to avoid calling
|
||||
`cached_get_processor` again if the processor fails to be loaded.
|
||||
|
||||
This is needed because `lru_cache` does not cache when an exception happens.
|
||||
"""
|
||||
|
||||
|
||||
def _try_get_processor_chat_template(
|
||||
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
|
||||
model_config: ModelConfig,
|
||||
) -> str | None:
|
||||
cache_key = (tokenizer.name_or_path, model_config.trust_remote_code)
|
||||
if cache_key in _PROCESSOR_CHAT_TEMPLATES:
|
||||
return _PROCESSOR_CHAT_TEMPLATES[cache_key]
|
||||
|
||||
try:
|
||||
processor = cached_get_processor(
|
||||
tokenizer.name_or_path,
|
||||
processor_cls=(
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
ProcessorMixin,
|
||||
),
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
)
|
||||
if (
|
||||
isinstance(processor, ProcessorMixin)
|
||||
and hasattr(processor, "chat_template")
|
||||
and (chat_template := processor.chat_template) is not None
|
||||
):
|
||||
_PROCESSOR_CHAT_TEMPLATES[cache_key] = chat_template
|
||||
return chat_template
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Failed to load AutoProcessor chat template for %s",
|
||||
tokenizer.name_or_path,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
_PROCESSOR_CHAT_TEMPLATES[cache_key] = None
|
||||
return None
|
||||
|
||||
|
||||
def resolve_hf_chat_template(
|
||||
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
) -> str | None:
|
||||
# 1st priority: The given chat template
|
||||
if chat_template is not None:
|
||||
return chat_template
|
||||
|
||||
# 2nd priority: AutoProcessor chat template, unless tool calling is enabled
|
||||
if tools is None:
|
||||
chat_template = _try_get_processor_chat_template(tokenizer, model_config)
|
||||
if chat_template is not None:
|
||||
return chat_template
|
||||
|
||||
# 3rd priority: AutoTokenizer chat template
|
||||
try:
|
||||
return tokenizer.get_chat_template(chat_template, tools=tools)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Failed to load AutoTokenizer chat template for %s",
|
||||
tokenizer.name_or_path,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# 4th priority: Predefined fallbacks
|
||||
path = get_chat_template_fallback_path(
|
||||
model_type=model_config.hf_config.model_type,
|
||||
tokenizer_name_or_path=model_config.tokenizer,
|
||||
)
|
||||
if path is not None:
|
||||
logger.info_once(
|
||||
"Loading chat template fallback for %s as there isn't one "
|
||||
"defined on HF Hub.",
|
||||
tokenizer.name_or_path,
|
||||
)
|
||||
chat_template = load_chat_template(path)
|
||||
else:
|
||||
logger.debug_once(
|
||||
"There is no chat template fallback for %s", tokenizer.name_or_path
|
||||
)
|
||||
|
||||
return chat_template
|
||||
|
||||
|
||||
def _resolve_chat_template_content_format(
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
tokenizer: TokenizerLike | None,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
) -> _ChatTemplateContentFormat:
|
||||
if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
|
||||
hf_chat_template = resolve_hf_chat_template(
|
||||
tokenizer,
|
||||
chat_template=chat_template,
|
||||
tools=tools,
|
||||
model_config=model_config,
|
||||
)
|
||||
else:
|
||||
hf_chat_template = None
|
||||
|
||||
jinja_text = (
|
||||
hf_chat_template
|
||||
if isinstance(hf_chat_template, str)
|
||||
else load_chat_template(chat_template, is_literal=True)
|
||||
)
|
||||
|
||||
detected_format = (
|
||||
"string"
|
||||
if jinja_text is None
|
||||
else _detect_content_format(jinja_text, default="string")
|
||||
)
|
||||
|
||||
return detected_format
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _log_chat_template_content_format(
|
||||
chat_template: str | None,
|
||||
given_format: ChatTemplateContentFormatOption,
|
||||
detected_format: ChatTemplateContentFormatOption,
|
||||
):
|
||||
logger.info(
|
||||
"Detected the chat template content format to be '%s'. "
|
||||
"You can set `--chat-template-content-format` to override this.",
|
||||
detected_format,
|
||||
)
|
||||
|
||||
if given_format != "auto" and given_format != detected_format:
|
||||
logger.warning(
|
||||
"You specified `--chat-template-content-format %s` "
|
||||
"which is different from the detected format '%s'. "
|
||||
"If our automatic detection is incorrect, please consider "
|
||||
"opening a GitHub issue so that we can improve it: "
|
||||
"https://github.com/vllm-project/vllm/issues/new/choose",
|
||||
given_format,
|
||||
detected_format,
|
||||
)
|
||||
|
||||
|
||||
def resolve_chat_template_content_format(
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
given_format: ChatTemplateContentFormatOption,
|
||||
tokenizer: TokenizerLike | None,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
) -> _ChatTemplateContentFormat:
|
||||
if given_format != "auto":
|
||||
return given_format
|
||||
|
||||
detected_format = _resolve_chat_template_content_format(
|
||||
chat_template,
|
||||
tools,
|
||||
tokenizer,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
_log_chat_template_content_format(
|
||||
chat_template,
|
||||
given_format=given_format,
|
||||
detected_format=detected_format,
|
||||
)
|
||||
|
||||
return detected_format
|
||||
# After resolving "auto"
|
||||
ChatTemplateContentFormat = Literal["string", "openai"]
|
||||
|
||||
|
||||
ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"]
|
||||
@@ -1593,7 +1279,7 @@ _ToolParser = partial(cast, ChatCompletionToolMessageParam)
|
||||
def _parse_chat_message_content(
|
||||
message: ChatCompletionMessageParam,
|
||||
mm_tracker: BaseMultiModalItemTracker,
|
||||
content_format: _ChatTemplateContentFormat,
|
||||
content_format: ChatTemplateContentFormat,
|
||||
interleave_strings: bool,
|
||||
) -> list[ConversationMessage]:
|
||||
role = message["role"]
|
||||
@@ -1669,7 +1355,7 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None:
|
||||
def parse_chat_messages(
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
model_config: ModelConfig,
|
||||
content_format: _ChatTemplateContentFormat,
|
||||
content_format: ChatTemplateContentFormat,
|
||||
) -> tuple[
|
||||
list[ConversationMessage],
|
||||
MultiModalDataDict | None,
|
||||
@@ -1697,13 +1383,13 @@ def parse_chat_messages(
|
||||
return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
|
||||
|
||||
|
||||
def parse_chat_messages_futures(
|
||||
async def parse_chat_messages_async(
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
model_config: ModelConfig,
|
||||
content_format: _ChatTemplateContentFormat,
|
||||
content_format: ChatTemplateContentFormat,
|
||||
) -> tuple[
|
||||
list[ConversationMessage],
|
||||
Awaitable[MultiModalDataDict | None],
|
||||
MultiModalDataDict | None,
|
||||
MultiModalUUIDDict | None,
|
||||
]:
|
||||
conversation: list[ConversationMessage] = []
|
||||
@@ -1725,174 +1411,7 @@ def parse_chat_messages_futures(
|
||||
|
||||
_postprocess_messages(conversation)
|
||||
|
||||
return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
|
||||
|
||||
|
||||
# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412
|
||||
# only preserve the parse function used to resolve chat template kwargs
|
||||
class AssistantTracker(jinja2.ext.Extension):
|
||||
tags = {"generation"}
|
||||
|
||||
def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
|
||||
lineno = next(parser.stream).lineno
|
||||
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
|
||||
call = self.call_method("_generation_support")
|
||||
call_block = jinja2.nodes.CallBlock(call, [], [], body)
|
||||
return call_block.set_lineno(lineno)
|
||||
|
||||
|
||||
def _resolve_chat_template_kwargs(
|
||||
chat_template: str,
|
||||
):
|
||||
env = jinja2.sandbox.ImmutableSandboxedEnvironment(
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
extensions=[AssistantTracker, jinja2.ext.loopcontrols],
|
||||
)
|
||||
parsed_content = env.parse(chat_template)
|
||||
template_vars = jinja2.meta.find_undeclared_variables(parsed_content)
|
||||
return template_vars
|
||||
|
||||
|
||||
_cached_resolve_chat_template_kwargs = lru_cache(_resolve_chat_template_kwargs)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _get_hf_base_chat_template_params() -> frozenset[str]:
|
||||
# Get standard parameters from HuggingFace's base tokenizer class.
|
||||
# This dynamically extracts parameters from PreTrainedTokenizer's
|
||||
# apply_chat_template method, ensuring compatibility with tokenizers
|
||||
# that use **kwargs to receive standard parameters.
|
||||
|
||||
# Read signature from HF's base class - the single source of truth
|
||||
base_sig = inspect.signature(PreTrainedTokenizer.apply_chat_template)
|
||||
# Exclude VAR_KEYWORD (**kwargs) and VAR_POSITIONAL (*args) placeholders
|
||||
return frozenset(
|
||||
p.name
|
||||
for p in base_sig.parameters.values()
|
||||
if p.kind
|
||||
not in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL)
|
||||
)
|
||||
|
||||
|
||||
def resolve_chat_template_kwargs(
|
||||
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
|
||||
chat_template: str,
|
||||
chat_template_kwargs: dict[str, Any],
|
||||
raise_on_unexpected: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
# We exclude chat_template from kwargs here, because
|
||||
# chat template has been already resolved at this stage
|
||||
unexpected_vars = {"chat_template", "tokenize"}
|
||||
if raise_on_unexpected and (
|
||||
unexpected_in_kwargs := unexpected_vars & chat_template_kwargs.keys()
|
||||
):
|
||||
raise ValueError(
|
||||
"Found unexpected chat template kwargs from request: "
|
||||
f"{unexpected_in_kwargs}"
|
||||
)
|
||||
|
||||
fn_kw = {
|
||||
k
|
||||
for k in chat_template_kwargs
|
||||
if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
|
||||
}
|
||||
template_vars = _cached_resolve_chat_template_kwargs(chat_template)
|
||||
|
||||
# Allow standard HF parameters even if tokenizer uses **kwargs to receive them
|
||||
hf_base_params = _get_hf_base_chat_template_params()
|
||||
|
||||
accept_vars = (fn_kw | template_vars | hf_base_params) - unexpected_vars
|
||||
return {k: v for k, v in chat_template_kwargs.items() if k in accept_vars}
|
||||
|
||||
|
||||
def apply_hf_chat_template(
|
||||
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
|
||||
conversation: list[ConversationMessage],
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
hf_chat_template = resolve_hf_chat_template(
|
||||
tokenizer,
|
||||
chat_template=chat_template,
|
||||
tools=tools,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
if hf_chat_template is None:
|
||||
raise ChatTemplateResolutionError(
|
||||
"As of transformers v4.44, default chat template is no longer "
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
"does not define one."
|
||||
)
|
||||
|
||||
resolved_kwargs = resolve_chat_template_kwargs(
|
||||
tokenizer=tokenizer,
|
||||
chat_template=hf_chat_template,
|
||||
chat_template_kwargs=kwargs,
|
||||
)
|
||||
|
||||
try:
|
||||
return tokenizer.apply_chat_template(
|
||||
conversation=conversation, # type: ignore[arg-type]
|
||||
tools=tools, # type: ignore[arg-type]
|
||||
chat_template=hf_chat_template,
|
||||
tokenize=False,
|
||||
**resolved_kwargs,
|
||||
)
|
||||
|
||||
# External library exceptions can sometimes occur despite the framework's
|
||||
# internal exception management capabilities.
|
||||
except Exception as e:
|
||||
# Log and report any library-related exceptions for further
|
||||
# investigation.
|
||||
logger.exception(
|
||||
"An error occurred in `transformers` while applying chat template"
|
||||
)
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
|
||||
def apply_mistral_chat_template(
|
||||
tokenizer: "MistralTokenizer",
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
**kwargs: Any,
|
||||
) -> list[int]:
|
||||
from mistral_common.exceptions import MistralCommonException
|
||||
|
||||
# The return value of resolve_mistral_chat_template is always None,
|
||||
# and we won't use it.
|
||||
resolve_mistral_chat_template(
|
||||
chat_template=chat_template,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
try:
|
||||
return tokenizer.apply_chat_template(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
# mistral-common uses assert statements to stop processing of input
|
||||
# if input does not comply with the expected format.
|
||||
# We convert those assertion errors to ValueErrors so they can be
|
||||
# properly caught in the preprocessing_input step
|
||||
except (AssertionError, MistralCommonException) as e:
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
# External library exceptions can sometimes occur despite the framework's
|
||||
# internal exception management capabilities.
|
||||
except Exception as e:
|
||||
# Log and report any library-related exceptions for further
|
||||
# investigation.
|
||||
logger.exception(
|
||||
"An error occurred in `mistral_common` while applying chat template"
|
||||
)
|
||||
raise ValueError(str(e)) from e
|
||||
return conversation, await mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
|
||||
|
||||
|
||||
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
|
||||
|
||||
Reference in New Issue
Block a user