Fix echo/logprob OpenAI completion bug (#3441)

Co-authored-by: Dylan Hawk <dylanwawk@gmail.com>
This commit is contained in:
Dylan Hawk
2024-04-11 15:15:50 -07:00
committed by GitHub
parent 559eb852f8
commit 95e7d4a97c
4 changed files with 73 additions and 29 deletions

View File

@@ -136,23 +136,24 @@ class OpenAIServingCompletion(OpenAIServing):
for i, prompt in enumerate(prompts):
if prompt_is_tokens:
input_ids = self._validate_prompt_and_tokenize(
prompt_formats = self._validate_prompt_and_tokenize(
request,
prompt_ids=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
else:
input_ids = self._validate_prompt_and_tokenize(
prompt_formats = self._validate_prompt_and_tokenize(
request,
prompt=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
prompt_ids, prompt_text = prompt_formats
generators.append(
self.engine.generate(prompt,
self.engine.generate(prompt_text,
sampling_params,
f"{request_id}-{i}",
prompt_token_ids=input_ids,
prompt_token_ids=prompt_ids,
lora_request=lora_request))
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
@@ -326,7 +327,8 @@ class OpenAIServingCompletion(OpenAIServing):
output_text = prompt_text
elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + output.token_ids
top_logprobs = prompt_logprobs + output.logprobs
top_logprobs = (prompt_logprobs + output.logprobs
if request.logprobs else None)
output_text = prompt_text + output.text
else:
token_ids = output.token_ids
@@ -334,6 +336,9 @@ class OpenAIServingCompletion(OpenAIServing):
output_text = output.text
if request.logprobs is not None:
assert top_logprobs is not None, (
"top_logprobs must be provided when logprobs "
"is requested")
logprobs = self._create_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,