[BUGFIX] [FRONTEND] Correct chat logprobs (#5029)

Co-authored-by: Breno Faria <breno.faria@intrafind.com>
This commit is contained in:
Breno Faria
2024-05-30 11:52:14 +02:00
committed by GitHub
parent e07aff9e52
commit 87d41c849d
6 changed files with 362 additions and 99 deletions

View File

@@ -11,7 +11,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest,
EmbeddingRequest, ErrorResponse,
LogProbs, ModelCard, ModelList,
ModelCard, ModelList,
ModelPermission)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@@ -75,51 +75,6 @@ class OpenAIServing:
model_cards.extend(lora_cards)
return ModelList(data=model_cards)
def _create_logprobs(
self,
token_ids: List[int],
top_logprobs: List[Optional[Dict[int, Logprob]]],
num_output_top_logprobs: Optional[int] = None,
initial_text_offset: int = 0,
) -> LogProbs:
"""Create OpenAI-style logprobs."""
logprobs = LogProbs()
last_token_len = 0
if num_output_top_logprobs:
logprobs.top_logprobs = []
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = self.tokenizer.decode(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(None)
assert logprobs.top_logprobs is not None
logprobs.top_logprobs.append(None)
else:
token_logprob = step_top_logprobs[token_id].logprob
token = step_top_logprobs[token_id].decoded_token
logprobs.tokens.append(token)
token_logprob = max(token_logprob, -9999.0)
logprobs.token_logprobs.append(token_logprob)
if num_output_top_logprobs:
assert logprobs.top_logprobs is not None
logprobs.top_logprobs.append({
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
p.decoded_token: max(p.logprob, -9999.0)
for i, p in step_top_logprobs.items()
} if step_top_logprobs else None)
if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset)
else:
logprobs.text_offset.append(logprobs.text_offset[-1] +
last_token_len)
last_token_len = len(token)
return logprobs
def create_error_response(
self,
message: str,
@@ -235,3 +190,8 @@ class OpenAIServing:
f"Please reduce the length of the messages or completion.", )
else:
return input_ids, input_text
def _get_decoded_token(self, logprob: Logprob, token_id: int) -> str:
if logprob.decoded_token is not None:
return logprob.decoded_token
return self.tokenizer.decode(token_id)