[Frontend] Add logits_processors as an extra completion argument (#11150)

Signed-off-by: Brad Hilton <brad.hilton.nw@gmail.com>
This commit is contained in:
Brad Hilton
2024-12-14 09:46:42 -07:00
committed by GitHub
parent 3cb5769883
commit 9c3dadd1c9
6 changed files with 127 additions and 39 deletions

View File

@@ -1,5 +1,6 @@
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import re
import time
from argparse import Namespace
from typing import Any, Dict, List, Literal, Optional, Union
@@ -14,7 +15,7 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams)
from vllm.sequence import Logprob
from vllm.utils import random_uuid
from vllm.utils import random_uuid, resolve_obj_by_qualname
logger = init_logger(__name__)
@@ -148,6 +149,46 @@ class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
type: Literal["function"] = "function"
class LogitsProcessorConstructor(BaseModel):
qualname: str
args: Optional[List[Any]] = None
kwargs: Optional[Dict[str, Any]] = None
LogitsProcessors = List[Union[str, LogitsProcessorConstructor]]
def get_logits_processors(processors: Optional[LogitsProcessors],
pattern: Optional[str]) -> Optional[List[Any]]:
if processors and pattern:
logits_processors = []
for processor in processors:
qualname = processor if isinstance(processor,
str) else processor.qualname
if not re.match(pattern, qualname):
raise ValueError(
f"Logits processor '{qualname}' is not allowed by this "
"server. See --logits-processor-pattern engine argument "
"for more information.")
try:
logits_processor = resolve_obj_by_qualname(qualname)
except Exception as e:
raise ValueError(
f"Logits processor '{qualname}' could not be resolved: {e}"
) from e
if isinstance(processor, LogitsProcessorConstructor):
logits_processor = logits_processor(*processor.args or [],
**processor.kwargs or {})
logits_processors.append(logits_processor)
return logits_processors
elif processors:
raise ValueError(
"The `logits_processors` argument is not supported by this "
"server. See --logits-processor-pattern engine argugment "
"for more information.")
return None
class ChatCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
@@ -293,6 +334,17 @@ class ChatCompletionRequest(OpenAIBaseModel):
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."))
logits_processors: Optional[LogitsProcessors] = Field(
default=None,
description=(
"A list of either qualified names of logits processors, or "
"constructor objects, to apply when sampling. A constructor is "
"a JSON object with a required 'qualname' field specifying the "
"qualified name of the processor class/factory, and optional "
"'args' and 'kwargs' fields containing positional and keyword "
"arguments. For example: {'qualname': "
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
"{'param': 'value'}}."))
# doc: end-chat-completion-extra-params
@@ -314,7 +366,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output)
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
def to_sampling_params(
self, default_max_tokens: int,
logits_processor_pattern: Optional[str]) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens
if max_tokens is None:
@@ -364,6 +418,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
min_tokens=self.min_tokens,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
logits_processors=get_logits_processors(self.logits_processors,
logits_processor_pattern),
include_stop_str_in_output=self.include_stop_str_in_output,
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
@@ -599,6 +655,17 @@ class CompletionRequest(OpenAIBaseModel):
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
logits_processors: Optional[LogitsProcessors] = Field(
default=None,
description=(
"A list of either qualified names of logits processors, or "
"constructor objects, to apply when sampling. A constructor is "
"a JSON object with a required 'qualname' field specifying the "
"qualified name of the processor class/factory, and optional "
"'args' and 'kwargs' fields containing positional and keyword "
"arguments. For example: {'qualname': "
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
"{'param': 'value'}}."))
# doc: end-completion-extra-params
@@ -619,7 +686,9 @@ class CompletionRequest(OpenAIBaseModel):
length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output)
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
def to_sampling_params(
self, default_max_tokens: int,
logits_processor_pattern: Optional[str]) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
@@ -665,6 +734,8 @@ class CompletionRequest(OpenAIBaseModel):
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
logits_processors=get_logits_processors(self.logits_processors,
logits_processor_pattern),
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,