[Frontend] Clean up type annotations for mistral tokenizer (#8314)
This commit is contained in:
@@ -23,6 +23,7 @@ from openai.types.chat import (ChatCompletionMessageToolCallParam,
|
||||
# yapf: enable
|
||||
# pydantic needs the TypedDict from typing_extensions
|
||||
from pydantic import ConfigDict
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from typing_extensions import Required, TypeAlias, TypedDict
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
@@ -31,7 +32,7 @@ from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.multimodal.utils import (async_get_and_parse_audio,
|
||||
async_get_and_parse_image,
|
||||
get_and_parse_audio, get_and_parse_image)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -379,6 +380,9 @@ def _parse_chat_message_content_parts(
|
||||
audio_url = _AudioParser(part)["audio_url"]
|
||||
|
||||
mm_parser.parse_audio(audio_url["url"])
|
||||
elif part_type == "refusal":
|
||||
text = _RefusalParser(part)["refusal"]
|
||||
texts.append(text)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown part type: {part_type}")
|
||||
|
||||
@@ -433,6 +437,21 @@ def _parse_chat_message_content(
|
||||
return result
|
||||
|
||||
|
||||
def _postprocess_messages(messages: List[ConversationMessage]) -> None:
|
||||
# per the Transformers docs & maintainers, tool call arguments in
|
||||
# assistant-role messages with tool_calls need to be dicts not JSON str -
|
||||
# this is how tool-use chat templates will expect them moving forwards
|
||||
# so, for messages that have tool_calls, parse the string (which we get
|
||||
# from openAI format) to dict
|
||||
for message in messages:
|
||||
if (message["role"] == "assistant" and "tool_calls" in message
|
||||
and isinstance(message["tool_calls"], list)):
|
||||
|
||||
for item in message["tool_calls"]:
|
||||
item["function"]["arguments"] = json.loads(
|
||||
item["function"]["arguments"])
|
||||
|
||||
|
||||
def parse_chat_messages(
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
model_config: ModelConfig,
|
||||
@@ -446,6 +465,8 @@ def parse_chat_messages(
|
||||
|
||||
conversation.extend(sub_messages)
|
||||
|
||||
_postprocess_messages(conversation)
|
||||
|
||||
return conversation, mm_tracker.all_mm_data()
|
||||
|
||||
|
||||
@@ -462,41 +483,41 @@ def parse_chat_messages_futures(
|
||||
|
||||
conversation.extend(sub_messages)
|
||||
|
||||
_postprocess_messages(conversation)
|
||||
|
||||
return conversation, mm_tracker.all_mm_data()
|
||||
|
||||
|
||||
def apply_chat_template(
|
||||
tokenizer: AnyTokenizer,
|
||||
def apply_hf_chat_template(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
conversation: List[ConversationMessage],
|
||||
chat_template: Optional[str],
|
||||
*,
|
||||
tokenize: bool = False, # Different from HF's default
|
||||
**kwargs: Any,
|
||||
) -> Union[str, List[int]]:
|
||||
) -> str:
|
||||
if chat_template is None and tokenizer.chat_template is None:
|
||||
raise ValueError(
|
||||
"As of transformers v4.44, default chat template is no longer "
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
"does not define one.")
|
||||
|
||||
# per the Transformers docs & maintainers, tool call arguments in
|
||||
# assistant-role messages with tool_calls need to be dicts not JSON str -
|
||||
# this is how tool-use chat templates will expect them moving forwards
|
||||
# so, for messages that have tool_calls, parse the string (which we get
|
||||
# from openAI format) to dict
|
||||
for message in conversation:
|
||||
if (message["role"] == "assistant" and "tool_calls" in message
|
||||
and isinstance(message["tool_calls"], list)):
|
||||
|
||||
for i in range(len(message["tool_calls"])):
|
||||
args: str = message["tool_calls"][i]["function"]["arguments"]
|
||||
parsed_args: Dict = json.loads(args)
|
||||
message["tool_calls"][i]["function"]["arguments"] = parsed_args
|
||||
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
conversation=conversation,
|
||||
return tokenizer.apply_chat_template(
|
||||
conversation=conversation, # type: ignore[arg-type]
|
||||
chat_template=chat_template,
|
||||
tokenize=tokenize,
|
||||
**kwargs,
|
||||
)
|
||||
return prompt
|
||||
|
||||
|
||||
def apply_mistral_chat_template(
|
||||
tokenizer: MistralTokenizer,
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
chat_template: Optional[str],
|
||||
**kwargs: Any,
|
||||
) -> List[int]:
|
||||
return tokenizer.apply_chat_template(
|
||||
messages=messages,
|
||||
chat_template=chat_template,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user