[Feature]: Add OpenAI server prompt_logprobs support #6508 (#7453)

This commit is contained in:
Grant Pinkert
2024-08-16 12:38:08 +10:00
committed by GitHub
parent b67ae00cdb
commit f878c8feb0
4 changed files with 154 additions and 3 deletions

View File

@@ -13,6 +13,7 @@ from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sequence import Logprob
from vllm.utils import random_uuid
# torch is mocked during docs generation,
@@ -152,6 +153,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
prompt_logprobs: Optional[int] = None
# doc: end-chat-completion-sampling-params
# doc: begin-chat-completion-extra-params
@@ -263,7 +265,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
stop=self.stop,
stop_token_ids=self.stop_token_ids,
logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=self.top_logprobs if self.echo else None,
prompt_logprobs=self.prompt_logprobs if self.prompt_logprobs else
(self.top_logprobs if self.echo else None),
ignore_eos=self.ignore_eos,
max_tokens=max_tokens,
min_tokens=self.min_tokens,
@@ -368,6 +371,7 @@ class CompletionRequest(OpenAIBaseModel):
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
allowed_token_ids: Optional[List[int]] = None
prompt_logprobs: Optional[int] = None
# doc: end-completion-sampling-params
# doc: begin-completion-extra-params
@@ -454,7 +458,8 @@ class CompletionRequest(OpenAIBaseModel):
min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
prompt_logprobs=self.logprobs if self.echo else None,
prompt_logprobs=self.prompt_logprobs
if self.prompt_logprobs else self.logprobs if self.echo else None,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
@@ -532,6 +537,7 @@ class CompletionResponseChoice(OpenAIBaseModel):
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"),
)
prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
class CompletionResponse(OpenAIBaseModel):
@@ -627,6 +633,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
class DeltaMessage(OpenAIBaseModel):

View File

@@ -83,6 +83,16 @@ class OpenAIServingChat(OpenAIServing):
if error_check_ret is not None:
return error_check_ret
if request.prompt_logprobs is not None:
if request.stream and request.prompt_logprobs > 0:
return self.create_error_response(
"Prompt_logprobs are not available when stream is enabled")
if request.prompt_logprobs < 0:
return self.create_error_response(
f"Prompt_logprobs set to invalid "
f"negative value: {request.prompt_logprobs}")
try:
(
lora_request,
@@ -506,6 +516,7 @@ class OpenAIServingChat(OpenAIServing):
model=model_name,
choices=choices,
usage=usage,
prompt_logprobs=final_res.prompt_logprobs,
)
return response

View File

@@ -84,6 +84,15 @@ class OpenAIServingCompletion(OpenAIServing):
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.time())
if request.prompt_logprobs is not None:
if request.stream and request.prompt_logprobs > 0:
return self.create_error_response(
"Prompt_logprobs are not available when stream is enabled")
elif request.prompt_logprobs < 0:
return self.create_error_response(
f"Prompt_logprobs set to invalid negative "
f"value: {request.prompt_logprobs}")
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[RequestOutput, None]] = []
try:
@@ -377,6 +386,7 @@ class OpenAIServingCompletion(OpenAIServing):
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
prompt_logprobs=final_res.prompt_logprobs,
)
choices.append(choice_data)