[Model] Add Mistral Tokenization to improve robustness and chat encoding (#7739)

This commit is contained in:
Patrick von Platen
2024-08-27 14:40:02 +02:00
committed by GitHub
parent 9606c7197d
commit 6fc4e6e07a
12 changed files with 275 additions and 60 deletions

View File

@@ -22,7 +22,8 @@ from vllm.entrypoints.openai.protocol import (
FunctionCall, ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing,
PromptAdapterPath)
PromptAdapterPath,
TextTokensPrompt)
from vllm.inputs import TokensPrompt
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
@@ -130,13 +131,22 @@ class OpenAIServingChat(OpenAIServing):
guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer))
prompt_inputs = self._tokenize_prompt_input(
request,
tokenizer,
prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
if isinstance(prompt, str):
prompt_inputs = self._tokenize_prompt_input(
request,
tokenizer,
prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
else:
assert isinstance(prompt, list) and isinstance(
prompt[0], int
), "Prompt has to be either a string or a list of token ids"
prompt_inputs = TextTokensPrompt(
prompt=tokenizer.decode(prompt), prompt_token_ids=prompt)
assert prompt_inputs is not None
sampling_params = request.to_sampling_params(
tokenizer,