Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -8,8 +8,7 @@ from collections import Counter, defaultdict, deque
from collections.abc import Awaitable, Iterable
from functools import cached_property, lru_cache, partial
from pathlib import Path
from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union,
cast)
from typing import Any, Callable, Generic, Literal, Optional, TypeVar, Union, cast
import jinja2
import jinja2.ext
@@ -18,40 +17,45 @@ import jinja2.nodes
import jinja2.parser
import jinja2.sandbox
import transformers.utils.chat_template_utils as hf_chat_utils
# yapf conflicts with isort for this block
# yapf: disable
from openai.types.chat import (ChatCompletionAssistantMessageParam,
ChatCompletionContentPartImageParam,
ChatCompletionContentPartInputAudioParam)
from openai.types.chat import (
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam)
from openai.types.chat import (ChatCompletionContentPartRefusalParam,
ChatCompletionContentPartTextParam)
ChatCompletionAssistantMessageParam,
ChatCompletionContentPartImageParam,
ChatCompletionContentPartInputAudioParam,
ChatCompletionContentPartRefusalParam,
ChatCompletionContentPartTextParam,
ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam,
)
from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
from openai.types.chat import (ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam)
from openai.types.chat.chat_completion_content_part_input_audio_param import (
InputAudio)
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam,
)
from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam,
)
from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
from openai.types.responses import ResponseInputImageParam
from openai_harmony import Message as OpenAIHarmonyMessage
from PIL import Image
from pydantic import BaseModel, ConfigDict, TypeAdapter
# yapf: enable
from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast,
ProcessorMixin)
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin
# pydantic needs the TypedDict from typing_extensions
from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.models import SupportsMultiModal
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalUUIDDict)
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
from vllm.multimodal.utils import MediaConnector
# yapf: disable
from vllm.transformers_utils.chat_templates import (
get_chat_template_fallback_path)
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
# yapf: enable
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
@@ -284,9 +288,11 @@ def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
if isinstance(node, jinja2.nodes.Getitem):
return (_is_var_access(node.node, varname)
and isinstance(node.arg, jinja2.nodes.Const)
and node.arg.value == key)
return (
_is_var_access(node.node, varname)
and isinstance(node.arg, jinja2.nodes.Const)
and node.arg.value == key
)
if isinstance(node, jinja2.nodes.Getattr):
return _is_var_access(node.node, varname) and node.attr == key
@@ -301,12 +307,14 @@ def _is_var_or_elems_access(
) -> bool:
if isinstance(node, jinja2.nodes.Filter):
return node.node is not None and _is_var_or_elems_access(
node.node, varname, key)
node.node, varname, key
)
if isinstance(node, jinja2.nodes.Test):
return _is_var_or_elems_access(node.node, varname, key)
if isinstance(node, jinja2.nodes.Getitem) and isinstance(
node.arg, jinja2.nodes.Slice):
node.arg, jinja2.nodes.Slice
):
return _is_var_or_elems_access(node.node, varname, key)
# yapf: disable
@@ -342,8 +350,7 @@ def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str):
# the scope in which each variable is defined, but that is too complicated
def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
messages_varnames = [
varname
for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
varname for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
]
# Search for {%- for message in messages -%} loops
@@ -484,8 +491,7 @@ def resolve_hf_chat_template(
# 2nd priority: AutoProcessor chat template, unless tool calling is enabled
if tools is None:
chat_template = _try_get_processor_chat_template(tokenizer,
model_config)
chat_template = _try_get_processor_chat_template(tokenizer, model_config)
if chat_template is not None:
return chat_template
@@ -678,16 +684,12 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
mm_uuids = {}
uuids_by_modality = dict(self._uuids_by_modality)
if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality:
raise ValueError(
"Mixing raw image and embedding inputs is not allowed"
)
raise ValueError("Mixing raw image and embedding inputs is not allowed")
if "image_embeds" in uuids_by_modality:
image_embeds_uuids = uuids_by_modality["image_embeds"]
if len(image_embeds_uuids) > 1:
raise ValueError(
"Only one message can have {'type': 'image_embeds'}"
)
raise ValueError("Only one message can have {'type': 'image_embeds'}")
mm_uuids["image"] = uuids_by_modality["image_embeds"]
if "image" in uuids_by_modality:
mm_uuids["image"] = uuids_by_modality["image"] # UUIDs of images
@@ -709,16 +711,12 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
mm_inputs = {}
items_by_modality = dict(self._items_by_modality)
if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError(
"Mixing raw image and embedding inputs is not allowed"
)
raise ValueError("Mixing raw image and embedding inputs is not allowed")
if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"]
if len(image_embeds_lst) > 1:
raise ValueError(
"Only one message can have {'type': 'image_embeds'}"
)
raise ValueError("Only one message can have {'type': 'image_embeds'}")
mm_inputs["image"] = image_embeds_lst[0]
if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images
@@ -748,16 +746,12 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
items_by_modality[modality] = await asyncio.gather(*coros)
if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError(
"Mixing raw image and embedding inputs is not allowed"
)
raise ValueError("Mixing raw image and embedding inputs is not allowed")
if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"]
if len(image_embeds_lst) > 1:
raise ValueError(
"Only one message can have {'type': 'image_embeds'}"
)
raise ValueError("Only one message can have {'type': 'image_embeds'}")
mm_inputs["image"] = image_embeds_lst[0]
if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images
@@ -783,9 +777,7 @@ class BaseMultiModalContentParser(ABC):
# }
self._placeholder_storage: dict[str, list] = defaultdict(list)
def _add_placeholder(
self, modality: ModalityStr, placeholder: Optional[str]
):
def _add_placeholder(self, modality: ModalityStr, placeholder: Optional[str]):
mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
if placeholder:
self._placeholder_storage[mod_placeholder].append(placeholder)
@@ -794,8 +786,7 @@ class BaseMultiModalContentParser(ABC):
return dict(self._placeholder_storage)
@abstractmethod
def parse_image(
self, image_url: Optional[str], uuid: Optional[str] = None) -> None:
def parse_image(self, image_url: Optional[str], uuid: Optional[str] = None) -> None:
raise NotImplementedError
@abstractmethod
@@ -813,9 +804,7 @@ class BaseMultiModalContentParser(ABC):
raise NotImplementedError
@abstractmethod
def parse_audio(
self, audio_url: Optional[str], uuid: Optional[str] = None
) -> None:
def parse_audio(self, audio_url: Optional[str], uuid: Optional[str] = None) -> None:
raise NotImplementedError
@abstractmethod
@@ -825,9 +814,7 @@ class BaseMultiModalContentParser(ABC):
raise NotImplementedError
@abstractmethod
def parse_video(
self, video_url: Optional[str], uuid: Optional[str] = None
) -> None:
def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None:
raise NotImplementedError
@@ -844,9 +831,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
allowed_media_domains=tracker.allowed_media_domains,
)
def parse_image(
self, image_url: Optional[str], uuid: Optional[str] = None
) -> None:
def parse_image(self, image_url: Optional[str], uuid: Optional[str] = None) -> None:
image = self._connector.fetch_image(image_url) if image_url else None
placeholder = self._tracker.add("image", image, uuid)
@@ -879,9 +864,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
placeholder = self._tracker.add("image", image_pil, uuid)
self._add_placeholder("image", placeholder)
def parse_audio(
self, audio_url: Optional[str], uuid: Optional[str] = None
) -> None:
def parse_audio(self, audio_url: Optional[str], uuid: Optional[str] = None) -> None:
audio = self._connector.fetch_audio(audio_url) if audio_url else None
placeholder = self._tracker.add("audio", audio, uuid)
@@ -903,14 +886,8 @@ class MultiModalContentParser(BaseMultiModalContentParser):
return self.parse_audio(audio_url, uuid)
def parse_video(
self, video_url: Optional[str], uuid: Optional[str] = None
) -> None:
video = (
self._connector.fetch_video(video_url=video_url)
if video_url
else None
)
def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None:
video = self._connector.fetch_video(video_url=video_url) if video_url else None
placeholder = self._tracker.add("video", video, uuid)
self._add_placeholder("video", placeholder)
@@ -929,12 +906,8 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
allowed_media_domains=tracker.allowed_media_domains,
)
def parse_image(
self, image_url: Optional[str], uuid: Optional[str] = None
) -> None:
image_coro = (
self._connector.fetch_image_async(image_url) if image_url else None
)
def parse_image(self, image_url: Optional[str], uuid: Optional[str] = None) -> None:
image_coro = self._connector.fetch_image_async(image_url) if image_url else None
placeholder = self._tracker.add("image", image_coro, uuid)
self._add_placeholder("image", placeholder)
@@ -944,9 +917,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
image_embeds: Union[str, dict[str, str], None],
uuid: Optional[str] = None,
) -> None:
future: asyncio.Future[Union[str, dict[str, str], None]] = (
asyncio.Future()
)
future: asyncio.Future[Union[str, dict[str, str], None]] = asyncio.Future()
if isinstance(image_embeds, dict):
embeds = {
@@ -977,12 +948,8 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
placeholder = self._tracker.add("image", future, uuid)
self._add_placeholder("image", placeholder)
def parse_audio(
self, audio_url: Optional[str], uuid: Optional[str] = None
) -> None:
audio_coro = (
self._connector.fetch_audio_async(audio_url) if audio_url else None
)
def parse_audio(self, audio_url: Optional[str], uuid: Optional[str] = None) -> None:
audio_coro = self._connector.fetch_audio_async(audio_url) if audio_url else None
placeholder = self._tracker.add("audio", audio_coro, uuid)
self._add_placeholder("audio", placeholder)
@@ -1003,9 +970,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
return self.parse_audio(audio_url, uuid)
def parse_video(
self, video_url: Optional[str], uuid: Optional[str] = None
) -> None:
def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None:
video = (
self._connector.fetch_video_async(video_url=video_url)
if video_url
@@ -1036,9 +1001,7 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]):
)
else:
raise TypeError(
f"{type(chat_template)} is not a valid chat template type"
)
raise TypeError(f"{type(chat_template)} is not a valid chat template type")
def _load_chat_template(
@@ -1145,9 +1108,7 @@ def _get_full_multimodal_text_prompt(
"actual multimodal data items."
)
missing_placeholders.extend(
[placeholder] * placeholder_counts[placeholder]
)
missing_placeholders.extend([placeholder] * placeholder_counts[placeholder])
# NOTE: Default behaviour: we always add missing placeholders
# at the front of the prompt, if interleave_strings=False
@@ -1166,9 +1127,7 @@ _ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
_ResponsesInputImageParser = TypeAdapter(
ResponseInputImageParam
).validate_python
_ResponsesInputImageParser = TypeAdapter(ResponseInputImageParam).validate_python
_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio, PILImage]
# Define a mapping from part types to their corresponding parsing functions.
@@ -1179,26 +1138,14 @@ MM_PARSER_MAP: dict[
"text": lambda part: _TextParser(part).get("text", None),
"thinking": lambda part: _ThinkParser(part).get("thinking", None),
"input_text": lambda part: _TextParser(part).get("text", None),
"input_image": lambda part: _ResponsesInputImageParser(part).get(
"image_url", None
),
"image_url": lambda part: _ImageParser(part)
.get("image_url", {})
.get("url", None),
"image_embeds": lambda part: _ImageEmbedsParser(part).get(
"image_embeds", None
),
"input_image": lambda part: _ResponsesInputImageParser(part).get("image_url", None),
"image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
"image_embeds": lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
"image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
"audio_url": lambda part: _AudioParser(part)
.get("audio_url", {})
.get("url", None),
"input_audio": lambda part: _InputAudioParser(part).get(
"input_audio", None
),
"audio_url": lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
"input_audio": lambda part: _InputAudioParser(part).get("input_audio", None),
"refusal": lambda part: _RefusalParser(part).get("refusal", None),
"video_url": lambda part: _VideoParser(part)
.get("video_url", {})
.get("url", None),
"video_url": lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
}
@@ -1225,15 +1172,14 @@ def _parse_chat_message_content_mm_part(
part_type = part.get("type", None)
uuid = part.get("uuid", None)
if isinstance(part_type, str) and part_type in MM_PARSER_MAP and uuid is None: # noqa: E501
if isinstance(part_type, str) and part_type in MM_PARSER_MAP and uuid is None: # noqa: E501
content = MM_PARSER_MAP[part_type](part)
# Special case for 'image_url.detail'
# We only support 'auto', which is the default
if part_type == "image_url" and part.get("detail", "auto") != "auto":
logger.warning(
"'image_url.detail' is currently not supported "
"and will be ignored."
"'image_url.detail' is currently not supported and will be ignored."
)
return part_type, content
@@ -1242,9 +1188,7 @@ def _parse_chat_message_content_mm_part(
# 'type' is required field by pydantic
if part_type is None or uuid is not None:
if "image_url" in part:
image_params = cast(
CustomChatCompletionContentSimpleImageParam, part
)
image_params = cast(CustomChatCompletionContentSimpleImageParam, part)
image_url = image_params.get("image_url", None)
if isinstance(image_url, dict):
# Can potentially happen if user provides a uuid
@@ -1253,22 +1197,20 @@ def _parse_chat_message_content_mm_part(
return "image_url", image_url
if "image_pil" in part:
# "image_pil" could be None if UUID is provided.
image_params = cast( # type: ignore
image_params = cast( # type: ignore
CustomChatCompletionContentPILImageParam, part
)
image_pil = image_params.get("image_pil", None)
return "image_pil", image_pil
if "image_embeds" in part:
# "image_embeds" could be None if UUID is provided.
image_params = cast( # type: ignore
image_params = cast( # type: ignore
ChatCompletionContentPartImageEmbedsParam, part
)
image_embeds = image_params.get("image_embeds", None)
return "image_embeds", image_embeds
if "audio_url" in part:
audio_params = cast(
CustomChatCompletionContentSimpleAudioParam, part
)
audio_params = cast(CustomChatCompletionContentSimpleAudioParam, part)
audio_url = audio_params.get("audio_url", None)
if isinstance(audio_url, dict):
# Can potentially happen if user provides a uuid
@@ -1279,9 +1221,7 @@ def _parse_chat_message_content_mm_part(
input_audio_params = cast(dict[str, str], part)
return "input_audio", input_audio_params
if "video_url" in part:
video_params = cast(
CustomChatCompletionContentSimpleVideoParam, part
)
video_params = cast(CustomChatCompletionContentSimpleVideoParam, part)
video_url = video_params.get("video_url", None)
if isinstance(video_url, dict):
# Can potentially happen if user provides a uuid
@@ -1418,9 +1358,7 @@ def _parse_chat_message_content_part(
return (
{"type": modality}
if wrap_dicts
else (
MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None
)
else (MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None)
)
@@ -1441,9 +1379,7 @@ def _parse_chat_message_content(
if content is None:
content = []
elif isinstance(content, str):
content = [
ChatCompletionContentPartTextParam(type="text", text=content)
]
content = [ChatCompletionContentPartTextParam(type="text", text=content)]
result = _parse_chat_message_content_parts(
role,
content, # type: ignore
@@ -1459,10 +1395,7 @@ def _parse_chat_message_content(
# The 'tool_calls' is not None check ensures compatibility.
# It's needed only if downstream code doesn't strictly
# follow the OpenAI spec.
if (
"tool_calls" in parsed_msg
and parsed_msg["tool_calls"] is not None
):
if "tool_calls" in parsed_msg and parsed_msg["tool_calls"] is not None:
result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
elif role == "tool":
parsed_msg = _ToolParser(message)
@@ -1594,7 +1527,8 @@ def resolve_chat_template_kwargs(
chat_template_kwargs: dict[str, Any],
) -> dict[str, Any]:
fn_kw = {
k for k in chat_template_kwargs
k
for k in chat_template_kwargs
if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
}
@@ -1604,9 +1538,7 @@ def resolve_chat_template_kwargs(
# chat template has been already resolved at this stage
unexpected_vars = {"chat_template"}
accept_vars = (fn_kw | template_vars) - unexpected_vars
return {
k: v for k, v in chat_template_kwargs.items() if k in accept_vars
}
return {k: v for k, v in chat_template_kwargs.items() if k in accept_vars}
def apply_hf_chat_template(