@@ -13,6 +13,7 @@ from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.entrypoints.openai.logits_processors import get_logits_processors
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import LogitsProcessor, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
# torch is mocked during docs generation,
|
||||
@@ -152,6 +153,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
skip_special_tokens: bool = True
|
||||
spaces_between_special_tokens: bool = True
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
prompt_logprobs: Optional[int] = None
|
||||
# doc: end-chat-completion-sampling-params
|
||||
|
||||
# doc: begin-chat-completion-extra-params
|
||||
@@ -263,7 +265,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
stop=self.stop,
|
||||
stop_token_ids=self.stop_token_ids,
|
||||
logprobs=self.top_logprobs if self.logprobs else None,
|
||||
prompt_logprobs=self.top_logprobs if self.echo else None,
|
||||
prompt_logprobs=self.prompt_logprobs if self.prompt_logprobs else
|
||||
(self.top_logprobs if self.echo else None),
|
||||
ignore_eos=self.ignore_eos,
|
||||
max_tokens=max_tokens,
|
||||
min_tokens=self.min_tokens,
|
||||
@@ -368,6 +371,7 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
spaces_between_special_tokens: bool = True
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
allowed_token_ids: Optional[List[int]] = None
|
||||
prompt_logprobs: Optional[int] = None
|
||||
# doc: end-completion-sampling-params
|
||||
|
||||
# doc: begin-completion-extra-params
|
||||
@@ -454,7 +458,8 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
min_tokens=self.min_tokens,
|
||||
use_beam_search=self.use_beam_search,
|
||||
early_stopping=self.early_stopping,
|
||||
prompt_logprobs=self.logprobs if self.echo else None,
|
||||
prompt_logprobs=self.prompt_logprobs
|
||||
if self.prompt_logprobs else self.logprobs if self.echo else None,
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
@@ -532,6 +537,7 @@ class CompletionResponseChoice(OpenAIBaseModel):
|
||||
"to stop, None if the completion finished for some other reason "
|
||||
"including encountering the EOS token"),
|
||||
)
|
||||
prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
|
||||
|
||||
|
||||
class CompletionResponse(OpenAIBaseModel):
|
||||
@@ -627,6 +633,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseChoice]
|
||||
usage: UsageInfo
|
||||
prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
|
||||
|
||||
|
||||
class DeltaMessage(OpenAIBaseModel):
|
||||
|
||||
@@ -83,6 +83,16 @@ class OpenAIServingChat(OpenAIServing):
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
if request.prompt_logprobs is not None:
|
||||
if request.stream and request.prompt_logprobs > 0:
|
||||
return self.create_error_response(
|
||||
"Prompt_logprobs are not available when stream is enabled")
|
||||
|
||||
if request.prompt_logprobs < 0:
|
||||
return self.create_error_response(
|
||||
f"Prompt_logprobs set to invalid "
|
||||
f"negative value: {request.prompt_logprobs}")
|
||||
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
@@ -506,6 +516,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
prompt_logprobs=final_res.prompt_logprobs,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@@ -84,6 +84,15 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
created_time = int(time.time())
|
||||
|
||||
if request.prompt_logprobs is not None:
|
||||
if request.stream and request.prompt_logprobs > 0:
|
||||
return self.create_error_response(
|
||||
"Prompt_logprobs are not available when stream is enabled")
|
||||
elif request.prompt_logprobs < 0:
|
||||
return self.create_error_response(
|
||||
f"Prompt_logprobs set to invalid negative "
|
||||
f"value: {request.prompt_logprobs}")
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
@@ -377,6 +386,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason,
|
||||
prompt_logprobs=final_res.prompt_logprobs,
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user