[Frontend][2/n] Make pooling entrypoints request schema consensus | ChatRequest (#32574)

Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
wang.yuqi
2026-01-22 18:32:44 +08:00
committed by GitHub
parent 64e3d67ac0
commit 328cbb2773
24 changed files with 456 additions and 205 deletions

View File

@@ -2,10 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Annotated
from typing import Annotated, Any
from pydantic import Field
from pydantic import Field, model_validator
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
from vllm.utils import random_uuid
@@ -44,3 +45,66 @@ class CompletionRequestMixin(OpenAIBaseModel):
"the prompt."
),
)
class ChatRequestMixin(OpenAIBaseModel):
messages: list[ChatCompletionMessageParam]
add_generation_prompt: bool = Field(
default=False,
description=(
"If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."
),
)
continue_final_message: bool = Field(
default=False,
description=(
"If this is set, the chat will be formatted so that the final "
"message in the chat is open-ended, without any EOS tokens. The "
"model will continue this message rather than starting a new one. "
'This allows you to "prefill" part of the model\'s response for it. '
"Cannot be used at the same time as `add_generation_prompt`."
),
)
add_special_tokens: bool = Field(
default=False,
description=(
"If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. "
"For most models, the chat template takes care of adding the "
"special tokens so this should be set to false (as is the "
"default)."
),
)
chat_template: str | None = Field(
default=None,
description=(
"A Jinja template to use for this conversion. "
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
),
)
chat_template_kwargs: dict[str, Any] | None = Field(
default=None,
description=(
"Additional keyword args to pass to the template renderer. "
"Will be accessible by the chat template."
),
)
@model_validator(mode="before")
@classmethod
def check_generation_prompt(cls, data):
if data.get("continue_final_message") and data.get("add_generation_prompt"):
raise ValueError(
"Cannot set both `continue_final_message` and "
"`add_generation_prompt` to True."
)
return data

View File

@@ -10,9 +10,9 @@ from pydantic import (
from vllm import PoolingParams
from vllm.config.pooler import get_use_activation
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import (
ChatRequestMixin,
CompletionRequestMixin,
PoolingBasicRequestMixin,
)
@@ -45,48 +45,8 @@ class ClassificationCompletionRequest(PoolingBasicRequestMixin, CompletionReques
)
class ClassificationChatRequest(PoolingBasicRequestMixin):
messages: list[ChatCompletionMessageParam]
class ClassificationChatRequest(PoolingBasicRequestMixin, ChatRequestMixin):
# --8<-- [start:chat-classification-extra-params]
add_generation_prompt: bool = Field(
default=False,
description=(
"If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."
),
)
add_special_tokens: bool = Field(
default=False,
description=(
"If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. "
"For most models, the chat template takes care of adding the "
"special tokens so this should be set to false (as is the "
"default)."
),
)
chat_template: str | None = Field(
default=None,
description=(
"A Jinja template to use for this conversion. "
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
),
)
chat_template_kwargs: dict[str, Any] | None = Field(
default=None,
description=(
"Additional keyword args to pass to the template renderer. "
"Will be accessible by the chat template."
),
)
mm_processor_kwargs: dict[str, Any] | None = Field(
default=None,
description=("Additional kwargs to pass to the HF processor."),

View File

@@ -86,8 +86,8 @@ class ClassificationMixin(OpenAIServing):
ChatTemplateContentFormatOption,
getattr(self, "chat_template_content_format", "auto"),
),
add_generation_prompt=False,
continue_final_message=False,
add_generation_prompt=chat_request.add_generation_prompt,
continue_final_message=chat_request.continue_final_message,
add_special_tokens=chat_request.add_special_tokens,
)
ctx.engine_prompts = engine_prompts

View File

@@ -5,13 +5,12 @@ from typing import Any, TypeAlias
from pydantic import (
Field,
model_validator,
)
from vllm import PoolingParams
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import (
ChatRequestMixin,
CompletionRequestMixin,
PoolingBasicRequestMixin,
)
@@ -57,57 +56,11 @@ class EmbeddingCompletionRequest(PoolingBasicRequestMixin, CompletionRequestMixi
)
class EmbeddingChatRequest(PoolingBasicRequestMixin):
messages: list[ChatCompletionMessageParam]
class EmbeddingChatRequest(PoolingBasicRequestMixin, ChatRequestMixin):
encoding_format: EncodingFormat = "float"
dimensions: int | None = None
# --8<-- [start:chat-embedding-extra-params]
add_generation_prompt: bool = Field(
default=False,
description=(
"If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."
),
)
continue_final_message: bool = Field(
default=False,
description=(
"If this is set, the chat will be formatted so that the final "
"message in the chat is open-ended, without any EOS tokens. The "
"model will continue this message rather than starting a new one. "
'This allows you to "prefill" part of the model\'s response for it. '
"Cannot be used at the same time as `add_generation_prompt`."
),
)
add_special_tokens: bool = Field(
default=False,
description=(
"If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. "
"For most models, the chat template takes care of adding the "
"special tokens so this should be set to false (as is the "
"default)."
),
)
chat_template: str | None = Field(
default=None,
description=(
"A Jinja template to use for this conversion. "
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
),
)
chat_template_kwargs: dict[str, Any] | None = Field(
default=None,
description=(
"Additional keyword args to pass to the template renderer. "
"Will be accessible by the chat template."
),
)
mm_processor_kwargs: dict[str, Any] | None = Field(
default=None,
description=("Additional kwargs to pass to the HF processor."),
@@ -134,16 +87,6 @@ class EmbeddingChatRequest(PoolingBasicRequestMixin):
)
# --8<-- [end:chat-embedding-extra-params]
@model_validator(mode="before")
@classmethod
def check_generation_prompt(cls, data):
if data.get("continue_final_message") and data.get("add_generation_prompt"):
raise ValueError(
"Cannot set both `continue_final_message` and "
"`add_generation_prompt` to True."
)
return data
def to_pooling_params(self):
return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens,

View File

@@ -144,10 +144,8 @@ class OpenAIServingPooling(OpenAIServing):
request.messages,
chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
# In pooling requests, we are not generating tokens,
# so there is no need to append extra tokens to the input
add_generation_prompt=False,
continue_final_message=False,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
add_special_tokens=request.add_special_tokens,
)
elif isinstance(request, PoolingCompletionRequest):