Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -8,8 +8,7 @@ from collections import Counter, defaultdict, deque
|
||||
from collections.abc import Awaitable, Iterable
|
||||
from functools import cached_property, lru_cache, partial
|
||||
from pathlib import Path
|
||||
from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union,
|
||||
cast)
|
||||
from typing import Any, Callable, Generic, Literal, Optional, TypeVar, Union, cast
|
||||
|
||||
import jinja2
|
||||
import jinja2.ext
|
||||
@@ -18,40 +17,45 @@ import jinja2.nodes
|
||||
import jinja2.parser
|
||||
import jinja2.sandbox
|
||||
import transformers.utils.chat_template_utils as hf_chat_utils
|
||||
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from openai.types.chat import (ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartInputAudioParam)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
|
||||
from openai.types.chat import (ChatCompletionContentPartRefusalParam,
|
||||
ChatCompletionContentPartTextParam)
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartInputAudioParam,
|
||||
ChatCompletionContentPartRefusalParam,
|
||||
ChatCompletionContentPartTextParam,
|
||||
ChatCompletionMessageToolCallParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
|
||||
from openai.types.chat import (ChatCompletionMessageToolCallParam,
|
||||
ChatCompletionToolMessageParam)
|
||||
from openai.types.chat.chat_completion_content_part_input_audio_param import (
|
||||
InputAudio)
|
||||
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam,
|
||||
)
|
||||
from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
|
||||
from openai.types.responses import ResponseInputImageParam
|
||||
from openai_harmony import Message as OpenAIHarmonyMessage
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, ConfigDict, TypeAdapter
|
||||
|
||||
# yapf: enable
|
||||
from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast,
|
||||
ProcessorMixin)
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin
|
||||
|
||||
# pydantic needs the TypedDict from typing_extensions
|
||||
from typing_extensions import Required, TypeAlias, TypedDict
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models import SupportsMultiModal
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
|
||||
MultiModalUUIDDict)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
|
||||
from vllm.multimodal.utils import MediaConnector
|
||||
|
||||
# yapf: disable
|
||||
from vllm.transformers_utils.chat_templates import (
|
||||
get_chat_template_fallback_path)
|
||||
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
|
||||
|
||||
# yapf: enable
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
@@ -284,9 +288,11 @@ def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
|
||||
|
||||
def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
|
||||
if isinstance(node, jinja2.nodes.Getitem):
|
||||
return (_is_var_access(node.node, varname)
|
||||
and isinstance(node.arg, jinja2.nodes.Const)
|
||||
and node.arg.value == key)
|
||||
return (
|
||||
_is_var_access(node.node, varname)
|
||||
and isinstance(node.arg, jinja2.nodes.Const)
|
||||
and node.arg.value == key
|
||||
)
|
||||
|
||||
if isinstance(node, jinja2.nodes.Getattr):
|
||||
return _is_var_access(node.node, varname) and node.attr == key
|
||||
@@ -301,12 +307,14 @@ def _is_var_or_elems_access(
|
||||
) -> bool:
|
||||
if isinstance(node, jinja2.nodes.Filter):
|
||||
return node.node is not None and _is_var_or_elems_access(
|
||||
node.node, varname, key)
|
||||
node.node, varname, key
|
||||
)
|
||||
if isinstance(node, jinja2.nodes.Test):
|
||||
return _is_var_or_elems_access(node.node, varname, key)
|
||||
|
||||
if isinstance(node, jinja2.nodes.Getitem) and isinstance(
|
||||
node.arg, jinja2.nodes.Slice):
|
||||
node.arg, jinja2.nodes.Slice
|
||||
):
|
||||
return _is_var_or_elems_access(node.node, varname, key)
|
||||
|
||||
# yapf: disable
|
||||
@@ -342,8 +350,7 @@ def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str):
|
||||
# the scope in which each variable is defined, but that is too complicated
|
||||
def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
|
||||
messages_varnames = [
|
||||
varname
|
||||
for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
|
||||
varname for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
|
||||
]
|
||||
|
||||
# Search for {%- for message in messages -%} loops
|
||||
@@ -484,8 +491,7 @@ def resolve_hf_chat_template(
|
||||
|
||||
# 2nd priority: AutoProcessor chat template, unless tool calling is enabled
|
||||
if tools is None:
|
||||
chat_template = _try_get_processor_chat_template(tokenizer,
|
||||
model_config)
|
||||
chat_template = _try_get_processor_chat_template(tokenizer, model_config)
|
||||
if chat_template is not None:
|
||||
return chat_template
|
||||
|
||||
@@ -678,16 +684,12 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
mm_uuids = {}
|
||||
uuids_by_modality = dict(self._uuids_by_modality)
|
||||
if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality:
|
||||
raise ValueError(
|
||||
"Mixing raw image and embedding inputs is not allowed"
|
||||
)
|
||||
raise ValueError("Mixing raw image and embedding inputs is not allowed")
|
||||
|
||||
if "image_embeds" in uuids_by_modality:
|
||||
image_embeds_uuids = uuids_by_modality["image_embeds"]
|
||||
if len(image_embeds_uuids) > 1:
|
||||
raise ValueError(
|
||||
"Only one message can have {'type': 'image_embeds'}"
|
||||
)
|
||||
raise ValueError("Only one message can have {'type': 'image_embeds'}")
|
||||
mm_uuids["image"] = uuids_by_modality["image_embeds"]
|
||||
if "image" in uuids_by_modality:
|
||||
mm_uuids["image"] = uuids_by_modality["image"] # UUIDs of images
|
||||
@@ -709,16 +711,12 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
|
||||
mm_inputs = {}
|
||||
items_by_modality = dict(self._items_by_modality)
|
||||
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
||||
raise ValueError(
|
||||
"Mixing raw image and embedding inputs is not allowed"
|
||||
)
|
||||
raise ValueError("Mixing raw image and embedding inputs is not allowed")
|
||||
|
||||
if "image_embeds" in items_by_modality:
|
||||
image_embeds_lst = items_by_modality["image_embeds"]
|
||||
if len(image_embeds_lst) > 1:
|
||||
raise ValueError(
|
||||
"Only one message can have {'type': 'image_embeds'}"
|
||||
)
|
||||
raise ValueError("Only one message can have {'type': 'image_embeds'}")
|
||||
mm_inputs["image"] = image_embeds_lst[0]
|
||||
if "image" in items_by_modality:
|
||||
mm_inputs["image"] = items_by_modality["image"] # A list of images
|
||||
@@ -748,16 +746,12 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
|
||||
items_by_modality[modality] = await asyncio.gather(*coros)
|
||||
|
||||
if "image" in items_by_modality and "image_embeds" in items_by_modality:
|
||||
raise ValueError(
|
||||
"Mixing raw image and embedding inputs is not allowed"
|
||||
)
|
||||
raise ValueError("Mixing raw image and embedding inputs is not allowed")
|
||||
|
||||
if "image_embeds" in items_by_modality:
|
||||
image_embeds_lst = items_by_modality["image_embeds"]
|
||||
if len(image_embeds_lst) > 1:
|
||||
raise ValueError(
|
||||
"Only one message can have {'type': 'image_embeds'}"
|
||||
)
|
||||
raise ValueError("Only one message can have {'type': 'image_embeds'}")
|
||||
mm_inputs["image"] = image_embeds_lst[0]
|
||||
if "image" in items_by_modality:
|
||||
mm_inputs["image"] = items_by_modality["image"] # A list of images
|
||||
@@ -783,9 +777,7 @@ class BaseMultiModalContentParser(ABC):
|
||||
# }
|
||||
self._placeholder_storage: dict[str, list] = defaultdict(list)
|
||||
|
||||
def _add_placeholder(
|
||||
self, modality: ModalityStr, placeholder: Optional[str]
|
||||
):
|
||||
def _add_placeholder(self, modality: ModalityStr, placeholder: Optional[str]):
|
||||
mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
|
||||
if placeholder:
|
||||
self._placeholder_storage[mod_placeholder].append(placeholder)
|
||||
@@ -794,8 +786,7 @@ class BaseMultiModalContentParser(ABC):
|
||||
return dict(self._placeholder_storage)
|
||||
|
||||
@abstractmethod
|
||||
def parse_image(
|
||||
self, image_url: Optional[str], uuid: Optional[str] = None) -> None:
|
||||
def parse_image(self, image_url: Optional[str], uuid: Optional[str] = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@@ -813,9 +804,7 @@ class BaseMultiModalContentParser(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_audio(
|
||||
self, audio_url: Optional[str], uuid: Optional[str] = None
|
||||
) -> None:
|
||||
def parse_audio(self, audio_url: Optional[str], uuid: Optional[str] = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@@ -825,9 +814,7 @@ class BaseMultiModalContentParser(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_video(
|
||||
self, video_url: Optional[str], uuid: Optional[str] = None
|
||||
) -> None:
|
||||
def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -844,9 +831,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
allowed_media_domains=tracker.allowed_media_domains,
|
||||
)
|
||||
|
||||
def parse_image(
|
||||
self, image_url: Optional[str], uuid: Optional[str] = None
|
||||
) -> None:
|
||||
def parse_image(self, image_url: Optional[str], uuid: Optional[str] = None) -> None:
|
||||
image = self._connector.fetch_image(image_url) if image_url else None
|
||||
|
||||
placeholder = self._tracker.add("image", image, uuid)
|
||||
@@ -879,9 +864,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
placeholder = self._tracker.add("image", image_pil, uuid)
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_audio(
|
||||
self, audio_url: Optional[str], uuid: Optional[str] = None
|
||||
) -> None:
|
||||
def parse_audio(self, audio_url: Optional[str], uuid: Optional[str] = None) -> None:
|
||||
audio = self._connector.fetch_audio(audio_url) if audio_url else None
|
||||
|
||||
placeholder = self._tracker.add("audio", audio, uuid)
|
||||
@@ -903,14 +886,8 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
|
||||
return self.parse_audio(audio_url, uuid)
|
||||
|
||||
def parse_video(
|
||||
self, video_url: Optional[str], uuid: Optional[str] = None
|
||||
) -> None:
|
||||
video = (
|
||||
self._connector.fetch_video(video_url=video_url)
|
||||
if video_url
|
||||
else None
|
||||
)
|
||||
def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None:
|
||||
video = self._connector.fetch_video(video_url=video_url) if video_url else None
|
||||
|
||||
placeholder = self._tracker.add("video", video, uuid)
|
||||
self._add_placeholder("video", placeholder)
|
||||
@@ -929,12 +906,8 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
allowed_media_domains=tracker.allowed_media_domains,
|
||||
)
|
||||
|
||||
def parse_image(
|
||||
self, image_url: Optional[str], uuid: Optional[str] = None
|
||||
) -> None:
|
||||
image_coro = (
|
||||
self._connector.fetch_image_async(image_url) if image_url else None
|
||||
)
|
||||
def parse_image(self, image_url: Optional[str], uuid: Optional[str] = None) -> None:
|
||||
image_coro = self._connector.fetch_image_async(image_url) if image_url else None
|
||||
|
||||
placeholder = self._tracker.add("image", image_coro, uuid)
|
||||
self._add_placeholder("image", placeholder)
|
||||
@@ -944,9 +917,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
image_embeds: Union[str, dict[str, str], None],
|
||||
uuid: Optional[str] = None,
|
||||
) -> None:
|
||||
future: asyncio.Future[Union[str, dict[str, str], None]] = (
|
||||
asyncio.Future()
|
||||
)
|
||||
future: asyncio.Future[Union[str, dict[str, str], None]] = asyncio.Future()
|
||||
|
||||
if isinstance(image_embeds, dict):
|
||||
embeds = {
|
||||
@@ -977,12 +948,8 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
placeholder = self._tracker.add("image", future, uuid)
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_audio(
|
||||
self, audio_url: Optional[str], uuid: Optional[str] = None
|
||||
) -> None:
|
||||
audio_coro = (
|
||||
self._connector.fetch_audio_async(audio_url) if audio_url else None
|
||||
)
|
||||
def parse_audio(self, audio_url: Optional[str], uuid: Optional[str] = None) -> None:
|
||||
audio_coro = self._connector.fetch_audio_async(audio_url) if audio_url else None
|
||||
|
||||
placeholder = self._tracker.add("audio", audio_coro, uuid)
|
||||
self._add_placeholder("audio", placeholder)
|
||||
@@ -1003,9 +970,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
|
||||
return self.parse_audio(audio_url, uuid)
|
||||
|
||||
def parse_video(
|
||||
self, video_url: Optional[str], uuid: Optional[str] = None
|
||||
) -> None:
|
||||
def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None:
|
||||
video = (
|
||||
self._connector.fetch_video_async(video_url=video_url)
|
||||
if video_url
|
||||
@@ -1036,9 +1001,7 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]):
|
||||
)
|
||||
|
||||
else:
|
||||
raise TypeError(
|
||||
f"{type(chat_template)} is not a valid chat template type"
|
||||
)
|
||||
raise TypeError(f"{type(chat_template)} is not a valid chat template type")
|
||||
|
||||
|
||||
def _load_chat_template(
|
||||
@@ -1145,9 +1108,7 @@ def _get_full_multimodal_text_prompt(
|
||||
"actual multimodal data items."
|
||||
)
|
||||
|
||||
missing_placeholders.extend(
|
||||
[placeholder] * placeholder_counts[placeholder]
|
||||
)
|
||||
missing_placeholders.extend([placeholder] * placeholder_counts[placeholder])
|
||||
|
||||
# NOTE: Default behaviour: we always add missing placeholders
|
||||
# at the front of the prompt, if interleave_strings=False
|
||||
@@ -1166,9 +1127,7 @@ _ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
|
||||
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
|
||||
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
|
||||
|
||||
_ResponsesInputImageParser = TypeAdapter(
|
||||
ResponseInputImageParam
|
||||
).validate_python
|
||||
_ResponsesInputImageParser = TypeAdapter(ResponseInputImageParam).validate_python
|
||||
_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio, PILImage]
|
||||
|
||||
# Define a mapping from part types to their corresponding parsing functions.
|
||||
@@ -1179,26 +1138,14 @@ MM_PARSER_MAP: dict[
|
||||
"text": lambda part: _TextParser(part).get("text", None),
|
||||
"thinking": lambda part: _ThinkParser(part).get("thinking", None),
|
||||
"input_text": lambda part: _TextParser(part).get("text", None),
|
||||
"input_image": lambda part: _ResponsesInputImageParser(part).get(
|
||||
"image_url", None
|
||||
),
|
||||
"image_url": lambda part: _ImageParser(part)
|
||||
.get("image_url", {})
|
||||
.get("url", None),
|
||||
"image_embeds": lambda part: _ImageEmbedsParser(part).get(
|
||||
"image_embeds", None
|
||||
),
|
||||
"input_image": lambda part: _ResponsesInputImageParser(part).get("image_url", None),
|
||||
"image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
|
||||
"image_embeds": lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
|
||||
"image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
|
||||
"audio_url": lambda part: _AudioParser(part)
|
||||
.get("audio_url", {})
|
||||
.get("url", None),
|
||||
"input_audio": lambda part: _InputAudioParser(part).get(
|
||||
"input_audio", None
|
||||
),
|
||||
"audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
|
||||
"input_audio": lambda part: _InputAudioParser(part).get("input_audio", None),
|
||||
"refusal": lambda part: _RefusalParser(part).get("refusal", None),
|
||||
"video_url": lambda part: _VideoParser(part)
|
||||
.get("video_url", {})
|
||||
.get("url", None),
|
||||
"video_url": lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
|
||||
}
|
||||
|
||||
|
||||
@@ -1225,15 +1172,14 @@ def _parse_chat_message_content_mm_part(
|
||||
part_type = part.get("type", None)
|
||||
uuid = part.get("uuid", None)
|
||||
|
||||
if isinstance(part_type, str) and part_type in MM_PARSER_MAP and uuid is None: # noqa: E501
|
||||
if isinstance(part_type, str) and part_type in MM_PARSER_MAP and uuid is None: # noqa: E501
|
||||
content = MM_PARSER_MAP[part_type](part)
|
||||
|
||||
# Special case for 'image_url.detail'
|
||||
# We only support 'auto', which is the default
|
||||
if part_type == "image_url" and part.get("detail", "auto") != "auto":
|
||||
logger.warning(
|
||||
"'image_url.detail' is currently not supported "
|
||||
"and will be ignored."
|
||||
"'image_url.detail' is currently not supported and will be ignored."
|
||||
)
|
||||
|
||||
return part_type, content
|
||||
@@ -1242,9 +1188,7 @@ def _parse_chat_message_content_mm_part(
|
||||
# 'type' is required field by pydantic
|
||||
if part_type is None or uuid is not None:
|
||||
if "image_url" in part:
|
||||
image_params = cast(
|
||||
CustomChatCompletionContentSimpleImageParam, part
|
||||
)
|
||||
image_params = cast(CustomChatCompletionContentSimpleImageParam, part)
|
||||
image_url = image_params.get("image_url", None)
|
||||
if isinstance(image_url, dict):
|
||||
# Can potentially happen if user provides a uuid
|
||||
@@ -1253,22 +1197,20 @@ def _parse_chat_message_content_mm_part(
|
||||
return "image_url", image_url
|
||||
if "image_pil" in part:
|
||||
# "image_pil" could be None if UUID is provided.
|
||||
image_params = cast( # type: ignore
|
||||
image_params = cast( # type: ignore
|
||||
CustomChatCompletionContentPILImageParam, part
|
||||
)
|
||||
image_pil = image_params.get("image_pil", None)
|
||||
return "image_pil", image_pil
|
||||
if "image_embeds" in part:
|
||||
# "image_embeds" could be None if UUID is provided.
|
||||
image_params = cast( # type: ignore
|
||||
image_params = cast( # type: ignore
|
||||
ChatCompletionContentPartImageEmbedsParam, part
|
||||
)
|
||||
image_embeds = image_params.get("image_embeds", None)
|
||||
return "image_embeds", image_embeds
|
||||
if "audio_url" in part:
|
||||
audio_params = cast(
|
||||
CustomChatCompletionContentSimpleAudioParam, part
|
||||
)
|
||||
audio_params = cast(CustomChatCompletionContentSimpleAudioParam, part)
|
||||
audio_url = audio_params.get("audio_url", None)
|
||||
if isinstance(audio_url, dict):
|
||||
# Can potentially happen if user provides a uuid
|
||||
@@ -1279,9 +1221,7 @@ def _parse_chat_message_content_mm_part(
|
||||
input_audio_params = cast(dict[str, str], part)
|
||||
return "input_audio", input_audio_params
|
||||
if "video_url" in part:
|
||||
video_params = cast(
|
||||
CustomChatCompletionContentSimpleVideoParam, part
|
||||
)
|
||||
video_params = cast(CustomChatCompletionContentSimpleVideoParam, part)
|
||||
video_url = video_params.get("video_url", None)
|
||||
if isinstance(video_url, dict):
|
||||
# Can potentially happen if user provides a uuid
|
||||
@@ -1418,9 +1358,7 @@ def _parse_chat_message_content_part(
|
||||
return (
|
||||
{"type": modality}
|
||||
if wrap_dicts
|
||||
else (
|
||||
MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None
|
||||
)
|
||||
else (MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None)
|
||||
)
|
||||
|
||||
|
||||
@@ -1441,9 +1379,7 @@ def _parse_chat_message_content(
|
||||
if content is None:
|
||||
content = []
|
||||
elif isinstance(content, str):
|
||||
content = [
|
||||
ChatCompletionContentPartTextParam(type="text", text=content)
|
||||
]
|
||||
content = [ChatCompletionContentPartTextParam(type="text", text=content)]
|
||||
result = _parse_chat_message_content_parts(
|
||||
role,
|
||||
content, # type: ignore
|
||||
@@ -1459,10 +1395,7 @@ def _parse_chat_message_content(
|
||||
# The 'tool_calls' is not None check ensures compatibility.
|
||||
# It's needed only if downstream code doesn't strictly
|
||||
# follow the OpenAI spec.
|
||||
if (
|
||||
"tool_calls" in parsed_msg
|
||||
and parsed_msg["tool_calls"] is not None
|
||||
):
|
||||
if "tool_calls" in parsed_msg and parsed_msg["tool_calls"] is not None:
|
||||
result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
|
||||
elif role == "tool":
|
||||
parsed_msg = _ToolParser(message)
|
||||
@@ -1594,7 +1527,8 @@ def resolve_chat_template_kwargs(
|
||||
chat_template_kwargs: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
fn_kw = {
|
||||
k for k in chat_template_kwargs
|
||||
k
|
||||
for k in chat_template_kwargs
|
||||
if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
|
||||
}
|
||||
|
||||
@@ -1604,9 +1538,7 @@ def resolve_chat_template_kwargs(
|
||||
# chat template has been already resolved at this stage
|
||||
unexpected_vars = {"chat_template"}
|
||||
accept_vars = (fn_kw | template_vars) - unexpected_vars
|
||||
return {
|
||||
k: v for k, v in chat_template_kwargs.items() if k in accept_vars
|
||||
}
|
||||
return {k: v for k, v in chat_template_kwargs.items() if k in accept_vars}
|
||||
|
||||
|
||||
def apply_hf_chat_template(
|
||||
|
||||
Reference in New Issue
Block a user