[BugFix][Frontend] Use LoRA tokenizer in OpenAI APIs (#6227)

Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Nick Hill
2024-07-18 00:13:30 -07:00
committed by GitHub
parent 8a74c68bd1
commit e2fbaee725
16 changed files with 267 additions and 186 deletions

View File

@@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
from typing import Union
from fastapi import Request
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
@@ -49,7 +50,9 @@ class OpenAIServingChat(OpenAIServing):
lora_modules=lora_modules)
self.response_role = response_role
load_chat_template(self, chat_template)
# If this is None we use the tokenizer's default chat template
self.chat_template = load_chat_template(chat_template)
async def create_chat_completion(
self,
@@ -71,11 +74,15 @@ class OpenAIServingChat(OpenAIServing):
return error_check_ret
try:
_, lora_request = self._maybe_get_adapter(request)
tokenizer = await self.engine.get_tokenizer(lora_request)
conversation: List[ConversationMessage] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
for msg in request.messages:
chat_parsed_result = parse_chat_message_content(self, msg)
chat_parsed_result = parse_chat_message_content(
msg, self.model_config, tokenizer)
conversation.extend(chat_parsed_result.messages)
mm_futures.extend(chat_parsed_result.mm_futures)
@@ -84,13 +91,13 @@ class OpenAIServingChat(OpenAIServing):
tool.model_dump() for tool in request.tools
]
prompt = self.tokenizer.apply_chat_template(
prompt = tokenizer.apply_chat_template(
conversation=conversation,
tokenize=False,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
chat_template=request.chat_template,
chat_template=request.chat_template or self.chat_template,
**(request.chat_template_kwargs or {}),
)
except Exception as e:
@@ -112,19 +119,19 @@ class OpenAIServingChat(OpenAIServing):
request_id = f"cmpl-{random_uuid()}"
try:
# Tokenize/detokenize depending on prompt format (string/token list)
prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
prompt_ids, prompt_text = await self._validate_prompt_and_tokenize(
request,
tokenizer,
prompt=prompt,
add_special_tokens=request.add_special_tokens)
sampling_params = request.to_sampling_params()
_, lora_request = self._maybe_get_adapter(request)
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
guided_decode_logits_processor = (
await get_guided_decoding_logits_processor(
guided_decoding_backend, request, await
self.engine.get_tokenizer()))
await
get_guided_decoding_logits_processor(guided_decoding_backend,
request, tokenizer))
if guided_decode_logits_processor:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
@@ -158,12 +165,12 @@ class OpenAIServingChat(OpenAIServing):
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id, conversation)
request, result_generator, request_id, conversation, tokenizer)
else:
try:
return await self.chat_completion_full_generator(
request, raw_request, result_generator, request_id,
conversation)
conversation, tokenizer)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
@@ -175,9 +182,12 @@ class OpenAIServingChat(OpenAIServing):
return request.messages[-1]["role"]
async def chat_completion_stream_generator(
self, request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput], request_id: str,
conversation: List[ConversationMessage]
self,
request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput],
request_id: str,
conversation: List[ConversationMessage],
tokenizer: PreTrainedTokenizer,
) -> AsyncGenerator[str, None]:
model_name = self.served_model_names[0]
created_time = int(time.time())
@@ -264,6 +274,7 @@ class OpenAIServingChat(OpenAIServing):
logprobs = self._create_chat_logprobs(
token_ids=delta_token_ids,
top_logprobs=out_logprobs,
tokenizer=tokenizer,
num_output_top_logprobs=request.top_logprobs,
)
else:
@@ -352,9 +363,13 @@ class OpenAIServingChat(OpenAIServing):
yield "data: [DONE]\n\n"
async def chat_completion_full_generator(
self, request: ChatCompletionRequest, raw_request: Optional[Request],
result_generator: AsyncIterator[RequestOutput], request_id: str,
conversation: List[ConversationMessage]
self,
request: ChatCompletionRequest,
raw_request: Optional[Request],
result_generator: AsyncIterator[RequestOutput],
request_id: str,
conversation: List[ConversationMessage],
tokenizer: PreTrainedTokenizer,
) -> Union[ErrorResponse, ChatCompletionResponse]:
model_name = self.served_model_names[0]
@@ -382,6 +397,7 @@ class OpenAIServingChat(OpenAIServing):
token_ids=token_ids,
top_logprobs=out_logprobs,
num_output_top_logprobs=request.top_logprobs,
tokenizer=tokenizer,
)
else:
logprobs = None
@@ -436,16 +452,14 @@ class OpenAIServingChat(OpenAIServing):
return response
def _get_top_logprobs(
self, logprobs: Dict[int, Logprob],
top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]:
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]:
return [
ChatCompletionLogProb(
token=self._get_decoded_token(p[1], p[0]),
token=(token := self._get_decoded_token(p[1], p[0],
tokenizer)),
logprob=max(p[1].logprob, -9999.0),
bytes=list(
self._get_decoded_token(p[1],
p[0]).encode("utf-8",
errors="replace")))
bytes=list(token.encode("utf-8", errors="replace")))
for i, p in enumerate(logprobs.items())
if top_logprobs and i < top_logprobs
]
@@ -454,6 +468,7 @@ class OpenAIServingChat(OpenAIServing):
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
tokenizer: PreTrainedTokenizer,
num_output_top_logprobs: Optional[int] = None,
) -> ChatCompletionLogProbs:
"""Create OpenAI-style logprobs."""
@@ -463,12 +478,11 @@ class OpenAIServingChat(OpenAIServing):
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = tokenizer.decode(token_id)
logprobs_content.append(
ChatCompletionLogProbsContent(
token=self.tokenizer.decode(token_id),
bytes=list(
self.tokenizer.decode(token_id).encode(
"utf-8", errors="replace"))))
token=token,
bytes=list(token.encode("utf-8", errors="replace"))))
else:
logprobs_content.append(
ChatCompletionLogProbsContent(
@@ -479,6 +493,7 @@ class OpenAIServingChat(OpenAIServing):
step_top_logprobs[token_id].decoded_token.encode(
"utf-8", errors="replace")),
top_logprobs=self._get_top_logprobs(
step_top_logprobs, num_output_top_logprobs)))
step_top_logprobs, num_output_top_logprobs,
tokenizer)))
return ChatCompletionLogProbs(content=logprobs_content)