Update deprecated Python 3.8 typing (#13971)

This commit is contained in:
Harry Mellor
2025-03-03 01:34:51 +00:00
committed by GitHub
parent bf33700ecd
commit cf069aa8aa
300 changed files with 2294 additions and 2347 deletions

View File

@@ -5,10 +5,11 @@ import codecs
import json
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from collections.abc import Awaitable, Iterable
from functools import cache, lru_cache, partial
from pathlib import Path
from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
Literal, Optional, Tuple, TypeVar, Union, cast)
from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union,
cast)
import jinja2.nodes
import transformers.utils.chat_template_utils as hf_chat_utils
@@ -117,7 +118,7 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
role: Required[str]
"""The role of the message's author."""
content: Union[str, List[ChatCompletionContentPartParam]]
content: Union[str, list[ChatCompletionContentPartParam]]
"""The contents of the message."""
name: str
@@ -143,7 +144,7 @@ class ConversationMessage(TypedDict, total=False):
role: Required[str]
"""The role of the message's author."""
content: Union[Optional[str], List[Dict[str, str]]]
content: Union[Optional[str], list[dict[str, str]]]
"""The contents of the message"""
tool_call_id: Optional[str]
@@ -495,13 +496,13 @@ class BaseMultiModalContentParser(ABC):
super().__init__()
# multimodal placeholder_string : count
self._placeholder_counts: Dict[str, int] = defaultdict(lambda: 0)
self._placeholder_counts: dict[str, int] = defaultdict(lambda: 0)
def _add_placeholder(self, placeholder: Optional[str]):
if placeholder:
self._placeholder_counts[placeholder] += 1
def mm_placeholder_counts(self) -> Dict[str, int]:
def mm_placeholder_counts(self) -> dict[str, int]:
return dict(self._placeholder_counts)
@abstractmethod
@@ -652,12 +653,12 @@ def load_chat_template(
# TODO: Let user specify how to insert multimodal tokens into prompt
# (similar to chat template)
def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
text_prompt: str) -> str:
"""Combine multimodal prompts for a multimodal language model."""
# Look through the text prompt to check for missing placeholders
missing_placeholders: List[str] = []
missing_placeholders: list[str] = []
for placeholder in placeholder_counts:
# For any existing placeholder in the text prompt, we leave it as is
@@ -684,10 +685,10 @@ _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
_ContentPart: TypeAlias = Union[str, Dict[str, str], InputAudio]
_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio]
# Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP: Dict[
MM_PARSER_MAP: dict[
str,
Callable[[ChatCompletionContentPartParam], _ContentPart],
] = {
@@ -749,7 +750,7 @@ def _parse_chat_message_content_mm_part(
part)
return "audio_url", audio_params.get("audio_url", "")
if part.get("input_audio") is not None:
input_audio_params = cast(Dict[str, str], part)
input_audio_params = cast(dict[str, str], part)
return "input_audio", input_audio_params
if part.get("video_url") is not None:
video_params = cast(CustomChatCompletionContentSimpleVideoParam,
@@ -773,7 +774,7 @@ def _parse_chat_message_content_parts(
mm_tracker: BaseMultiModalItemTracker,
*,
wrap_dicts: bool,
) -> List[ConversationMessage]:
) -> list[ConversationMessage]:
content = list[_ContentPart]()
mm_parser = mm_tracker.create_parser()
@@ -791,7 +792,7 @@ def _parse_chat_message_content_parts(
# Parsing wraps images and texts as interleaved dictionaries
return [ConversationMessage(role=role,
content=content)] # type: ignore
texts = cast(List[str], content)
texts = cast(list[str], content)
text_prompt = "\n".join(texts)
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
if mm_placeholder_counts:
@@ -866,7 +867,7 @@ def _parse_chat_message_content(
message: ChatCompletionMessageParam,
mm_tracker: BaseMultiModalItemTracker,
content_format: _ChatTemplateContentFormat,
) -> List[ConversationMessage]:
) -> list[ConversationMessage]:
role = message["role"]
content = message.get("content")
@@ -900,7 +901,7 @@ def _parse_chat_message_content(
return result
def _postprocess_messages(messages: List[ConversationMessage]) -> None:
def _postprocess_messages(messages: list[ConversationMessage]) -> None:
# per the Transformers docs & maintainers, tool call arguments in
# assistant-role messages with tool_calls need to be dicts not JSON str -
# this is how tool-use chat templates will expect them moving forwards
@@ -916,12 +917,12 @@ def _postprocess_messages(messages: List[ConversationMessage]) -> None:
def parse_chat_messages(
messages: List[ChatCompletionMessageParam],
messages: list[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: AnyTokenizer,
content_format: _ChatTemplateContentFormat,
) -> Tuple[List[ConversationMessage], Optional[MultiModalDataDict]]:
conversation: List[ConversationMessage] = []
) -> tuple[list[ConversationMessage], Optional[MultiModalDataDict]]:
conversation: list[ConversationMessage] = []
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
for msg in messages:
@@ -939,12 +940,12 @@ def parse_chat_messages(
def parse_chat_messages_futures(
messages: List[ChatCompletionMessageParam],
messages: list[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: AnyTokenizer,
content_format: _ChatTemplateContentFormat,
) -> Tuple[List[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]:
conversation: List[ConversationMessage] = []
) -> tuple[list[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]:
conversation: list[ConversationMessage] = []
mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
for msg in messages:
@@ -963,7 +964,7 @@ def parse_chat_messages_futures(
def apply_hf_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
conversation: List[ConversationMessage],
conversation: list[ConversationMessage],
chat_template: Optional[str],
*,
tokenize: bool = False, # Different from HF's default
@@ -985,10 +986,10 @@ def apply_hf_chat_template(
def apply_mistral_chat_template(
tokenizer: MistralTokenizer,
messages: List[ChatCompletionMessageParam],
messages: list[ChatCompletionMessageParam],
chat_template: Optional[str] = None,
**kwargs: Any,
) -> List[int]:
) -> list[int]:
if chat_template is not None:
logger.warning_once(
"'chat_template' cannot be overridden for mistral tokenizer.")