[BugFix][Frontend] Use LoRA tokenizer in OpenAI APIs (#6227)

Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Nick Hill
2024-07-18 00:13:30 -07:00
committed by GitHub
parent 8a74c68bd1
commit e2fbaee725
16 changed files with 267 additions and 186 deletions

View File

@@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
from typing import Tuple
from fastapi import Request
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
@@ -100,20 +101,22 @@ class OpenAIServingCompletion(OpenAIServing):
# Schedule the request and get the result generator.
generators: List[AsyncIterator[RequestOutput]] = []
try:
sampling_params = request.to_sampling_params()
adapter_type, adapter_request = self._maybe_get_adapter(request)
lora_request, prompt_adapter_request = None, None
if adapter_type == 'LoRA':
lora_request, prompt_adapter_request = adapter_request, None
elif adapter_type == 'PromptAdapter':
lora_request, prompt_adapter_request = None, adapter_request
tokenizer = await self.engine.get_tokenizer(lora_request)
sampling_params = request.to_sampling_params()
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
guided_decode_logit_processor = (
await get_guided_decoding_logits_processor(
guided_decoding_backend, request, await
self.engine.get_tokenizer()))
await
get_guided_decoding_logits_processor(guided_decoding_backend,
request, tokenizer))
if guided_decode_logit_processor is not None:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
@@ -122,18 +125,13 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
for i, prompt in enumerate(prompts):
if prompt_is_tokens:
prompt_formats = self._validate_prompt_and_tokenize(
request,
prompt_ids=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
else:
prompt_formats = self._validate_prompt_and_tokenize(
request,
prompt=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
prompt_arg = "prompt_ids" if prompt_is_tokens else "prompt"
prompt_formats = await self._validate_prompt_and_tokenize(
request,
tokenizer,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens,
**{prompt_arg: prompt})
prompt_ids, prompt_text = prompt_formats
is_tracing_enabled = await self.engine.is_tracing_enabled()
@@ -179,7 +177,8 @@ class OpenAIServingCompletion(OpenAIServing):
request_id,
created_time,
model_name,
num_prompts=len(prompts))
num_prompts=len(prompts),
tokenizer=tokenizer)
# Non-streaming response
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
@@ -191,7 +190,8 @@ class OpenAIServingCompletion(OpenAIServing):
return self.create_error_response("Client disconnected")
final_res_batch[i] = res
response = self.request_output_to_completion_response(
final_res_batch, request, request_id, created_time, model_name)
final_res_batch, request, request_id, created_time, model_name,
tokenizer)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
@@ -218,6 +218,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time: int,
model_name: str,
num_prompts: int,
tokenizer: PreTrainedTokenizer,
) -> AsyncGenerator[str, None]:
assert request.n is not None
previous_texts = [""] * request.n * num_prompts
@@ -268,6 +269,7 @@ class OpenAIServingCompletion(OpenAIServing):
token_ids=delta_token_ids,
top_logprobs=out_logprobs,
num_output_top_logprobs=request.logprobs,
tokenizer=tokenizer,
initial_text_offset=len(previous_texts[i]),
)
else:
@@ -336,6 +338,7 @@ class OpenAIServingCompletion(OpenAIServing):
request_id: str,
created_time: int,
model_name: str,
tokenizer: PreTrainedTokenizer,
) -> CompletionResponse:
choices: List[CompletionResponseChoice] = []
num_prompt_tokens = 0
@@ -367,6 +370,7 @@ class OpenAIServingCompletion(OpenAIServing):
logprobs = self._create_completion_logprobs(
token_ids=token_ids,
top_logprobs=out_logprobs,
tokenizer=tokenizer,
num_output_top_logprobs=request.logprobs,
)
else:
@@ -404,6 +408,7 @@ class OpenAIServingCompletion(OpenAIServing):
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
num_output_top_logprobs: int,
tokenizer: PreTrainedTokenizer,
initial_text_offset: int = 0,
) -> CompletionLogProbs:
"""Create logprobs for OpenAI Completion API."""
@@ -417,13 +422,13 @@ class OpenAIServingCompletion(OpenAIServing):
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = self.tokenizer.decode(token_id)
token = tokenizer.decode(token_id)
out_tokens.append(token)
out_token_logprobs.append(None)
out_top_logprobs.append(None)
else:
token = self._get_decoded_token(step_top_logprobs[token_id],
token_id)
token_id, tokenizer)
token_logprob = max(step_top_logprobs[token_id].logprob,
-9999.0)
out_tokens.append(token)
@@ -436,7 +441,7 @@ class OpenAIServingCompletion(OpenAIServing):
out_top_logprobs.append({
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
self._get_decoded_token(top_lp[1], top_lp[0]):
self._get_decoded_token(top_lp[1], top_lp[0], tokenizer):
max(top_lp[1].logprob, -9999.0)
for i, top_lp in enumerate(step_top_logprobs.items())
if num_output_top_logprobs >= i