[Frontend][2/n] Make pooling entrypoints request schema consensus | ChatRequest (#32574)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
@@ -2,10 +2,11 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Any
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
@@ -44,3 +45,66 @@ class CompletionRequestMixin(OpenAIBaseModel):
|
||||
"the prompt."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ChatRequestMixin(OpenAIBaseModel):
|
||||
messages: list[ChatCompletionMessageParam]
|
||||
|
||||
add_generation_prompt: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, the generation prompt will be added to the chat template. "
|
||||
"This is a parameter used by chat template in tokenizer config of the "
|
||||
"model."
|
||||
),
|
||||
)
|
||||
|
||||
continue_final_message: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If this is set, the chat will be formatted so that the final "
|
||||
"message in the chat is open-ended, without any EOS tokens. The "
|
||||
"model will continue this message rather than starting a new one. "
|
||||
'This allows you to "prefill" part of the model\'s response for it. '
|
||||
"Cannot be used at the same time as `add_generation_prompt`."
|
||||
),
|
||||
)
|
||||
|
||||
add_special_tokens: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
||||
"on top of what is added by the chat template. "
|
||||
"For most models, the chat template takes care of adding the "
|
||||
"special tokens so this should be set to false (as is the "
|
||||
"default)."
|
||||
),
|
||||
)
|
||||
|
||||
chat_template: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"A Jinja template to use for this conversion. "
|
||||
"As of transformers v4.44, default chat template is no longer "
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
"does not define one."
|
||||
),
|
||||
)
|
||||
|
||||
chat_template_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Additional keyword args to pass to the template renderer. "
|
||||
"Will be accessible by the chat template."
|
||||
),
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_generation_prompt(cls, data):
|
||||
if data.get("continue_final_message") and data.get("add_generation_prompt"):
|
||||
raise ValueError(
|
||||
"Cannot set both `continue_final_message` and "
|
||||
"`add_generation_prompt` to True."
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -10,9 +10,9 @@ from pydantic import (
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.config.pooler import get_use_activation
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
|
||||
from vllm.entrypoints.pooling.base.protocol import (
|
||||
ChatRequestMixin,
|
||||
CompletionRequestMixin,
|
||||
PoolingBasicRequestMixin,
|
||||
)
|
||||
@@ -45,48 +45,8 @@ class ClassificationCompletionRequest(PoolingBasicRequestMixin, CompletionReques
|
||||
)
|
||||
|
||||
|
||||
class ClassificationChatRequest(PoolingBasicRequestMixin):
|
||||
messages: list[ChatCompletionMessageParam]
|
||||
|
||||
class ClassificationChatRequest(PoolingBasicRequestMixin, ChatRequestMixin):
|
||||
# --8<-- [start:chat-classification-extra-params]
|
||||
add_generation_prompt: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, the generation prompt will be added to the chat template. "
|
||||
"This is a parameter used by chat template in tokenizer config of the "
|
||||
"model."
|
||||
),
|
||||
)
|
||||
|
||||
add_special_tokens: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
||||
"on top of what is added by the chat template. "
|
||||
"For most models, the chat template takes care of adding the "
|
||||
"special tokens so this should be set to false (as is the "
|
||||
"default)."
|
||||
),
|
||||
)
|
||||
|
||||
chat_template: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"A Jinja template to use for this conversion. "
|
||||
"As of transformers v4.44, default chat template is no longer "
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
"does not define one."
|
||||
),
|
||||
)
|
||||
|
||||
chat_template_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Additional keyword args to pass to the template renderer. "
|
||||
"Will be accessible by the chat template."
|
||||
),
|
||||
)
|
||||
|
||||
mm_processor_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=("Additional kwargs to pass to the HF processor."),
|
||||
|
||||
@@ -86,8 +86,8 @@ class ClassificationMixin(OpenAIServing):
|
||||
ChatTemplateContentFormatOption,
|
||||
getattr(self, "chat_template_content_format", "auto"),
|
||||
),
|
||||
add_generation_prompt=False,
|
||||
continue_final_message=False,
|
||||
add_generation_prompt=chat_request.add_generation_prompt,
|
||||
continue_final_message=chat_request.continue_final_message,
|
||||
add_special_tokens=chat_request.add_special_tokens,
|
||||
)
|
||||
ctx.engine_prompts = engine_prompts
|
||||
|
||||
@@ -5,13 +5,12 @@ from typing import Any, TypeAlias
|
||||
|
||||
from pydantic import (
|
||||
Field,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from vllm import PoolingParams
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
|
||||
from vllm.entrypoints.pooling.base.protocol import (
|
||||
ChatRequestMixin,
|
||||
CompletionRequestMixin,
|
||||
PoolingBasicRequestMixin,
|
||||
)
|
||||
@@ -57,57 +56,11 @@ class EmbeddingCompletionRequest(PoolingBasicRequestMixin, CompletionRequestMixi
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingChatRequest(PoolingBasicRequestMixin):
|
||||
messages: list[ChatCompletionMessageParam]
|
||||
|
||||
class EmbeddingChatRequest(PoolingBasicRequestMixin, ChatRequestMixin):
|
||||
encoding_format: EncodingFormat = "float"
|
||||
dimensions: int | None = None
|
||||
|
||||
# --8<-- [start:chat-embedding-extra-params]
|
||||
add_generation_prompt: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, the generation prompt will be added to the chat template. "
|
||||
"This is a parameter used by chat template in tokenizer config of the "
|
||||
"model."
|
||||
),
|
||||
)
|
||||
continue_final_message: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If this is set, the chat will be formatted so that the final "
|
||||
"message in the chat is open-ended, without any EOS tokens. The "
|
||||
"model will continue this message rather than starting a new one. "
|
||||
'This allows you to "prefill" part of the model\'s response for it. '
|
||||
"Cannot be used at the same time as `add_generation_prompt`."
|
||||
),
|
||||
)
|
||||
add_special_tokens: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
||||
"on top of what is added by the chat template. "
|
||||
"For most models, the chat template takes care of adding the "
|
||||
"special tokens so this should be set to false (as is the "
|
||||
"default)."
|
||||
),
|
||||
)
|
||||
chat_template: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"A Jinja template to use for this conversion. "
|
||||
"As of transformers v4.44, default chat template is no longer "
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
"does not define one."
|
||||
),
|
||||
)
|
||||
chat_template_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Additional keyword args to pass to the template renderer. "
|
||||
"Will be accessible by the chat template."
|
||||
),
|
||||
)
|
||||
mm_processor_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=("Additional kwargs to pass to the HF processor."),
|
||||
@@ -134,16 +87,6 @@ class EmbeddingChatRequest(PoolingBasicRequestMixin):
|
||||
)
|
||||
# --8<-- [end:chat-embedding-extra-params]
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_generation_prompt(cls, data):
|
||||
if data.get("continue_final_message") and data.get("add_generation_prompt"):
|
||||
raise ValueError(
|
||||
"Cannot set both `continue_final_message` and "
|
||||
"`add_generation_prompt` to True."
|
||||
)
|
||||
return data
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
|
||||
@@ -144,10 +144,8 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
request.messages,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
chat_template_content_format=self.chat_template_content_format,
|
||||
# In pooling requests, we are not generating tokens,
|
||||
# so there is no need to append extra tokens to the input
|
||||
add_generation_prompt=False,
|
||||
continue_final_message=False,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
elif isinstance(request, PoolingCompletionRequest):
|
||||
|
||||
Reference in New Issue
Block a user