[BugFix][Frontend] Use LoRA tokenizer in OpenAI APIs (#6227)
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
|
||||
from typing import Union
|
||||
|
||||
from fastapi import Request
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
@@ -49,7 +50,9 @@ class OpenAIServingChat(OpenAIServing):
|
||||
lora_modules=lora_modules)
|
||||
|
||||
self.response_role = response_role
|
||||
load_chat_template(self, chat_template)
|
||||
|
||||
# If this is None we use the tokenizer's default chat template
|
||||
self.chat_template = load_chat_template(chat_template)
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
@@ -71,11 +74,15 @@ class OpenAIServingChat(OpenAIServing):
|
||||
return error_check_ret
|
||||
|
||||
try:
|
||||
_, lora_request = self._maybe_get_adapter(request)
|
||||
tokenizer = await self.engine.get_tokenizer(lora_request)
|
||||
|
||||
conversation: List[ConversationMessage] = []
|
||||
mm_futures: List[Awaitable[MultiModalDataDict]] = []
|
||||
|
||||
for msg in request.messages:
|
||||
chat_parsed_result = parse_chat_message_content(self, msg)
|
||||
chat_parsed_result = parse_chat_message_content(
|
||||
msg, self.model_config, tokenizer)
|
||||
|
||||
conversation.extend(chat_parsed_result.messages)
|
||||
mm_futures.extend(chat_parsed_result.mm_futures)
|
||||
@@ -84,13 +91,13 @@ class OpenAIServingChat(OpenAIServing):
|
||||
tool.model_dump() for tool in request.tools
|
||||
]
|
||||
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
conversation=conversation,
|
||||
tokenize=False,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
tools=tool_dicts,
|
||||
documents=request.documents,
|
||||
chat_template=request.chat_template,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
**(request.chat_template_kwargs or {}),
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -112,19 +119,19 @@ class OpenAIServingChat(OpenAIServing):
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
try:
|
||||
# Tokenize/detokenize depending on prompt format (string/token list)
|
||||
prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
|
||||
prompt_ids, prompt_text = await self._validate_prompt_and_tokenize(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt=prompt,
|
||||
add_special_tokens=request.add_special_tokens)
|
||||
sampling_params = request.to_sampling_params()
|
||||
_, lora_request = self._maybe_get_adapter(request)
|
||||
decoding_config = await self.engine.get_decoding_config()
|
||||
guided_decoding_backend = request.guided_decoding_backend \
|
||||
or decoding_config.guided_decoding_backend
|
||||
guided_decode_logits_processor = (
|
||||
await get_guided_decoding_logits_processor(
|
||||
guided_decoding_backend, request, await
|
||||
self.engine.get_tokenizer()))
|
||||
await
|
||||
get_guided_decoding_logits_processor(guided_decoding_backend,
|
||||
request, tokenizer))
|
||||
if guided_decode_logits_processor:
|
||||
if sampling_params.logits_processors is None:
|
||||
sampling_params.logits_processors = []
|
||||
@@ -158,12 +165,12 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# Streaming response
|
||||
if request.stream:
|
||||
return self.chat_completion_stream_generator(
|
||||
request, result_generator, request_id, conversation)
|
||||
request, result_generator, request_id, conversation, tokenizer)
|
||||
else:
|
||||
try:
|
||||
return await self.chat_completion_full_generator(
|
||||
request, raw_request, result_generator, request_id,
|
||||
conversation)
|
||||
conversation, tokenizer)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
@@ -175,9 +182,12 @@ class OpenAIServingChat(OpenAIServing):
|
||||
return request.messages[-1]["role"]
|
||||
|
||||
async def chat_completion_stream_generator(
|
||||
self, request: ChatCompletionRequest,
|
||||
result_generator: AsyncIterator[RequestOutput], request_id: str,
|
||||
conversation: List[ConversationMessage]
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
result_generator: AsyncIterator[RequestOutput],
|
||||
request_id: str,
|
||||
conversation: List[ConversationMessage],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
model_name = self.served_model_names[0]
|
||||
created_time = int(time.time())
|
||||
@@ -264,6 +274,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
logprobs = self._create_chat_logprobs(
|
||||
token_ids=delta_token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
tokenizer=tokenizer,
|
||||
num_output_top_logprobs=request.top_logprobs,
|
||||
)
|
||||
else:
|
||||
@@ -352,9 +363,13 @@ class OpenAIServingChat(OpenAIServing):
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def chat_completion_full_generator(
|
||||
self, request: ChatCompletionRequest, raw_request: Optional[Request],
|
||||
result_generator: AsyncIterator[RequestOutput], request_id: str,
|
||||
conversation: List[ConversationMessage]
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
raw_request: Optional[Request],
|
||||
result_generator: AsyncIterator[RequestOutput],
|
||||
request_id: str,
|
||||
conversation: List[ConversationMessage],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
) -> Union[ErrorResponse, ChatCompletionResponse]:
|
||||
|
||||
model_name = self.served_model_names[0]
|
||||
@@ -382,6 +397,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
token_ids=token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
num_output_top_logprobs=request.top_logprobs,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
@@ -436,16 +452,14 @@ class OpenAIServingChat(OpenAIServing):
|
||||
return response
|
||||
|
||||
def _get_top_logprobs(
|
||||
self, logprobs: Dict[int, Logprob],
|
||||
top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]:
|
||||
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
|
||||
tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]:
|
||||
return [
|
||||
ChatCompletionLogProb(
|
||||
token=self._get_decoded_token(p[1], p[0]),
|
||||
token=(token := self._get_decoded_token(p[1], p[0],
|
||||
tokenizer)),
|
||||
logprob=max(p[1].logprob, -9999.0),
|
||||
bytes=list(
|
||||
self._get_decoded_token(p[1],
|
||||
p[0]).encode("utf-8",
|
||||
errors="replace")))
|
||||
bytes=list(token.encode("utf-8", errors="replace")))
|
||||
for i, p in enumerate(logprobs.items())
|
||||
if top_logprobs and i < top_logprobs
|
||||
]
|
||||
@@ -454,6 +468,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
self,
|
||||
token_ids: GenericSequence[int],
|
||||
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
num_output_top_logprobs: Optional[int] = None,
|
||||
) -> ChatCompletionLogProbs:
|
||||
"""Create OpenAI-style logprobs."""
|
||||
@@ -463,12 +478,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is None:
|
||||
token = tokenizer.decode(token_id)
|
||||
logprobs_content.append(
|
||||
ChatCompletionLogProbsContent(
|
||||
token=self.tokenizer.decode(token_id),
|
||||
bytes=list(
|
||||
self.tokenizer.decode(token_id).encode(
|
||||
"utf-8", errors="replace"))))
|
||||
token=token,
|
||||
bytes=list(token.encode("utf-8", errors="replace"))))
|
||||
else:
|
||||
logprobs_content.append(
|
||||
ChatCompletionLogProbsContent(
|
||||
@@ -479,6 +493,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
step_top_logprobs[token_id].decoded_token.encode(
|
||||
"utf-8", errors="replace")),
|
||||
top_logprobs=self._get_top_logprobs(
|
||||
step_top_logprobs, num_output_top_logprobs)))
|
||||
step_top_logprobs, num_output_top_logprobs,
|
||||
tokenizer)))
|
||||
|
||||
return ChatCompletionLogProbs(content=logprobs_content)
|
||||
|
||||
Reference in New Issue
Block a user