Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
@@ -5,10 +5,11 @@ import codecs
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Awaitable, Iterable
|
||||
from functools import cache, lru_cache, partial
|
||||
from pathlib import Path
|
||||
from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
|
||||
Literal, Optional, Tuple, TypeVar, Union, cast)
|
||||
from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union,
|
||||
cast)
|
||||
|
||||
import jinja2.nodes
|
||||
import transformers.utils.chat_template_utils as hf_chat_utils
|
||||
@@ -117,7 +118,7 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
|
||||
role: Required[str]
|
||||
"""The role of the message's author."""
|
||||
|
||||
content: Union[str, List[ChatCompletionContentPartParam]]
|
||||
content: Union[str, list[ChatCompletionContentPartParam]]
|
||||
"""The contents of the message."""
|
||||
|
||||
name: str
|
||||
@@ -143,7 +144,7 @@ class ConversationMessage(TypedDict, total=False):
|
||||
role: Required[str]
|
||||
"""The role of the message's author."""
|
||||
|
||||
content: Union[Optional[str], List[Dict[str, str]]]
|
||||
content: Union[Optional[str], list[dict[str, str]]]
|
||||
"""The contents of the message"""
|
||||
|
||||
tool_call_id: Optional[str]
|
||||
@@ -495,13 +496,13 @@ class BaseMultiModalContentParser(ABC):
|
||||
super().__init__()
|
||||
|
||||
# multimodal placeholder_string : count
|
||||
self._placeholder_counts: Dict[str, int] = defaultdict(lambda: 0)
|
||||
self._placeholder_counts: dict[str, int] = defaultdict(lambda: 0)
|
||||
|
||||
def _add_placeholder(self, placeholder: Optional[str]):
|
||||
if placeholder:
|
||||
self._placeholder_counts[placeholder] += 1
|
||||
|
||||
def mm_placeholder_counts(self) -> Dict[str, int]:
|
||||
def mm_placeholder_counts(self) -> dict[str, int]:
|
||||
return dict(self._placeholder_counts)
|
||||
|
||||
@abstractmethod
|
||||
@@ -652,12 +653,12 @@ def load_chat_template(
|
||||
|
||||
# TODO: Let user specify how to insert multimodal tokens into prompt
|
||||
# (similar to chat template)
|
||||
def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
|
||||
def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
|
||||
text_prompt: str) -> str:
|
||||
"""Combine multimodal prompts for a multimodal language model."""
|
||||
|
||||
# Look through the text prompt to check for missing placeholders
|
||||
missing_placeholders: List[str] = []
|
||||
missing_placeholders: list[str] = []
|
||||
for placeholder in placeholder_counts:
|
||||
|
||||
# For any existing placeholder in the text prompt, we leave it as is
|
||||
@@ -684,10 +685,10 @@ _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
|
||||
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
|
||||
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
|
||||
|
||||
_ContentPart: TypeAlias = Union[str, Dict[str, str], InputAudio]
|
||||
_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio]
|
||||
|
||||
# Define a mapping from part types to their corresponding parsing functions.
|
||||
MM_PARSER_MAP: Dict[
|
||||
MM_PARSER_MAP: dict[
|
||||
str,
|
||||
Callable[[ChatCompletionContentPartParam], _ContentPart],
|
||||
] = {
|
||||
@@ -749,7 +750,7 @@ def _parse_chat_message_content_mm_part(
|
||||
part)
|
||||
return "audio_url", audio_params.get("audio_url", "")
|
||||
if part.get("input_audio") is not None:
|
||||
input_audio_params = cast(Dict[str, str], part)
|
||||
input_audio_params = cast(dict[str, str], part)
|
||||
return "input_audio", input_audio_params
|
||||
if part.get("video_url") is not None:
|
||||
video_params = cast(CustomChatCompletionContentSimpleVideoParam,
|
||||
@@ -773,7 +774,7 @@ def _parse_chat_message_content_parts(
|
||||
mm_tracker: BaseMultiModalItemTracker,
|
||||
*,
|
||||
wrap_dicts: bool,
|
||||
) -> List[ConversationMessage]:
|
||||
) -> list[ConversationMessage]:
|
||||
content = list[_ContentPart]()
|
||||
|
||||
mm_parser = mm_tracker.create_parser()
|
||||
@@ -791,7 +792,7 @@ def _parse_chat_message_content_parts(
|
||||
# Parsing wraps images and texts as interleaved dictionaries
|
||||
return [ConversationMessage(role=role,
|
||||
content=content)] # type: ignore
|
||||
texts = cast(List[str], content)
|
||||
texts = cast(list[str], content)
|
||||
text_prompt = "\n".join(texts)
|
||||
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
|
||||
if mm_placeholder_counts:
|
||||
@@ -866,7 +867,7 @@ def _parse_chat_message_content(
|
||||
message: ChatCompletionMessageParam,
|
||||
mm_tracker: BaseMultiModalItemTracker,
|
||||
content_format: _ChatTemplateContentFormat,
|
||||
) -> List[ConversationMessage]:
|
||||
) -> list[ConversationMessage]:
|
||||
role = message["role"]
|
||||
content = message.get("content")
|
||||
|
||||
@@ -900,7 +901,7 @@ def _parse_chat_message_content(
|
||||
return result
|
||||
|
||||
|
||||
def _postprocess_messages(messages: List[ConversationMessage]) -> None:
|
||||
def _postprocess_messages(messages: list[ConversationMessage]) -> None:
|
||||
# per the Transformers docs & maintainers, tool call arguments in
|
||||
# assistant-role messages with tool_calls need to be dicts not JSON str -
|
||||
# this is how tool-use chat templates will expect them moving forwards
|
||||
@@ -916,12 +917,12 @@ def _postprocess_messages(messages: List[ConversationMessage]) -> None:
|
||||
|
||||
|
||||
def parse_chat_messages(
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
model_config: ModelConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
content_format: _ChatTemplateContentFormat,
|
||||
) -> Tuple[List[ConversationMessage], Optional[MultiModalDataDict]]:
|
||||
conversation: List[ConversationMessage] = []
|
||||
) -> tuple[list[ConversationMessage], Optional[MultiModalDataDict]]:
|
||||
conversation: list[ConversationMessage] = []
|
||||
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
|
||||
|
||||
for msg in messages:
|
||||
@@ -939,12 +940,12 @@ def parse_chat_messages(
|
||||
|
||||
|
||||
def parse_chat_messages_futures(
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
model_config: ModelConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
content_format: _ChatTemplateContentFormat,
|
||||
) -> Tuple[List[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]:
|
||||
conversation: List[ConversationMessage] = []
|
||||
) -> tuple[list[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]:
|
||||
conversation: list[ConversationMessage] = []
|
||||
mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
|
||||
|
||||
for msg in messages:
|
||||
@@ -963,7 +964,7 @@ def parse_chat_messages_futures(
|
||||
|
||||
def apply_hf_chat_template(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
conversation: List[ConversationMessage],
|
||||
conversation: list[ConversationMessage],
|
||||
chat_template: Optional[str],
|
||||
*,
|
||||
tokenize: bool = False, # Different from HF's default
|
||||
@@ -985,10 +986,10 @@ def apply_hf_chat_template(
|
||||
|
||||
def apply_mistral_chat_template(
|
||||
tokenizer: MistralTokenizer,
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
chat_template: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[int]:
|
||||
) -> list[int]:
|
||||
if chat_template is not None:
|
||||
logger.warning_once(
|
||||
"'chat_template' cannot be overridden for mistral tokenizer.")
|
||||
|
||||
Reference in New Issue
Block a user