[BugFix][Frontend] Use LoRA tokenizer in OpenAI APIs (#6227)
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
|
||||
from typing import Tuple
|
||||
|
||||
from fastapi import Request
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
@@ -100,20 +101,22 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncIterator[RequestOutput]] = []
|
||||
try:
|
||||
sampling_params = request.to_sampling_params()
|
||||
adapter_type, adapter_request = self._maybe_get_adapter(request)
|
||||
lora_request, prompt_adapter_request = None, None
|
||||
if adapter_type == 'LoRA':
|
||||
lora_request, prompt_adapter_request = adapter_request, None
|
||||
elif adapter_type == 'PromptAdapter':
|
||||
lora_request, prompt_adapter_request = None, adapter_request
|
||||
tokenizer = await self.engine.get_tokenizer(lora_request)
|
||||
|
||||
sampling_params = request.to_sampling_params()
|
||||
decoding_config = await self.engine.get_decoding_config()
|
||||
guided_decoding_backend = request.guided_decoding_backend \
|
||||
or decoding_config.guided_decoding_backend
|
||||
guided_decode_logit_processor = (
|
||||
await get_guided_decoding_logits_processor(
|
||||
guided_decoding_backend, request, await
|
||||
self.engine.get_tokenizer()))
|
||||
await
|
||||
get_guided_decoding_logits_processor(guided_decoding_backend,
|
||||
request, tokenizer))
|
||||
if guided_decode_logit_processor is not None:
|
||||
if sampling_params.logits_processors is None:
|
||||
sampling_params.logits_processors = []
|
||||
@@ -122,18 +125,13 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
|
||||
|
||||
for i, prompt in enumerate(prompts):
|
||||
if prompt_is_tokens:
|
||||
prompt_formats = self._validate_prompt_and_tokenize(
|
||||
request,
|
||||
prompt_ids=prompt,
|
||||
truncate_prompt_tokens=sampling_params.
|
||||
truncate_prompt_tokens)
|
||||
else:
|
||||
prompt_formats = self._validate_prompt_and_tokenize(
|
||||
request,
|
||||
prompt=prompt,
|
||||
truncate_prompt_tokens=sampling_params.
|
||||
truncate_prompt_tokens)
|
||||
prompt_arg = "prompt_ids" if prompt_is_tokens else "prompt"
|
||||
prompt_formats = await self._validate_prompt_and_tokenize(
|
||||
request,
|
||||
tokenizer,
|
||||
truncate_prompt_tokens=sampling_params.
|
||||
truncate_prompt_tokens,
|
||||
**{prompt_arg: prompt})
|
||||
prompt_ids, prompt_text = prompt_formats
|
||||
|
||||
is_tracing_enabled = await self.engine.is_tracing_enabled()
|
||||
@@ -179,7 +177,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
num_prompts=len(prompts))
|
||||
num_prompts=len(prompts),
|
||||
tokenizer=tokenizer)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
|
||||
@@ -191,7 +190,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
return self.create_error_response("Client disconnected")
|
||||
final_res_batch[i] = res
|
||||
response = self.request_output_to_completion_response(
|
||||
final_res_batch, request, request_id, created_time, model_name)
|
||||
final_res_batch, request, request_id, created_time, model_name,
|
||||
tokenizer)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
@@ -218,6 +218,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
num_prompts: int,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
assert request.n is not None
|
||||
previous_texts = [""] * request.n * num_prompts
|
||||
@@ -268,6 +269,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
token_ids=delta_token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
tokenizer=tokenizer,
|
||||
initial_text_offset=len(previous_texts[i]),
|
||||
)
|
||||
else:
|
||||
@@ -336,6 +338,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
) -> CompletionResponse:
|
||||
choices: List[CompletionResponseChoice] = []
|
||||
num_prompt_tokens = 0
|
||||
@@ -367,6 +370,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
logprobs = self._create_completion_logprobs(
|
||||
token_ids=token_ids,
|
||||
top_logprobs=out_logprobs,
|
||||
tokenizer=tokenizer,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
)
|
||||
else:
|
||||
@@ -404,6 +408,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
token_ids: GenericSequence[int],
|
||||
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
||||
num_output_top_logprobs: int,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
initial_text_offset: int = 0,
|
||||
) -> CompletionLogProbs:
|
||||
"""Create logprobs for OpenAI Completion API."""
|
||||
@@ -417,13 +422,13 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is None:
|
||||
token = self.tokenizer.decode(token_id)
|
||||
token = tokenizer.decode(token_id)
|
||||
out_tokens.append(token)
|
||||
out_token_logprobs.append(None)
|
||||
out_top_logprobs.append(None)
|
||||
else:
|
||||
token = self._get_decoded_token(step_top_logprobs[token_id],
|
||||
token_id)
|
||||
token_id, tokenizer)
|
||||
token_logprob = max(step_top_logprobs[token_id].logprob,
|
||||
-9999.0)
|
||||
out_tokens.append(token)
|
||||
@@ -436,7 +441,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
out_top_logprobs.append({
|
||||
# Convert float("-inf") to the
|
||||
# JSON-serializable float that OpenAI uses
|
||||
self._get_decoded_token(top_lp[1], top_lp[0]):
|
||||
self._get_decoded_token(top_lp[1], top_lp[0], tokenizer):
|
||||
max(top_lp[1].logprob, -9999.0)
|
||||
for i, top_lp in enumerate(step_top_logprobs.items())
|
||||
if num_output_top_logprobs >= i
|
||||
|
||||
Reference in New Issue
Block a user