Fix echo/logprob OpenAI completion bug (#3441)
Co-authored-by: Dylan Hawk <dylanwawk@gmail.com>
This commit is contained in:
@@ -2,7 +2,7 @@ import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import conint
|
||||
|
||||
@@ -99,27 +99,32 @@ class OpenAIServing:
|
||||
last_token_len = 0
|
||||
if num_output_top_logprobs:
|
||||
logprobs.top_logprobs = []
|
||||
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is not None:
|
||||
token_logprob = step_top_logprobs[token_id].logprob
|
||||
if step_top_logprobs is None:
|
||||
token = self.tokenizer.decode(token_id)
|
||||
logprobs.tokens.append(token)
|
||||
logprobs.token_logprobs.append(None)
|
||||
logprobs.top_logprobs.append(None)
|
||||
else:
|
||||
token_logprob = None
|
||||
token = step_top_logprobs[token_id].decoded_token
|
||||
logprobs.tokens.append(token)
|
||||
logprobs.token_logprobs.append(token_logprob)
|
||||
token_logprob = step_top_logprobs[token_id].logprob
|
||||
token = step_top_logprobs[token_id].decoded_token
|
||||
logprobs.tokens.append(token)
|
||||
logprobs.token_logprobs.append(token_logprob)
|
||||
|
||||
if num_output_top_logprobs:
|
||||
logprobs.top_logprobs.append({
|
||||
p.decoded_token: p.logprob
|
||||
for i, p in step_top_logprobs.items()
|
||||
} if step_top_logprobs else None)
|
||||
|
||||
if len(logprobs.text_offset) == 0:
|
||||
logprobs.text_offset.append(initial_text_offset)
|
||||
else:
|
||||
logprobs.text_offset.append(logprobs.text_offset[-1] +
|
||||
last_token_len)
|
||||
last_token_len = len(token)
|
||||
|
||||
if num_output_top_logprobs:
|
||||
logprobs.top_logprobs.append({
|
||||
p.decoded_token: p.logprob
|
||||
for i, p in step_top_logprobs.items()
|
||||
} if step_top_logprobs else None)
|
||||
return logprobs
|
||||
|
||||
def create_error_response(
|
||||
@@ -164,12 +169,12 @@ class OpenAIServing:
|
||||
raise ValueError("The model `{request.model}` does not exist.")
|
||||
|
||||
def _validate_prompt_and_tokenize(
|
||||
self,
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
prompt: Optional[str] = None,
|
||||
prompt_ids: Optional[List[int]] = None,
|
||||
truncate_prompt_tokens: Optional[conint(ge=1)] = None
|
||||
) -> List[int]:
|
||||
self,
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
prompt: Optional[str] = None,
|
||||
prompt_ids: Optional[List[int]] = None,
|
||||
truncate_prompt_tokens: Optional[conint(ge=1)] = None
|
||||
) -> Tuple[List[int], str]:
|
||||
if not (prompt or prompt_ids):
|
||||
raise ValueError("Either prompt or prompt_ids should be provided.")
|
||||
if (prompt and prompt_ids):
|
||||
@@ -187,6 +192,8 @@ class OpenAIServing:
|
||||
else:
|
||||
input_ids = prompt_ids
|
||||
|
||||
input_text = prompt if prompt is not None else self.tokenizer.decode(
|
||||
prompt_ids)
|
||||
token_num = len(input_ids)
|
||||
|
||||
if request.max_tokens is None:
|
||||
@@ -201,4 +208,4 @@ class OpenAIServing:
|
||||
f"{request.max_tokens} in the completion). "
|
||||
f"Please reduce the length of the messages or completion.", )
|
||||
else:
|
||||
return input_ids
|
||||
return input_ids, input_text
|
||||
|
||||
Reference in New Issue
Block a user