[Bugfix][Frontend] Cleanup "fix chat logprobs" (#5026)

This commit is contained in:
Cyrus Leung
2024-06-11 13:36:46 +08:00
committed by GitHub
parent 351d5e7b82
commit 640052b069
6 changed files with 122 additions and 123 deletions

View File

@@ -8,6 +8,7 @@ from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionRequest,
@@ -16,7 +17,6 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionResponseStreamChoice,
CompletionStreamResponse,
UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger
@@ -221,7 +221,7 @@ class OpenAIServingCompletion(OpenAIServing):
# only return the prompt
delta_text = res.prompt
delta_token_ids = res.prompt_token_ids
top_logprobs = res.prompt_logprobs
out_logprobs = res.prompt_logprobs
has_echoed[i] = True
elif (request.echo and request.max_tokens > 0
and not has_echoed[i]):
@@ -229,7 +229,7 @@ class OpenAIServingCompletion(OpenAIServing):
delta_text = res.prompt + output.text
delta_token_ids = (res.prompt_token_ids +
output.token_ids)
top_logprobs = res.prompt_logprobs + (output.logprobs
out_logprobs = res.prompt_logprobs + (output.logprobs
or [])
has_echoed[i] = True
else:
@@ -237,13 +237,15 @@ class OpenAIServingCompletion(OpenAIServing):
delta_text = output.text[len(previous_texts[i]):]
delta_token_ids = output.token_ids[
previous_num_tokens[i]:]
top_logprobs = output.logprobs[previous_num_tokens[
out_logprobs = output.logprobs[previous_num_tokens[
i]:] if output.logprobs else None
if request.logprobs is not None:
assert out_logprobs is not None, (
"Did not output logprobs")
logprobs = self._create_completion_logprobs(
token_ids=delta_token_ids,
top_logprobs=top_logprobs,
top_logprobs=out_logprobs,
num_output_top_logprobs=request.logprobs,
initial_text_offset=len(previous_texts[i]),
)
@@ -325,25 +327,23 @@ class OpenAIServingCompletion(OpenAIServing):
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0:
token_ids = prompt_token_ids
top_logprobs = prompt_logprobs
out_logprobs = prompt_logprobs
output_text = prompt_text
elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + output.token_ids
top_logprobs = (prompt_logprobs + output.logprobs
out_logprobs = (prompt_logprobs + output.logprobs
if request.logprobs is not None else None)
output_text = prompt_text + output.text
else:
token_ids = output.token_ids
top_logprobs = output.logprobs
out_logprobs = output.logprobs
output_text = output.text
if request.logprobs is not None:
assert top_logprobs is not None, (
"top_logprobs must be provided when logprobs "
"is requested")
assert out_logprobs is not None, "Did not output logprobs"
logprobs = self._create_completion_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,
top_logprobs=out_logprobs,
num_output_top_logprobs=request.logprobs,
)
else: