Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -5,10 +5,10 @@ import asyncio
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import Counter, defaultdict, deque
|
||||
from collections.abc import Awaitable, Iterable
|
||||
from collections.abc import Awaitable, Callable, Iterable
|
||||
from functools import cached_property, lru_cache, partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Generic, Literal, Optional, TypeVar, Union, cast
|
||||
from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast
|
||||
|
||||
import jinja2
|
||||
import jinja2.ext
|
||||
@@ -40,7 +40,7 @@ from pydantic import BaseModel, ConfigDict, TypeAdapter
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin
|
||||
|
||||
# pydantic needs the TypedDict from typing_extensions
|
||||
from typing_extensions import Required, TypeAlias, TypedDict
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
@@ -76,7 +76,7 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False):
|
||||
|
||||
|
||||
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
|
||||
image_embeds: Optional[Union[str, dict[str, str]]]
|
||||
image_embeds: str | dict[str, str] | None
|
||||
"""
|
||||
The image embeddings. It can be either:
|
||||
- A single base64 string.
|
||||
@@ -84,7 +84,7 @@ class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
|
||||
"""
|
||||
type: Required[Literal["image_embeds"]]
|
||||
"""The type of the content part."""
|
||||
uuid: Optional[str]
|
||||
uuid: str | None
|
||||
"""
|
||||
User-provided UUID of a media. User must guarantee that it is properly
|
||||
generated and unique for different medias.
|
||||
@@ -123,8 +123,8 @@ class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
|
||||
}
|
||||
"""
|
||||
|
||||
image_pil: Optional[PILImage]
|
||||
uuid: Optional[str]
|
||||
image_pil: PILImage | None
|
||||
uuid: str | None
|
||||
"""
|
||||
User-provided UUID of a media. User must guarantee that it is properly
|
||||
generated and unique for different medias.
|
||||
@@ -141,8 +141,8 @@ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
|
||||
}
|
||||
"""
|
||||
|
||||
image_url: Optional[str]
|
||||
uuid: Optional[str]
|
||||
image_url: str | None
|
||||
uuid: str | None
|
||||
"""
|
||||
User-provided UUID of a media. User must guarantee that it is properly
|
||||
generated and unique for different medias.
|
||||
@@ -158,7 +158,7 @@ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
|
||||
}
|
||||
"""
|
||||
|
||||
audio_url: Optional[str]
|
||||
audio_url: str | None
|
||||
|
||||
|
||||
class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
|
||||
@@ -170,8 +170,8 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
|
||||
}
|
||||
"""
|
||||
|
||||
video_url: Optional[str]
|
||||
uuid: Optional[str]
|
||||
video_url: str | None
|
||||
uuid: str | None
|
||||
"""
|
||||
User-provided UUID of a media. User must guarantee that it is properly
|
||||
generated and unique for different medias.
|
||||
@@ -199,20 +199,20 @@ class CustomThinkCompletionContentParam(TypedDict, total=False):
|
||||
"""The thinking type."""
|
||||
|
||||
|
||||
ChatCompletionContentPartParam: TypeAlias = Union[
|
||||
OpenAIChatCompletionContentPartParam,
|
||||
ChatCompletionContentPartAudioParam,
|
||||
ChatCompletionContentPartInputAudioParam,
|
||||
ChatCompletionContentPartVideoParam,
|
||||
ChatCompletionContentPartRefusalParam,
|
||||
CustomChatCompletionContentPILImageParam,
|
||||
CustomChatCompletionContentSimpleImageParam,
|
||||
ChatCompletionContentPartImageEmbedsParam,
|
||||
CustomChatCompletionContentSimpleAudioParam,
|
||||
CustomChatCompletionContentSimpleVideoParam,
|
||||
str,
|
||||
CustomThinkCompletionContentParam,
|
||||
]
|
||||
ChatCompletionContentPartParam: TypeAlias = (
|
||||
OpenAIChatCompletionContentPartParam
|
||||
| ChatCompletionContentPartAudioParam
|
||||
| ChatCompletionContentPartInputAudioParam
|
||||
| ChatCompletionContentPartVideoParam
|
||||
| ChatCompletionContentPartRefusalParam
|
||||
| CustomChatCompletionContentPILImageParam
|
||||
| CustomChatCompletionContentSimpleImageParam
|
||||
| ChatCompletionContentPartImageEmbedsParam
|
||||
| CustomChatCompletionContentSimpleAudioParam
|
||||
| CustomChatCompletionContentSimpleVideoParam
|
||||
| str
|
||||
| CustomThinkCompletionContentParam
|
||||
)
|
||||
|
||||
|
||||
class CustomChatCompletionMessageParam(TypedDict, total=False):
|
||||
@@ -221,7 +221,7 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
|
||||
role: Required[str]
|
||||
"""The role of the message's author."""
|
||||
|
||||
content: Union[str, list[ChatCompletionContentPartParam]]
|
||||
content: str | list[ChatCompletionContentPartParam]
|
||||
"""The contents of the message."""
|
||||
|
||||
name: str
|
||||
@@ -231,18 +231,18 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
|
||||
same role.
|
||||
"""
|
||||
|
||||
tool_call_id: Optional[str]
|
||||
tool_call_id: str | None
|
||||
"""Tool call that this message is responding to."""
|
||||
|
||||
tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
|
||||
tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None
|
||||
"""The tool calls generated by the model, such as function calls."""
|
||||
|
||||
|
||||
ChatCompletionMessageParam = Union[
|
||||
OpenAIChatCompletionMessageParam,
|
||||
CustomChatCompletionMessageParam,
|
||||
OpenAIHarmonyMessage,
|
||||
]
|
||||
ChatCompletionMessageParam: TypeAlias = (
|
||||
OpenAIChatCompletionMessageParam
|
||||
| CustomChatCompletionMessageParam
|
||||
| OpenAIHarmonyMessage
|
||||
)
|
||||
|
||||
|
||||
# TODO: Make fields ReadOnly once mypy supports it
|
||||
@@ -250,16 +250,16 @@ class ConversationMessage(TypedDict, total=False):
|
||||
role: Required[str]
|
||||
"""The role of the message's author."""
|
||||
|
||||
content: Union[Optional[str], list[dict[str, str]]]
|
||||
content: str | None | list[dict[str, str]]
|
||||
"""The contents of the message"""
|
||||
|
||||
tool_call_id: Optional[str]
|
||||
tool_call_id: str | None
|
||||
"""Tool call that this message is responding to."""
|
||||
|
||||
name: Optional[str]
|
||||
name: str | None
|
||||
"""The name of the function to call"""
|
||||
|
||||
tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]]
|
||||
tool_calls: Iterable[ChatCompletionMessageToolCallParam] | None
|
||||
"""The tool calls generated by the model, such as function calls."""
|
||||
|
||||
|
||||
@@ -294,7 +294,7 @@ def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
|
||||
def _is_var_or_elems_access(
|
||||
node: jinja2.nodes.Node,
|
||||
varname: str,
|
||||
key: Optional[str] = None,
|
||||
key: str | None = None,
|
||||
) -> bool:
|
||||
if isinstance(node, jinja2.nodes.Filter):
|
||||
return node.node is not None and _is_var_or_elems_access(
|
||||
@@ -369,7 +369,7 @@ def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
|
||||
break
|
||||
|
||||
|
||||
def _try_extract_ast(chat_template: str) -> Optional[jinja2.nodes.Template]:
|
||||
def _try_extract_ast(chat_template: str) -> jinja2.nodes.Template | None:
|
||||
try:
|
||||
jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
|
||||
return jinja_compiled.environment.parse(chat_template)
|
||||
@@ -400,9 +400,9 @@ def _detect_content_format(
|
||||
|
||||
|
||||
def resolve_mistral_chat_template(
|
||||
chat_template: Optional[str],
|
||||
chat_template: str | None,
|
||||
**kwargs: Any,
|
||||
) -> Optional[str]:
|
||||
) -> str | None:
|
||||
if chat_template is not None or kwargs.get("chat_template_kwargs") is not None:
|
||||
raise ValueError(
|
||||
"'chat_template' or 'chat_template_kwargs' cannot be overridden "
|
||||
@@ -412,7 +412,7 @@ def resolve_mistral_chat_template(
|
||||
return None
|
||||
|
||||
|
||||
_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], Optional[str]]()
|
||||
_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], str | None]()
|
||||
"""
|
||||
Used in `_try_get_processor_chat_template` to avoid calling
|
||||
`cached_get_processor` again if the processor fails to be loaded.
|
||||
@@ -422,9 +422,9 @@ This is needed because `lru_cache` does not cache when an exception happens.
|
||||
|
||||
|
||||
def _try_get_processor_chat_template(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
|
||||
model_config: ModelConfig,
|
||||
) -> Optional[str]:
|
||||
) -> str | None:
|
||||
cache_key = (tokenizer.name_or_path, model_config.trust_remote_code)
|
||||
if cache_key in _PROCESSOR_CHAT_TEMPLATES:
|
||||
return _PROCESSOR_CHAT_TEMPLATES[cache_key]
|
||||
@@ -458,12 +458,12 @@ def _try_get_processor_chat_template(
|
||||
|
||||
|
||||
def resolve_hf_chat_template(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
chat_template: Optional[str],
|
||||
tools: Optional[list[dict[str, Any]]],
|
||||
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
) -> Optional[str]:
|
||||
) -> str | None:
|
||||
# 1st priority: The given chat template
|
||||
if chat_template is not None:
|
||||
return chat_template
|
||||
@@ -505,8 +505,8 @@ def resolve_hf_chat_template(
|
||||
|
||||
|
||||
def _resolve_chat_template_content_format(
|
||||
chat_template: Optional[str],
|
||||
tools: Optional[list[dict[str, Any]]],
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
tokenizer: AnyTokenizer,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
@@ -538,7 +538,7 @@ def _resolve_chat_template_content_format(
|
||||
|
||||
@lru_cache
|
||||
def _log_chat_template_content_format(
|
||||
chat_template: Optional[str],
|
||||
chat_template: str | None,
|
||||
given_format: ChatTemplateContentFormatOption,
|
||||
detected_format: ChatTemplateContentFormatOption,
|
||||
):
|
||||
@@ -561,8 +561,8 @@ def _log_chat_template_content_format(
|
||||
|
||||
|
||||
def resolve_chat_template_content_format(
|
||||
chat_template: Optional[str],
|
||||
tools: Optional[list[dict[str, Any]]],
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
given_format: ChatTemplateContentFormatOption,
|
||||
tokenizer: AnyTokenizer,
|
||||
*,
|
||||
@@ -604,8 +604,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
self._model_config = model_config
|
||||
self._tokenizer = tokenizer
|
||||
|
||||
self._items_by_modality = defaultdict[str, list[Optional[_T]]](list)
|
||||
self._uuids_by_modality = defaultdict[str, list[Optional[str]]](list)
|
||||
self._items_by_modality = defaultdict[str, list[_T | None]](list)
|
||||
self._uuids_by_modality = defaultdict[str, list[str | None]](list)
|
||||
|
||||
@property
|
||||
def model_config(self) -> ModelConfig:
|
||||
@@ -637,9 +637,9 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
def add(
|
||||
self,
|
||||
modality: ModalityStr,
|
||||
item: Optional[_T],
|
||||
uuid: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
item: _T | None,
|
||||
uuid: str | None = None,
|
||||
) -> str | None:
|
||||
"""
|
||||
Add a multi-modal item to the current prompt and returns the
|
||||
placeholder string to use, if any.
|
||||
@@ -657,7 +657,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
|
||||
return self.model_cls.get_placeholder_str(modality, num_items)
|
||||
|
||||
def all_mm_uuids(self) -> Optional[MultiModalUUIDDict]:
|
||||
def all_mm_uuids(self) -> MultiModalUUIDDict | None:
|
||||
if not self._items_by_modality:
|
||||
return None
|
||||
mm_uuids = {}
|
||||
@@ -684,7 +684,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
|
||||
|
||||
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
|
||||
def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
||||
def all_mm_data(self) -> MultiModalDataDict | None:
|
||||
if not self._items_by_modality:
|
||||
return None
|
||||
mm_inputs = {}
|
||||
@@ -710,7 +710,7 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
|
||||
|
||||
|
||||
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
|
||||
async def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
||||
async def all_mm_data(self) -> MultiModalDataDict | None:
|
||||
if not self._items_by_modality:
|
||||
return None
|
||||
mm_inputs = {}
|
||||
@@ -756,7 +756,7 @@ class BaseMultiModalContentParser(ABC):
|
||||
# }
|
||||
self._placeholder_storage: dict[str, list] = defaultdict(list)
|
||||
|
||||
def _add_placeholder(self, modality: ModalityStr, placeholder: Optional[str]):
|
||||
def _add_placeholder(self, modality: ModalityStr, placeholder: str | None):
|
||||
mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
|
||||
if placeholder:
|
||||
self._placeholder_storage[mod_placeholder].append(placeholder)
|
||||
@@ -765,35 +765,35 @@ class BaseMultiModalContentParser(ABC):
|
||||
return dict(self._placeholder_storage)
|
||||
|
||||
@abstractmethod
|
||||
def parse_image(self, image_url: Optional[str], uuid: Optional[str] = None) -> None:
|
||||
def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_image_embeds(
|
||||
self,
|
||||
image_embeds: Union[str, dict[str, str], None],
|
||||
uuid: Optional[str] = None,
|
||||
image_embeds: str | dict[str, str] | None,
|
||||
uuid: str | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_image_pil(
|
||||
self, image_pil: Optional[Image.Image], uuid: Optional[str] = None
|
||||
self, image_pil: Image.Image | None, uuid: str | None = None
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_audio(self, audio_url: Optional[str], uuid: Optional[str] = None) -> None:
|
||||
def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_input_audio(
|
||||
self, input_audio: Optional[InputAudio], uuid: Optional[str] = None
|
||||
self, input_audio: InputAudio | None, uuid: str | None = None
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None:
|
||||
def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -810,7 +810,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
allowed_media_domains=tracker.allowed_media_domains,
|
||||
)
|
||||
|
||||
def parse_image(self, image_url: Optional[str], uuid: Optional[str] = None) -> None:
|
||||
def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
|
||||
image = self._connector.fetch_image(image_url) if image_url else None
|
||||
|
||||
placeholder = self._tracker.add("image", image, uuid)
|
||||
@@ -818,8 +818,8 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
|
||||
def parse_image_embeds(
|
||||
self,
|
||||
image_embeds: Union[str, dict[str, str], None],
|
||||
uuid: Optional[str] = None,
|
||||
image_embeds: str | dict[str, str] | None,
|
||||
uuid: str | None = None,
|
||||
) -> None:
|
||||
if isinstance(image_embeds, dict):
|
||||
embeds = {
|
||||
@@ -838,19 +838,19 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_image_pil(
|
||||
self, image_pil: Optional[Image.Image], uuid: Optional[str] = None
|
||||
self, image_pil: Image.Image | None, uuid: str | None = None
|
||||
) -> None:
|
||||
placeholder = self._tracker.add("image", image_pil, uuid)
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_audio(self, audio_url: Optional[str], uuid: Optional[str] = None) -> None:
|
||||
def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
|
||||
audio = self._connector.fetch_audio(audio_url) if audio_url else None
|
||||
|
||||
placeholder = self._tracker.add("audio", audio, uuid)
|
||||
self._add_placeholder("audio", placeholder)
|
||||
|
||||
def parse_input_audio(
|
||||
self, input_audio: Optional[InputAudio], uuid: Optional[str] = None
|
||||
self, input_audio: InputAudio | None, uuid: str | None = None
|
||||
) -> None:
|
||||
if input_audio:
|
||||
audio_data = input_audio.get("data", "")
|
||||
@@ -865,7 +865,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
|
||||
return self.parse_audio(audio_url, uuid)
|
||||
|
||||
def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None:
|
||||
def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
|
||||
video = self._connector.fetch_video(video_url=video_url) if video_url else None
|
||||
|
||||
placeholder = self._tracker.add("video", video, uuid)
|
||||
@@ -885,7 +885,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
allowed_media_domains=tracker.allowed_media_domains,
|
||||
)
|
||||
|
||||
def parse_image(self, image_url: Optional[str], uuid: Optional[str] = None) -> None:
|
||||
def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
|
||||
image_coro = self._connector.fetch_image_async(image_url) if image_url else None
|
||||
|
||||
placeholder = self._tracker.add("image", image_coro, uuid)
|
||||
@@ -893,10 +893,10 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
|
||||
def parse_image_embeds(
|
||||
self,
|
||||
image_embeds: Union[str, dict[str, str], None],
|
||||
uuid: Optional[str] = None,
|
||||
image_embeds: str | dict[str, str] | None,
|
||||
uuid: str | None = None,
|
||||
) -> None:
|
||||
future: asyncio.Future[Union[str, dict[str, str], None]] = asyncio.Future()
|
||||
future: asyncio.Future[str | dict[str, str] | None] = asyncio.Future()
|
||||
|
||||
if isinstance(image_embeds, dict):
|
||||
embeds = {
|
||||
@@ -916,9 +916,9 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_image_pil(
|
||||
self, image_pil: Optional[Image.Image], uuid: Optional[str] = None
|
||||
self, image_pil: Image.Image | None, uuid: str | None = None
|
||||
) -> None:
|
||||
future: asyncio.Future[Optional[Image.Image]] = asyncio.Future()
|
||||
future: asyncio.Future[Image.Image | None] = asyncio.Future()
|
||||
if image_pil:
|
||||
future.set_result(image_pil)
|
||||
else:
|
||||
@@ -927,14 +927,14 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
placeholder = self._tracker.add("image", future, uuid)
|
||||
self._add_placeholder("image", placeholder)
|
||||
|
||||
def parse_audio(self, audio_url: Optional[str], uuid: Optional[str] = None) -> None:
|
||||
def parse_audio(self, audio_url: str | None, uuid: str | None = None) -> None:
|
||||
audio_coro = self._connector.fetch_audio_async(audio_url) if audio_url else None
|
||||
|
||||
placeholder = self._tracker.add("audio", audio_coro, uuid)
|
||||
self._add_placeholder("audio", placeholder)
|
||||
|
||||
def parse_input_audio(
|
||||
self, input_audio: Optional[InputAudio], uuid: Optional[str] = None
|
||||
self, input_audio: InputAudio | None, uuid: str | None = None
|
||||
) -> None:
|
||||
if input_audio:
|
||||
audio_data = input_audio.get("data", "")
|
||||
@@ -949,7 +949,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
|
||||
return self.parse_audio(audio_url, uuid)
|
||||
|
||||
def parse_video(self, video_url: Optional[str], uuid: Optional[str] = None) -> None:
|
||||
def parse_video(self, video_url: str | None, uuid: str | None = None) -> None:
|
||||
video = (
|
||||
self._connector.fetch_video_async(video_url=video_url)
|
||||
if video_url
|
||||
@@ -960,7 +960,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
self._add_placeholder("video", placeholder)
|
||||
|
||||
|
||||
def validate_chat_template(chat_template: Optional[Union[Path, str]]):
|
||||
def validate_chat_template(chat_template: Path | str | None):
|
||||
"""Raises if the provided chat template appears invalid."""
|
||||
if chat_template is None:
|
||||
return
|
||||
@@ -984,10 +984,10 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]):
|
||||
|
||||
|
||||
def _load_chat_template(
|
||||
chat_template: Optional[Union[Path, str]],
|
||||
chat_template: Path | str | None,
|
||||
*,
|
||||
is_literal: bool = False,
|
||||
) -> Optional[str]:
|
||||
) -> str | None:
|
||||
if chat_template is None:
|
||||
return None
|
||||
|
||||
@@ -1024,10 +1024,10 @@ _cached_load_chat_template = lru_cache(_load_chat_template)
|
||||
|
||||
|
||||
def load_chat_template(
|
||||
chat_template: Optional[Union[Path, str]],
|
||||
chat_template: Path | str | None,
|
||||
*,
|
||||
is_literal: bool = False,
|
||||
) -> Optional[str]:
|
||||
) -> str | None:
|
||||
return _cached_load_chat_template(chat_template, is_literal=is_literal)
|
||||
|
||||
|
||||
@@ -1107,7 +1107,7 @@ _AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
|
||||
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
|
||||
|
||||
_ResponsesInputImageParser = TypeAdapter(ResponseInputImageParam).validate_python
|
||||
_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio, PILImage]
|
||||
_ContentPart: TypeAlias = str | dict[str, str] | InputAudio | PILImage
|
||||
|
||||
# Define a mapping from part types to their corresponding parsing functions.
|
||||
MM_PARSER_MAP: dict[
|
||||
@@ -1264,7 +1264,7 @@ def _parse_chat_message_content_part(
|
||||
*,
|
||||
wrap_dicts: bool,
|
||||
interleave_strings: bool,
|
||||
) -> Optional[_ContentPart]:
|
||||
) -> _ContentPart | None:
|
||||
"""Parses a single part of a conversation. If wrap_dicts is True,
|
||||
structured dictionary pieces for texts and images will be
|
||||
wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
|
||||
@@ -1310,10 +1310,7 @@ def _parse_chat_message_content_part(
|
||||
mm_parser.parse_image(str_content, uuid)
|
||||
modality = "image"
|
||||
elif part_type == "image_embeds":
|
||||
if content is not None:
|
||||
content = cast(Union[str, dict[str, str]], content)
|
||||
else:
|
||||
content = None
|
||||
content = cast(str | dict[str, str], content) if content is not None else None
|
||||
mm_parser.parse_image_embeds(content, uuid)
|
||||
modality = "image"
|
||||
elif part_type == "audio_url":
|
||||
@@ -1411,8 +1408,8 @@ def parse_chat_messages(
|
||||
content_format: _ChatTemplateContentFormat,
|
||||
) -> tuple[
|
||||
list[ConversationMessage],
|
||||
Optional[MultiModalDataDict],
|
||||
Optional[MultiModalUUIDDict],
|
||||
MultiModalDataDict | None,
|
||||
MultiModalUUIDDict | None,
|
||||
]:
|
||||
conversation: list[ConversationMessage] = []
|
||||
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
|
||||
@@ -1443,8 +1440,8 @@ def parse_chat_messages_futures(
|
||||
content_format: _ChatTemplateContentFormat,
|
||||
) -> tuple[
|
||||
list[ConversationMessage],
|
||||
Awaitable[Optional[MultiModalDataDict]],
|
||||
Optional[MultiModalUUIDDict],
|
||||
Awaitable[MultiModalDataDict | None],
|
||||
MultiModalUUIDDict | None,
|
||||
]:
|
||||
conversation: list[ConversationMessage] = []
|
||||
mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
|
||||
@@ -1498,7 +1495,7 @@ _cached_resolve_chat_template_kwargs = lru_cache(_resolve_chat_template_kwargs)
|
||||
|
||||
|
||||
def resolve_chat_template_kwargs(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
|
||||
chat_template: str,
|
||||
chat_template_kwargs: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
@@ -1518,10 +1515,10 @@ def resolve_chat_template_kwargs(
|
||||
|
||||
|
||||
def apply_hf_chat_template(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
|
||||
conversation: list[ConversationMessage],
|
||||
chat_template: Optional[str],
|
||||
tools: Optional[list[dict[str, Any]]],
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
tokenize: bool = False, # Different from HF's default
|
||||
@@ -1569,8 +1566,8 @@ def apply_hf_chat_template(
|
||||
def apply_mistral_chat_template(
|
||||
tokenizer: MistralTokenizer,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
chat_template: Optional[str],
|
||||
tools: Optional[list[dict[str, Any]]],
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
**kwargs: Any,
|
||||
) -> list[int]:
|
||||
from mistral_common.exceptions import MistralCommonException
|
||||
|
||||
Reference in New Issue
Block a user