[Frontend] Clean up type annotations for mistral tokenizer (#8314)
This commit is contained in:
@@ -6,7 +6,8 @@ from tqdm import tqdm
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
apply_chat_template,
|
||||
apply_hf_chat_template,
|
||||
apply_mistral_chat_template,
|
||||
parse_chat_messages)
|
||||
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
@@ -19,7 +20,7 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||
get_cached_tokenizer)
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
@@ -393,12 +394,21 @@ class LLM:
|
||||
conversation, mm_data = parse_chat_messages(messages, model_config,
|
||||
tokenizer)
|
||||
|
||||
prompt = apply_chat_template(
|
||||
tokenizer,
|
||||
conversation,
|
||||
chat_template=chat_template,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
)
|
||||
prompt: Union[str, List[int]]
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
prompt = apply_mistral_chat_template(
|
||||
tokenizer,
|
||||
messages=messages,
|
||||
chat_template=chat_template,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
)
|
||||
else:
|
||||
prompt = apply_hf_chat_template(
|
||||
tokenizer,
|
||||
conversation=conversation,
|
||||
chat_template=chat_template,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
)
|
||||
|
||||
inputs: PromptInputs
|
||||
if is_list_of(prompt, int):
|
||||
|
||||
Reference in New Issue
Block a user