Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-12 17:51:31 +01:00
committed by GitHub
parent 9bb38130cb
commit 8fcaaf6a16
944 changed files with 9490 additions and 10121 deletions

View File

@@ -5,10 +5,10 @@ import asyncio
import json
from abc import ABC, abstractmethod
from collections import Counter, defaultdict, deque
from collections.abc import Awaitable, Iterable
from collections.abc import Awaitable, Callable, Iterable
from functools import cached_property, lru_cache, partial
from pathlib import Path
from typing import Any, Callable, Generic, Literal, Optional, TypeVar, Union, cast
from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast
import jinja2
import jinja2.ext
@@ -40,7 +40,7 @@ from pydantic import BaseModel, ConfigDict, TypeAdapter
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin
# pydantic needs the TypedDict from typing_extensions
from typing_extensions import Required, TypeAlias, TypedDict
from typing_extensions import Required, TypedDict
from vllm.config import ModelConfig
from vllm.logger import init_logger
@@ -76,7 +76,7 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False):
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
image_embeds: Optional[Union[str, dict[str, str]]]
image_embeds: str | dict[str, str] | None
"""
The image embeddings. It can be either:
- A single base64 string.
@@ -84,7 +84,7 @@ class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
"""
type: Required[Literal["image_embeds"]]
"""The type of the content part."""
uuid: Optional[str]
uuid: str | None
"""
User-provided UUID of a media. User must guarantee that it is properly
generated and unique for different medias.
@@ -123,8 +123,8 @@ class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
}
"""
image_pil: Optional[PILImage]
uuid: Optional[str]
image_pil: PILImage | None
uuid: str | None
"""
User-provided UUID of a media. User must guarantee that it is properly
generated and unique for different medias.
@@ -141,8 +141,8 @@ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
}
"""
image_url: Optional[str]
uuid: Optional[str]
image_url: str | None
uuid: str | None
"""
User-provided UUID of a media. User must guarantee that it is properly
generated and unique for different medias.
@@ -158,7 +158,7 @@ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
}
"""
audio_url: Optional[str]
audio_url: str | None
class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
@@ -170,8 +170,8 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
}
"""
video_url: Optional[str]
uuid: Optional[str]
video_url: str | None
uuid: str | None
"""
User-provided UUID of a media. User must guarantee that it is properly
generated and unique for different medias.
@@ -199,20 +199,20 @@ class CustomThinkCompletionContentParam(TypedDict, total=False):
"""The thinking type."""
ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam,
ChatCompletionContentPartAudioParam,
ChatCompletionContentPartInputAudioParam,
ChatCompletionContentPartVideoParam,
ChatCompletionContentPartRefusalParam,
CustomChatCompletionContentPILImageParam,
CustomChatCompletionContentSimpleImageParam,
ChatCompletionContentPartImageEmbedsParam,
CustomChatCompletionContentSimpleAudioParam,
CustomChatCompletionContentSimpleVideoParam,
str,
CustomThinkCompletionContentParam,
]
ChatCompletionContentPartParam: TypeAlias = (
OpenAIChatCompletionContentPartParam
| ChatCompletionContentPartAudioParam
| ChatCompletionContentPartInputAudioParam
| ChatCompletionContentPartVideoParam
| ChatCompletionContentPartRefusalParam
| CustomChatCompletionContentPILImageParam
| CustomChatCompletionContentSimpleImageParam
| ChatCompletionContentPartImageEmbedsParam
| CustomChatCompletionContentSimpleAudioParam
| CustomChatCompletionContentSimpleVideoParam
| str
| CustomThinkCompletionContentParam
)
class CustomChatCompletionMessageParam(TypedDict, total=False):
@@ -221,7 +221,7 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
role: Required[str]
"""The role of the message's author."""
content: Union[str, list[ChatCompletionContentPartParam]]
content: str | list[ChatCompletionContentPartParam]
"""The contents of the message."""
name: str
@@ -231,18 +231,18 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
same role.
"""
tool_call_id: Optional[str]
tool_call_id: str | None
"""Tool call that this message is responding to."""
tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None
"""The tool calls generated by the model, such as function calls."""
ChatCompletionMessageParam = Union[
OpenAIChatCompletionMessageParam,
CustomChatCompletionMessageParam,
OpenAIHarmonyMessage,
]
ChatCompletionMessageParam: TypeAlias = (
OpenAIChatCompletionMessageParam
| CustomChatCompletionMessageParam
| OpenAIHarmonyMessage
)
# TODO: Make fields ReadOnly once mypy supports it
@@ -250,16 +250,16 @@ class ConversationMessage(TypedDict, total=False):
role: Required[str]
"""The role of the message's author."""
content: Union[Optional[str], list[dict[str, str]]]
content: str | None | list[dict[str, str]]
"""The contents of the message"""
tool_call_id: Optional[str]
tool_call_id: str | None
"""Tool call that this message is responding to."""
name: Optional[str]
name: str | None
"""The name of the function to call"""
tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None
"""The tool calls generated by the model, such as function calls."""
@@ -294,7 +294,7 @@ def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
def _is_var_or_elems_access(
node: jinja2.nodes.Node,
varname: str,
key: Optional[str] = None,
key: str | None = None,
) -> bool:
if isinstance(node, jinja2.nodes.Filter):
return node.node is not None and _is_var_or_elems_access(
@@ -369,7 +369,7 @@ def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
break
def _try_extract_ast(chat_template: str) -> Optional[jinja2.nodes.Template]:
def _try_extract_ast(chat_template: str) -> jinja2.nodes.Template | None:
try:
jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
return jinja_compiled.environment.parse(chat_template)
@@ -400,9 +400,9 @@ def _detect_content_format(
def resolve_mistral_chat_template(
chat_template: Optional[str],
chat_template: str | None,
**kwargs: Any,
) -> Optional[str]:
) -> str | None:
if chat_template is not None or kwargs.get("chat_template_kwargs") is not None:
raise ValueError(
"'chat_template' or 'chat_template_kwargs' cannot be overridden "
@@ -412,7 +412,7 @@ def resolve_mistral_chat_template(
return None
_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], Optional[str]]()
_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], str | None]()
"""
Used in `_try_get_processor_chat_template` to avoid calling
`cached_get_processor` again if the processor fails to be loaded.
@@ -422,9 +422,9 @@ This is needed because `lru_cache` does not cache when an exception happens.
def _try_get_processor_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
model_config: ModelConfig,
) -> Optional[str]:
) -> str | None:
cache_key = (tokenizer.name_or_path, model_config.trust_remote_code)
if cache_key in _PROCESSOR_CHAT_TEMPLATES:
return _PROCESSOR_CHAT_TEMPLATES[cache_key]
@@ -458,12 +458,12 @@ def _try_get_processor_chat_template(
def resolve_hf_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
chat_template: str | None,
tools: list[dict[str, Any]] | None,
*,
model_config: ModelConfig,
) -> Optional[str]:
) -> str | None:
# 1st priority: The given chat template
if chat_template is not None:
return chat_template
@@ -505,8 +505,8 @@ def resolve_hf_chat_template(
def _resolve_chat_template_content_format(
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
chat_template: str | None,
tools: list[dict[str, Any]] | None,
tokenizer: AnyTokenizer,
*,
model_config: ModelConfig,
@@ -538,7 +538,7 @@ def _resolve_chat_template_content_format(
@lru_cache
def _log_chat_template_content_format(
chat_template: Optional[str],
chat_template: str | None,
given_format: ChatTemplateContentFormatOption,
detected_format: ChatTemplateContentFormatOption,
):
@@ -561,8 +561,8 @@ def _log_chat_template_content_format(
def resolve_chat_template_content_format(
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
chat_template: str | None,
tools: list[dict[str, Any]] | None,
given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer,
*,
@@ -604,8 +604,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self._model_config = model_config
self._tokenizer = tokenizer
self._items_by_modality = defaultdict[str, list[Optional[_T]]](list)
self._uuids_by_modality = defaultdict[str, list[Optional[str]]](list)
self._items_by_modality = defaultdict[str, list[_T | None]](list)
self._uuids_by_modality = defaultdict[str, list[str | None]](list)
@property
def model_config(self) -> ModelConfig:
@@ -637,9 +637,9 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def add(
self,
modality: ModalityStr,
item: Optional[_T],
uuid: Optional[str] = None,
) -> Optional[str]:
item: _T | None,
uuid: str | None = None,
) -> str | None:
"""
Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any.
@@ -657,7 +657,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return self.model_cls.get_placeholder_str(modality, num_items)
def all_mm_uuids(self) -> Optional[MultiModalUUIDDict]:
def all_mm_uuids(self) -> MultiModalUUIDDict | None:
if not self._items_by_modality:
return None
mm_uuids = {}
@@ -684,7 +684,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
def all_mm_data(self) -> Optional[MultiModalDataDict]:
def all_mm_data(self) -> MultiModalDataDict | None:
if not self._items_by_modality:
return None
mm_inputs = {}
@@ -710,7 +710,7 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
async def all_mm_data(self) -> Optional[MultiModalDataDict]:
async def all_mm_data(self) -> MultiModalDataDict | None:
if not self._items_by_modality:
return None
mm_inputs = {}
@@ -756,7 +756,7 @@ class BaseMultiModalContentParser(ABC):
# }
self._placeholder_storage: dict[str, list] = defaultdict(list)
def _add_placeholder(self, modality: ModalityStr, placeholder: Optional[str]):
def _add_placeholder(self, modality: ModalityStr, placeholder: str | None):
mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
if placeholder:
self._placeholder_storage[mod_placeholder].append(placeholder)
@@ -765,35 +765,35 @@ class BaseMultiModalContentParser(ABC):
return dict(self._placeholder_storage)
@abstractmethod
def parse_image(self, image_url: Optional[str], uuid: Optional[str] = None) -> None:
def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
raise NotImplementedError
@abstractmethod
def parse_image_embeds(
self,
image_embeds: Union[str, dict[str, str], None],
uuid: Optional[str] = None,
image_embeds: str | dict[str, str] | None,
uuid: str | None = None,
) -> None:
raise NotImplementedError
@abstractmethod
def parse_image_pil(
self, image_pil: Optional[Image.Image], uuid: Optional[str] = None
self, image_pil: Image.Image | None, uuid: str | None = None
) -> None:
raise NotImplementedError
@abstractmethod
def parse_audio(self, audio_url: Optional[str], uuid: Optional[str] = None) -> None:
def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
raise NotImplementedError
@abstractmethod
def parse_input_audio(
self, input_audio: Optional[InputAudio], uuid: Optional[str] = None
self, input_audio: InputAudio | None, uuid: str | None = None
) -> None:
raise NotImplementedError
@abstractmethod
def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None:
def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
raise NotImplementedError
@@ -810,7 +810,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
allowed_media_domains=tracker.allowed_media_domains,
)
def parse_image(self, image_url: Optional[str], uuid: Optional[str] = None) -> None:
def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
image = self._connector.fetch_image(image_url) if image_url else None
placeholder = self._tracker.add("image", image, uuid)
@@ -818,8 +818,8 @@ class MultiModalContentParser(BaseMultiModalContentParser):
def parse_image_embeds(
self,
image_embeds: Union[str, dict[str, str], None],
uuid: Optional[str] = None,
image_embeds: str | dict[str, str] | None,
uuid: str | None = None,
) -> None:
if isinstance(image_embeds, dict):
embeds = {
@@ -838,19 +838,19 @@ class MultiModalContentParser(BaseMultiModalContentParser):
self._add_placeholder("image", placeholder)
def parse_image_pil(
self, image_pil: Optional[Image.Image], uuid: Optional[str] = None
self, image_pil: Image.Image | None, uuid: str | None = None
) -> None:
placeholder = self._tracker.add("image", image_pil, uuid)
self._add_placeholder("image", placeholder)
def parse_audio(self, audio_url: Optional[str], uuid: Optional[str] = None) -> None:
def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
audio = self._connector.fetch_audio(audio_url) if audio_url else None
placeholder = self._tracker.add("audio", audio, uuid)
self._add_placeholder("audio", placeholder)
def parse_input_audio(
self, input_audio: Optional[InputAudio], uuid: Optional[str] = None
self, input_audio: InputAudio | None, uuid: str | None = None
) -> None:
if input_audio:
audio_data = input_audio.get("data", "")
@@ -865,7 +865,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
return self.parse_audio(audio_url, uuid)
def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None:
def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
video = self._connector.fetch_video(video_url=video_url) if video_url else None
placeholder = self._tracker.add("video", video, uuid)
@@ -885,7 +885,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
allowed_media_domains=tracker.allowed_media_domains,
)
def parse_image(self, image_url: Optional[str], uuid: Optional[str] = None) -> None:
def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
image_coro = self._connector.fetch_image_async(image_url) if image_url else None
placeholder = self._tracker.add("image", image_coro, uuid)
@@ -893,10 +893,10 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
def parse_image_embeds(
self,
image_embeds: Union[str, dict[str, str], None],
uuid: Optional[str] = None,
image_embeds: str | dict[str, str] | None,
uuid: str | None = None,
) -> None:
future: asyncio.Future[Union[str, dict[str, str], None]] = asyncio.Future()
future: asyncio.Future[str | dict[str, str] | None] = asyncio.Future()
if isinstance(image_embeds, dict):
embeds = {
@@ -916,9 +916,9 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
self._add_placeholder("image", placeholder)
def parse_image_pil(
self, image_pil: Optional[Image.Image], uuid: Optional[str] = None
self, image_pil: Image.Image | None, uuid: str | None = None
) -> None:
future: asyncio.Future[Optional[Image.Image]] = asyncio.Future()
future: asyncio.Future[Image.Image | None] = asyncio.Future()
if image_pil:
future.set_result(image_pil)
else:
@@ -927,14 +927,14 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
placeholder = self._tracker.add("image", future, uuid)
self._add_placeholder("image", placeholder)
def parse_audio(self, audio_url: Optional[str], uuid: Optional[str] = None) -> None:
def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
audio_coro = self._connector.fetch_audio_async(audio_url) if audio_url else None
placeholder = self._tracker.add("audio", audio_coro, uuid)
self._add_placeholder("audio", placeholder)
def parse_input_audio(
self, input_audio: Optional[InputAudio], uuid: Optional[str] = None
self, input_audio: InputAudio | None, uuid: str | None = None
) -> None:
if input_audio:
audio_data = input_audio.get("data", "")
@@ -949,7 +949,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
return self.parse_audio(audio_url, uuid)
def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None:
def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
video = (
self._connector.fetch_video_async(video_url=video_url)
if video_url
@@ -960,7 +960,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
self._add_placeholder("video", placeholder)
def validate_chat_template(chat_template: Optional[Union[Path, str]]):
def validate_chat_template(chat_template: Path | str | None):
"""Raises if the provided chat template appears invalid."""
if chat_template is None:
return
@@ -984,10 +984,10 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]):
def _load_chat_template(
chat_template: Optional[Union[Path, str]],
chat_template: Path | str | None,
*,
is_literal: bool = False,
) -> Optional[str]:
) -> str | None:
if chat_template is None:
return None
@@ -1024,10 +1024,10 @@ _cached_load_chat_template = lru_cache(_load_chat_template)
def load_chat_template(
chat_template: Optional[Union[Path, str]],
chat_template: Path | str | None,
*,
is_literal: bool = False,
) -> Optional[str]:
) -> str | None:
return _cached_load_chat_template(chat_template, is_literal=is_literal)
@@ -1107,7 +1107,7 @@ _AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
_ResponsesInputImageParser = TypeAdapter(ResponseInputImageParam).validate_python
_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio, PILImage]
_ContentPart: TypeAlias = str | dict[str, str] | InputAudio | PILImage
# Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP: dict[
@@ -1264,7 +1264,7 @@ def _parse_chat_message_content_part(
*,
wrap_dicts: bool,
interleave_strings: bool,
) -> Optional[_ContentPart]:
) -> _ContentPart | None:
"""Parses a single part of a conversation. If wrap_dicts is True,
structured dictionary pieces for texts and images will be
wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
@@ -1310,10 +1310,7 @@ def _parse_chat_message_content_part(
mm_parser.parse_image(str_content, uuid)
modality = "image"
elif part_type == "image_embeds":
if content is not None:
content = cast(Union[str, dict[str, str]], content)
else:
content = None
content = cast(str | dict[str, str], content) if content is not None else None
mm_parser.parse_image_embeds(content, uuid)
modality = "image"
elif part_type == "audio_url":
@@ -1411,8 +1408,8 @@ def parse_chat_messages(
content_format: _ChatTemplateContentFormat,
) -> tuple[
list[ConversationMessage],
Optional[MultiModalDataDict],
Optional[MultiModalUUIDDict],
MultiModalDataDict | None,
MultiModalUUIDDict | None,
]:
conversation: list[ConversationMessage] = []
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
@@ -1443,8 +1440,8 @@ def parse_chat_messages_futures(
content_format: _ChatTemplateContentFormat,
) -> tuple[
list[ConversationMessage],
Awaitable[Optional[MultiModalDataDict]],
Optional[MultiModalUUIDDict],
Awaitable[MultiModalDataDict | None],
MultiModalUUIDDict | None,
]:
conversation: list[ConversationMessage] = []
mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
@@ -1498,7 +1495,7 @@ _cached_resolve_chat_template_kwargs = lru_cache(_resolve_chat_template_kwargs)
def resolve_chat_template_kwargs(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
chat_template: str,
chat_template_kwargs: dict[str, Any],
) -> dict[str, Any]:
@@ -1518,10 +1515,10 @@ def resolve_chat_template_kwargs(
def apply_hf_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
conversation: list[ConversationMessage],
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
chat_template: str | None,
tools: list[dict[str, Any]] | None,
*,
model_config: ModelConfig,
tokenize: bool = False, # Different from HF's default
@@ -1569,8 +1566,8 @@ def apply_hf_chat_template(
def apply_mistral_chat_template(
tokenizer: MistralTokenizer,
messages: list[ChatCompletionMessageParam],
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
chat_template: str | None,
tools: list[dict[str, Any]] | None,
**kwargs: Any,
) -> list[int]:
from mistral_common.exceptions import MistralCommonException