[Frontend] Add logits_processors as an extra completion argument (#11150)
Signed-off-by: Brad Hilton <brad.hilton.nw@gmail.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# Adapted from
|
||||
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
||||
import re
|
||||
import time
|
||||
from argparse import Namespace
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
@@ -14,7 +15,7 @@ from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
|
||||
RequestOutputKind, SamplingParams)
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils import random_uuid, resolve_obj_by_qualname
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -148,6 +149,46 @@ class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
|
||||
type: Literal["function"] = "function"
|
||||
|
||||
|
||||
class LogitsProcessorConstructor(BaseModel):
|
||||
qualname: str
|
||||
args: Optional[List[Any]] = None
|
||||
kwargs: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
LogitsProcessors = List[Union[str, LogitsProcessorConstructor]]
|
||||
|
||||
|
||||
def get_logits_processors(processors: Optional[LogitsProcessors],
|
||||
pattern: Optional[str]) -> Optional[List[Any]]:
|
||||
if processors and pattern:
|
||||
logits_processors = []
|
||||
for processor in processors:
|
||||
qualname = processor if isinstance(processor,
|
||||
str) else processor.qualname
|
||||
if not re.match(pattern, qualname):
|
||||
raise ValueError(
|
||||
f"Logits processor '{qualname}' is not allowed by this "
|
||||
"server. See --logits-processor-pattern engine argument "
|
||||
"for more information.")
|
||||
try:
|
||||
logits_processor = resolve_obj_by_qualname(qualname)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Logits processor '{qualname}' could not be resolved: {e}"
|
||||
) from e
|
||||
if isinstance(processor, LogitsProcessorConstructor):
|
||||
logits_processor = logits_processor(*processor.args or [],
|
||||
**processor.kwargs or {})
|
||||
logits_processors.append(logits_processor)
|
||||
return logits_processors
|
||||
elif processors:
|
||||
raise ValueError(
|
||||
"The `logits_processors` argument is not supported by this "
|
||||
"server. See --logits-processor-pattern engine argugment "
|
||||
"for more information.")
|
||||
return None
|
||||
|
||||
|
||||
class ChatCompletionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/chat/create
|
||||
@@ -293,6 +334,17 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."))
|
||||
logits_processors: Optional[LogitsProcessors] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"A list of either qualified names of logits processors, or "
|
||||
"constructor objects, to apply when sampling. A constructor is "
|
||||
"a JSON object with a required 'qualname' field specifying the "
|
||||
"qualified name of the processor class/factory, and optional "
|
||||
"'args' and 'kwargs' fields containing positional and keyword "
|
||||
"arguments. For example: {'qualname': "
|
||||
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
|
||||
"{'param': 'value'}}."))
|
||||
|
||||
# doc: end-chat-completion-extra-params
|
||||
|
||||
@@ -314,7 +366,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
length_penalty=self.length_penalty,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output)
|
||||
|
||||
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
|
||||
def to_sampling_params(
|
||||
self, default_max_tokens: int,
|
||||
logits_processor_pattern: Optional[str]) -> SamplingParams:
|
||||
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
||||
max_tokens = self.max_completion_tokens or self.max_tokens
|
||||
if max_tokens is None:
|
||||
@@ -364,6 +418,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
min_tokens=self.min_tokens,
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||
logits_processors=get_logits_processors(self.logits_processors,
|
||||
logits_processor_pattern),
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
output_kind=RequestOutputKind.DELTA if self.stream \
|
||||
@@ -599,6 +655,17 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
logits_processors: Optional[LogitsProcessors] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"A list of either qualified names of logits processors, or "
|
||||
"constructor objects, to apply when sampling. A constructor is "
|
||||
"a JSON object with a required 'qualname' field specifying the "
|
||||
"qualified name of the processor class/factory, and optional "
|
||||
"'args' and 'kwargs' fields containing positional and keyword "
|
||||
"arguments. For example: {'qualname': "
|
||||
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
|
||||
"{'param': 'value'}}."))
|
||||
|
||||
# doc: end-completion-extra-params
|
||||
|
||||
@@ -619,7 +686,9 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
length_penalty=self.length_penalty,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output)
|
||||
|
||||
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
|
||||
def to_sampling_params(
|
||||
self, default_max_tokens: int,
|
||||
logits_processor_pattern: Optional[str]) -> SamplingParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
@@ -665,6 +734,8 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
logits_processors=get_logits_processors(self.logits_processors,
|
||||
logits_processor_pattern),
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
output_kind=RequestOutputKind.DELTA if self.stream \
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
|
||||
Reference in New Issue
Block a user