[BUGFIX] [FRONTEND] Correct chat logprobs (#5029)

Co-authored-by: Breno Faria <breno.faria@intrafind.com>
This commit is contained in:
Breno Faria
2024-05-30 11:52:14 +02:00
committed by GitHub
parent e07aff9e52
commit 87d41c849d
6 changed files with 362 additions and 99 deletions

View File

@@ -1,23 +1,29 @@
import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional, Tuple)
Optional)
from typing import Sequence as GenericSequence
from typing import Tuple
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (CompletionRequest,
# yapf: disable
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
LogProbs, UsageInfo)
UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.outputs import RequestOutput
from vllm.sequence import Logprob
from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__)
@@ -25,7 +31,7 @@ logger = init_logger(__name__)
TypeTokenIDs = List[int]
TypeTopLogProbs = List[Optional[Dict[int, float]]]
TypeCreateLogProbsFn = Callable[
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs]
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
def parse_prompt_format(prompt) -> Tuple[bool, list]:
@@ -235,7 +241,7 @@ class OpenAIServingCompletion(OpenAIServing):
i]:] if output.logprobs else None
if request.logprobs is not None:
logprobs = self._create_logprobs(
logprobs = self._create_completion_logprobs(
token_ids=delta_token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
@@ -317,7 +323,7 @@ class OpenAIServingCompletion(OpenAIServing):
assert top_logprobs is not None, (
"top_logprobs must be provided when logprobs "
"is requested")
logprobs = self._create_logprobs(
logprobs = self._create_completion_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
@@ -351,3 +357,59 @@ class OpenAIServingCompletion(OpenAIServing):
choices=choices,
usage=usage,
)
def _create_completion_logprobs(
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
num_output_top_logprobs: int,
initial_text_offset: int = 0,
) -> CompletionLogProbs:
"""Create logprobs for OpenAI Completion API."""
out_text_offset: List[int] = []
out_token_logprobs: List[Optional[float]] = []
out_tokens: List[str] = []
out_top_logprobs: List[Optional[Dict[str, float]]] = []
last_token_len = 0
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = self.tokenizer.decode(token_id)
out_tokens.append(token)
out_token_logprobs.append(None)
out_top_logprobs.append(None)
else:
token = self._get_decoded_token(step_top_logprobs[token_id],
token_id)
token_logprob = max(step_top_logprobs[token_id].logprob,
-9999.0)
out_tokens.append(token)
out_token_logprobs.append(token_logprob)
# makes sure to add the top num_output_top_logprobs + 1
# logprobs, as defined in the openai API
# (cf. https://github.com/openai/openai-openapi/blob/
# 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
out_top_logprobs.append({
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
self._get_decoded_token(top_lp[1], top_lp[0]):
max(top_lp[1].logprob, -9999.0)
for i, top_lp in enumerate(step_top_logprobs.items())
if num_output_top_logprobs >= i
})
if len(out_text_offset) == 0:
out_text_offset.append(initial_text_offset)
else:
out_text_offset.append(out_text_offset[-1] + last_token_len)
last_token_len = len(token)
return CompletionLogProbs(
text_offset=out_text_offset,
token_logprobs=out_token_logprobs,
tokens=out_tokens,
top_logprobs=out_top_logprobs,
)