[Frontend][2/n] Make pooling entrypoints request schema consensus | ChatRequest (#32574)

Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
wang.yuqi
2026-01-22 18:32:44 +08:00
committed by GitHub
parent 64e3d67ac0
commit 328cbb2773
24 changed files with 456 additions and 205 deletions

View File

@@ -79,7 +79,7 @@ The `post_process*` methods take `PoolingRequestOutput` objects as input and gen
The `validate_or_generate_params` method is used for validating with the plugin any `SamplingParameters`/`PoolingParameters` received with the user request, or to generate new ones if none are specified. The function always returns the validated/generated parameters.
The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/pooling/pooling/serving.py).
An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/pooling/plugin/prithvi_geospatial_mae_client.py](../../examples/pooling/plugin/prithvi_geospatial_mae_client.py)) and offline ([examples/pooling/plugin/prithvi_geospatial_mae_io_processor.py](../../examples/pooling/plugin/prithvi_geospatial_mae_io_processor.py)) inference examples.
An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/pooling/plugin/prithvi_geospatial_mae_online.py](../../examples/pooling/plugin/prithvi_geospatial_mae_online.py)) and offline ([examples/pooling/plugin/prithvi_geospatial_mae_io_processor.py](../../examples/pooling/plugin/prithvi_geospatial_mae_io_processor.py)) inference examples.
## Using an IO Processor plugin

View File

@@ -273,7 +273,7 @@ outputs = llm.embed(
print(outputs[0].outputs)
```
A code example can be found here: [examples/pooling/embed/embed_matryoshka_fy.py](../../examples/pooling/embed/embed_matryoshka_fy.py)
A code example can be found here: [examples/pooling/embed/embed_matryoshka_fy_offline.py](../../examples/pooling/embed/embed_matryoshka_fy_offline.py)
### Online Inference
@@ -303,7 +303,7 @@ Expected output:
{"id":"embd-5c21fc9a5c9d4384a1b021daccaf9f64","object":"list","created":1745476417,"model":"jinaai/jina-embeddings-v3","data":[{"index":0,"object":"embedding","embedding":[-0.3828125,-0.1357421875,0.03759765625,0.125,0.21875,0.09521484375,-0.003662109375,0.1591796875,-0.130859375,-0.0869140625,-0.1982421875,0.1689453125,-0.220703125,0.1728515625,-0.2275390625,-0.0712890625,-0.162109375,-0.283203125,-0.055419921875,-0.0693359375,0.031982421875,-0.04052734375,-0.2734375,0.1826171875,-0.091796875,0.220703125,0.37890625,-0.0888671875,-0.12890625,-0.021484375,-0.0091552734375,0.23046875]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0,"prompt_tokens_details":null}}
```
An OpenAI client example can be found here: [examples/pooling/embed/openai_embedding_matryoshka_fy.py](../../examples/pooling/embed/openai_embedding_matryoshka_fy.py)
An OpenAI client example can be found here: [examples/pooling/embed/openai_embedding_matryoshka_fy_client.py](../../examples/pooling/embed/openai_embedding_matryoshka_fy_client.py)
## Deprecated Features

View File

@@ -619,7 +619,7 @@ These models primarily support the [`LLM.encode`](./pooling_models.md#llmencode)
| `ModernBertForTokenClassification` | ModernBERT-based | `disham993/electrical-ner-ModernBERT-base` | | |
!!! note
Named Entity Recognition (NER) usage, please refer to [examples/pooling/token_classify/ner.py](../../examples/pooling/token_classify/ner.py), [examples/pooling/token_classify/ner_client.py](../../examples/pooling/token_classify/ner_client.py).
Named Entity Recognition (NER) usage, please refer to [examples/pooling/token_classify/ner_offline.py](../../examples/pooling/token_classify/ner_offline.py), [examples/pooling/token_classify/ner_online.py](../../examples/pooling/token_classify/ner_online.py).
## List of Multimodal Language Models

View File

@@ -551,7 +551,7 @@ Our Pooling API encodes input prompts using a [pooling model](../models/pooling_
The input format is the same as [Embeddings API](#embeddings-api), but the output data can contain an arbitrary nested list, not just a 1-D list of floats.
Code example: [examples/pooling/pooling/openai_pooling_client.py](../../examples/pooling/pooling/openai_pooling_client.py)
Code example: [examples/pooling/pooling/pooling_online.py](../../examples/pooling/pooling/pooling_online.py)
### Classification API

View File

@@ -26,36 +26,42 @@ def post_http_request(prompt: dict, api_url: str) -> requests.Response:
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--model", type=str, default="intfloat/e5-small")
return parser.parse_args()
parse = argparse.ArgumentParser()
parse.add_argument("--host", type=str, default="localhost")
parse.add_argument("--port", type=int, default=8000)
return parse.parse_args()
def main(args):
api_url = f"http://{args.host}:{args.port}/v1/embeddings"
model_name = args.model
base_url = f"http://{args.host}:{args.port}"
models_url = base_url + "/v1/models"
embeddings_url = base_url + "/v1/embeddings"
response = requests.get(models_url)
model = response.json()["data"][0]["id"]
input_texts = [
"The best thing about vLLM is that it supports many different models",
] * 2
# The OpenAI client does not support the embed_dtype and endianness parameters.
for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE:
for endianness in ENDIANNESS:
prompt = {
"model": model_name,
"input": "vLLM is great!",
"model": model,
"input": input_texts,
"encoding_format": "base64",
"embed_dtype": embed_dtype,
"endianness": endianness,
}
response = post_http_request(prompt=prompt, api_url=api_url)
response = post_http_request(prompt=prompt, api_url=embeddings_url)
embedding = []
for data in response.json()["data"]:
binary = base64.b64decode(data["embedding"])
tensor = binary2tensor(binary, (-1,), embed_dtype, endianness)
embedding.append(tensor.to(torch.float32))
embedding = torch.cat(embedding)
embedding = torch.stack(embedding)
print(embed_dtype, endianness, embedding.shape)

View File

@@ -31,14 +31,18 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--model", type=str, default="intfloat/e5-small")
return parser.parse_args()
def main(args):
api_url = f"http://{args.host}:{args.port}/v1/embeddings"
model_name = args.model
base_url = f"http://{args.host}:{args.port}"
models_url = base_url + "/v1/models"
embeddings_url = base_url + "/v1/embeddings"
response = requests.get(models_url)
model = response.json()["data"][0]["id"]
embedding_size = 0
input_texts = [
@@ -50,13 +54,13 @@ def main(args):
for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE:
for endianness in ENDIANNESS:
prompt = {
"model": model_name,
"model": model,
"input": input_texts,
"encoding_format": "bytes",
"embed_dtype": embed_dtype,
"endianness": endianness,
}
response = post_http_request(prompt=prompt, api_url=api_url)
response = post_http_request(prompt=prompt, api_url=embeddings_url)
metadata = json.loads(response.headers["metadata"])
body = response.content
items = [MetadataItem(**x) for x in metadata["data"]]
@@ -73,13 +77,13 @@ def main(args):
for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE:
for endianness in ENDIANNESS:
prompt = {
"model": model_name,
"model": model,
"input": input_texts,
"encoding_format": "bytes_only",
"embed_dtype": embed_dtype,
"endianness": endianness,
}
response = post_http_request(prompt=prompt, api_url=api_url)
response = post_http_request(prompt=prompt, api_url=embeddings_url)
body = response.content
items = build_metadata_items(

View File

@@ -25,18 +25,21 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--model", type=str, default="internlm/internlm2-1_8b-reward")
return parser.parse_args()
def main(args):
api_url = f"http://{args.host}:{args.port}/pooling"
model_name = args.model
base_url = f"http://{args.host}:{args.port}"
models_url = base_url + "/v1/models"
pooing_url = base_url + "/pooling"
response = requests.get(models_url)
model = response.json()["data"][0]["id"]
# Input like Completions API
prompt = {"model": model_name, "input": "vLLM is great!"}
pooling_response = post_http_request(prompt=prompt, api_url=api_url)
prompt = {"model": model, "input": "vLLM is great!"}
pooling_response = post_http_request(prompt=prompt, api_url=pooing_url)
print("-" * 50)
print("Pooling Response:")
pprint.pprint(pooling_response.json())
@@ -44,7 +47,7 @@ def main(args):
# Input like Chat API
prompt = {
"model": model_name,
"model": model,
"messages": [
{
"role": "user",
@@ -52,7 +55,7 @@ def main(args):
}
],
}
pooling_response = post_http_request(prompt=prompt, api_url=api_url)
pooling_response = post_http_request(prompt=prompt, api_url=pooing_url)
print("Pooling Response:")
pprint.pprint(pooling_response.json())
print("-" * 50)

View File

@@ -167,7 +167,8 @@ def test_truncate_prompt_tokens(server: RemoteOpenAIServer, model_name: str):
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_add_special_tokens(server: RemoteOpenAIServer, model_name: str):
# FIXME: The add_special_tokens parameter doesn't seem to be working.
# The add_special_tokens parameter doesn't seem to be working with this model.
# working with papluca/xlm-roberta-base-language-detection
response = requests.post(
server.url_for("classify"),
json={"model": model_name, "input": input_text, "add_special_tokens": False},
@@ -184,7 +185,110 @@ def test_add_special_tokens(server: RemoteOpenAIServer, model_name: str):
@pytest.mark.asyncio
async def test_invocations(server: RemoteOpenAIServer):
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_chat_request(server: RemoteOpenAIServer, model_name: str):
messages = [
{
"role": "user",
"content": "The cat sat on the mat.",
},
{
"role": "assistant",
"content": "A feline was resting on a rug.",
},
{
"role": "user",
"content": "Stars twinkle brightly in the night sky.",
},
]
# test chat request basic usage
response = requests.post(
server.url_for("classify"),
json={"model": model_name, "messages": messages},
)
response.raise_for_status()
output = ClassificationResponse.model_validate(response.json())
assert output.object == "list"
assert output.model == MODEL_NAME
assert len(output.data) == 1
assert hasattr(output.data[0], "label")
assert hasattr(output.data[0], "probs")
assert output.usage.prompt_tokens == 51
# test add_generation_prompt
response = requests.post(
server.url_for("classify"),
json={"model": model_name, "messages": messages, "add_generation_prompt": True},
)
response.raise_for_status()
output = ClassificationResponse.model_validate(response.json())
assert output.object == "list"
assert output.model == MODEL_NAME
assert len(output.data) == 1
assert hasattr(output.data[0], "label")
assert hasattr(output.data[0], "probs")
assert output.usage.prompt_tokens == 54
# test continue_final_message
response = requests.post(
server.url_for("classify"),
json={
"model": model_name,
"messages": messages,
"continue_final_message": True,
},
)
response.raise_for_status()
output = ClassificationResponse.model_validate(response.json())
assert output.object == "list"
assert output.model == MODEL_NAME
assert len(output.data) == 1
assert hasattr(output.data[0], "label")
assert hasattr(output.data[0], "probs")
assert output.usage.prompt_tokens == 49
# test add_special_tokens
# The add_special_tokens parameter doesn't seem to be working with this model.
response = requests.post(
server.url_for("classify"),
json={"model": model_name, "messages": messages, "add_special_tokens": True},
)
response.raise_for_status()
output = ClassificationResponse.model_validate(response.json())
assert output.object == "list"
assert output.model == MODEL_NAME
assert len(output.data) == 1
assert hasattr(output.data[0], "label")
assert hasattr(output.data[0], "probs")
assert output.usage.prompt_tokens == 51
# test continue_final_message with add_generation_prompt
response = requests.post(
server.url_for("classify"),
json={
"model": model_name,
"messages": messages,
"continue_final_message": True,
"add_generation_prompt": True,
},
)
assert (
"Cannot set both `continue_final_message` and `add_generation_prompt` to True."
in response.json()["error"]["message"]
)
@pytest.mark.asyncio
async def test_invocations_completion_request(server: RemoteOpenAIServer):
request_args = {
"model": MODEL_NAME,
"input": input_text,
@@ -213,6 +317,48 @@ async def test_invocations(server: RemoteOpenAIServer):
)
@pytest.mark.asyncio
async def test_invocations_chat_request(server: RemoteOpenAIServer):
messages = [
{
"role": "user",
"content": "The cat sat on the mat.",
},
{
"role": "assistant",
"content": "A feline was resting on a rug.",
},
{
"role": "user",
"content": "Stars twinkle brightly in the night sky.",
},
]
request_args = {"model": MODEL_NAME, "messages": messages}
classification_response = requests.post(
server.url_for("classify"), json=request_args
)
classification_response.raise_for_status()
invocation_response = requests.post(
server.url_for("invocations"), json=request_args
)
invocation_response.raise_for_status()
classification_output = classification_response.json()
invocation_output = invocation_response.json()
assert classification_output.keys() == invocation_output.keys()
for classification_data, invocation_data in zip(
classification_output["data"], invocation_output["data"]
):
assert classification_data.keys() == invocation_data.keys()
assert classification_data["probs"] == pytest.approx(
invocation_data["probs"], rel=0.01
)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_use_activation(server: RemoteOpenAIServer, model_name: str):

View File

@@ -214,64 +214,6 @@ async def test_completion_request_batched(
run_embedding_correctness_test(hf_model, input_texts, vllm_outputs)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_conversation_embedding(
server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str
):
messages = [
{
"role": "user",
"content": "The cat sat on the mat.",
},
{
"role": "assistant",
"content": "A feline was resting on a rug.",
},
{
"role": "user",
"content": "Stars twinkle brightly in the night sky.",
},
]
chat_response = requests.post(
server.url_for("v1/embeddings"),
json={
"model": model_name,
"messages": messages,
"encoding_format": "float",
},
)
chat_response.raise_for_status()
chat_embeddings = EmbeddingResponse.model_validate(chat_response.json())
tokenizer = get_tokenizer(tokenizer_name=model_name)
prompt = tokenizer.apply_chat_template(
messages,
chat_template=DUMMY_CHAT_TEMPLATE,
add_generation_prompt=True,
continue_final_message=False,
tokenize=False,
)
completion_response = await client.embeddings.create(
model=model_name,
input=prompt,
encoding_format="float",
# To be consistent with chat
extra_body={"add_special_tokens": False},
)
completion_embeddings = EmbeddingResponse.model_validate(
completion_response.model_dump(mode="json")
)
assert chat_embeddings.id is not None
assert completion_embeddings.id is not None
assert chat_embeddings.created <= completion_embeddings.created
assert chat_embeddings.model_dump(exclude={"id", "created"}) == (
completion_embeddings.model_dump(exclude={"id", "created"})
)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_truncate_prompt_tokens(client: openai.AsyncOpenAI, model_name: str):
@@ -350,7 +292,129 @@ async def test_truncate_prompt_tokens(client: openai.AsyncOpenAI, model_name: st
@pytest.mark.asyncio
async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenAI):
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_chat_request(
server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str
):
messages = [
{
"role": "user",
"content": "The cat sat on the mat.",
},
{
"role": "assistant",
"content": "A feline was resting on a rug.",
},
{
"role": "user",
"content": "Stars twinkle brightly in the night sky.",
},
]
# test chat request basic usage
chat_response = requests.post(
server.url_for("v1/embeddings"),
json={
"model": model_name,
"messages": messages,
"encoding_format": "float",
},
)
chat_response.raise_for_status()
chat_embeddings = EmbeddingResponse.model_validate(chat_response.json())
tokenizer = get_tokenizer(tokenizer_name=model_name)
prompt = tokenizer.apply_chat_template(
messages,
chat_template=DUMMY_CHAT_TEMPLATE,
add_generation_prompt=True,
continue_final_message=False,
tokenize=False,
)
completion_response = await client.embeddings.create(
model=model_name,
input=prompt,
encoding_format="float",
# To be consistent with chat
extra_body={"add_special_tokens": False},
)
completion_embeddings = EmbeddingResponse.model_validate(
completion_response.model_dump(mode="json")
)
assert chat_embeddings.id is not None
assert completion_embeddings.id is not None
assert chat_embeddings.created <= completion_embeddings.created
assert chat_embeddings.model_dump(exclude={"id", "created"}) == (
completion_embeddings.model_dump(exclude={"id", "created"})
)
# test add_generation_prompt
response = requests.post(
server.url_for("v1/embeddings"),
json={"model": model_name, "messages": messages, "add_generation_prompt": True},
)
response.raise_for_status()
output = EmbeddingResponse.model_validate(response.json())
assert output.object == "list"
assert len(output.data) == 1
assert output.model == MODEL_NAME
assert output.usage.prompt_tokens == 34
# test continue_final_message
response = requests.post(
server.url_for("v1/embeddings"),
json={
"model": model_name,
"messages": messages,
"continue_final_message": True,
},
)
response.raise_for_status()
output = EmbeddingResponse.model_validate(response.json())
assert output.object == "list"
assert len(output.data) == 1
assert output.model == MODEL_NAME
assert output.usage.prompt_tokens == 33
# test add_special_tokens
response = requests.post(
server.url_for("v1/embeddings"),
json={"model": model_name, "messages": messages, "add_special_tokens": True},
)
response.raise_for_status()
output = EmbeddingResponse.model_validate(response.json())
assert output.object == "list"
assert len(output.data) == 1
assert output.model == MODEL_NAME
assert output.usage.prompt_tokens == 36
# test continue_final_message with add_generation_prompt
response = requests.post(
server.url_for("v1/embeddings"),
json={
"model": model_name,
"messages": messages,
"continue_final_message": True,
"add_generation_prompt": True,
},
)
assert (
"Cannot set both `continue_final_message` and `add_generation_prompt` to True."
in response.json()["error"]["message"]
)
@pytest.mark.asyncio
async def test_invocations_completion_request(
server: RemoteOpenAIServer, client: openai.AsyncOpenAI
):
request_args = {
"model": MODEL_NAME,
"input": input_text,
@@ -381,7 +445,7 @@ async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenA
@pytest.mark.asyncio
async def test_invocations_conversation(server: RemoteOpenAIServer):
async def test_invocations_chat_request(server: RemoteOpenAIServer):
messages = [
{
"role": "user",

View File

@@ -138,7 +138,7 @@ def test_completion_request_batched(server: RemoteOpenAIServer, model_name: str)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_conversation_pooling(server: RemoteOpenAIServer, model_name: str):
async def test_chat_request(server: RemoteOpenAIServer, model_name: str):
messages = [
{
"role": "user",
@@ -154,6 +154,7 @@ async def test_conversation_pooling(server: RemoteOpenAIServer, model_name: str)
},
]
# test chat request basic usage
chat_response = requests.post(
server.url_for("pooling"),
json={
@@ -193,6 +194,68 @@ async def test_conversation_pooling(server: RemoteOpenAIServer, model_name: str)
completion_poolings.model_dump(exclude={"id", "created"})
)
# test add_generation_prompt
response = requests.post(
server.url_for("pooling"),
json={"model": model_name, "messages": messages, "add_generation_prompt": True},
)
response.raise_for_status()
output = PoolingResponse.model_validate(response.json())
assert output.object == "list"
assert len(output.data) == 1
assert output.model == MODEL_NAME
assert output.usage.prompt_tokens == 33
# test continue_final_message
# The continue_final_message parameter doesn't seem to be working with this model.
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"messages": messages,
"continue_final_message": True,
},
)
response.raise_for_status()
output = PoolingResponse.model_validate(response.json())
assert output.object == "list"
assert len(output.data) == 1
assert output.model == MODEL_NAME
assert output.usage.prompt_tokens == 33
# test add_special_tokens
response = requests.post(
server.url_for("pooling"),
json={"model": model_name, "messages": messages, "add_special_tokens": True},
)
response.raise_for_status()
output = PoolingResponse.model_validate(response.json())
assert output.object == "list"
assert len(output.data) == 1
assert output.model == MODEL_NAME
assert output.usage.prompt_tokens == 34
# test continue_final_message with add_generation_prompt
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"messages": messages,
"continue_final_message": True,
"add_generation_prompt": True,
},
)
assert (
"Cannot set both `continue_final_message` and `add_generation_prompt` to True."
in response.json()["error"]["message"]
)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@@ -430,7 +493,7 @@ async def test_params_not_supported(
@pytest.mark.asyncio
async def test_invocations(server: RemoteOpenAIServer):
async def test_invocations_chat_request(server: RemoteOpenAIServer):
request_args = {
"model": MODEL_NAME,
"input": input_text,
@@ -462,7 +525,7 @@ async def test_invocations(server: RemoteOpenAIServer):
@pytest.mark.asyncio
async def test_invocations_conversation(server: RemoteOpenAIServer):
async def test_invocations_conversation_chat_request(server: RemoteOpenAIServer):
messages = [
{
"role": "user",

View File

@@ -2,10 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Annotated
from typing import Annotated, Any
from pydantic import Field
from pydantic import Field, model_validator
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
from vllm.utils import random_uuid
@@ -44,3 +45,66 @@ class CompletionRequestMixin(OpenAIBaseModel):
"the prompt."
),
)
class ChatRequestMixin(OpenAIBaseModel):
messages: list[ChatCompletionMessageParam]
add_generation_prompt: bool = Field(
default=False,
description=(
"If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."
),
)
continue_final_message: bool = Field(
default=False,
description=(
"If this is set, the chat will be formatted so that the final "
"message in the chat is open-ended, without any EOS tokens. The "
"model will continue this message rather than starting a new one. "
'This allows you to "prefill" part of the model\'s response for it. '
"Cannot be used at the same time as `add_generation_prompt`."
),
)
add_special_tokens: bool = Field(
default=False,
description=(
"If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. "
"For most models, the chat template takes care of adding the "
"special tokens so this should be set to false (as is the "
"default)."
),
)
chat_template: str | None = Field(
default=None,
description=(
"A Jinja template to use for this conversion. "
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
),
)
chat_template_kwargs: dict[str, Any] | None = Field(
default=None,
description=(
"Additional keyword args to pass to the template renderer. "
"Will be accessible by the chat template."
),
)
@model_validator(mode="before")
@classmethod
def check_generation_prompt(cls, data):
if data.get("continue_final_message") and data.get("add_generation_prompt"):
raise ValueError(
"Cannot set both `continue_final_message` and "
"`add_generation_prompt` to True."
)
return data

View File

@@ -10,9 +10,9 @@ from pydantic import (
from vllm import PoolingParams
from vllm.config.pooler import get_use_activation
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import (
ChatRequestMixin,
CompletionRequestMixin,
PoolingBasicRequestMixin,
)
@@ -45,48 +45,8 @@ class ClassificationCompletionRequest(PoolingBasicRequestMixin, CompletionReques
)
class ClassificationChatRequest(PoolingBasicRequestMixin):
messages: list[ChatCompletionMessageParam]
class ClassificationChatRequest(PoolingBasicRequestMixin, ChatRequestMixin):
# --8<-- [start:chat-classification-extra-params]
add_generation_prompt: bool = Field(
default=False,
description=(
"If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."
),
)
add_special_tokens: bool = Field(
default=False,
description=(
"If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. "
"For most models, the chat template takes care of adding the "
"special tokens so this should be set to false (as is the "
"default)."
),
)
chat_template: str | None = Field(
default=None,
description=(
"A Jinja template to use for this conversion. "
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
),
)
chat_template_kwargs: dict[str, Any] | None = Field(
default=None,
description=(
"Additional keyword args to pass to the template renderer. "
"Will be accessible by the chat template."
),
)
mm_processor_kwargs: dict[str, Any] | None = Field(
default=None,
description=("Additional kwargs to pass to the HF processor."),

View File

@@ -86,8 +86,8 @@ class ClassificationMixin(OpenAIServing):
ChatTemplateContentFormatOption,
getattr(self, "chat_template_content_format", "auto"),
),
add_generation_prompt=False,
continue_final_message=False,
add_generation_prompt=chat_request.add_generation_prompt,
continue_final_message=chat_request.continue_final_message,
add_special_tokens=chat_request.add_special_tokens,
)
ctx.engine_prompts = engine_prompts

View File

@@ -5,13 +5,12 @@ from typing import Any, TypeAlias
from pydantic import (
Field,
model_validator,
)
from vllm import PoolingParams
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
from vllm.entrypoints.pooling.base.protocol import (
ChatRequestMixin,
CompletionRequestMixin,
PoolingBasicRequestMixin,
)
@@ -57,57 +56,11 @@ class EmbeddingCompletionRequest(PoolingBasicRequestMixin, CompletionRequestMixi
)
class EmbeddingChatRequest(PoolingBasicRequestMixin):
messages: list[ChatCompletionMessageParam]
class EmbeddingChatRequest(PoolingBasicRequestMixin, ChatRequestMixin):
encoding_format: EncodingFormat = "float"
dimensions: int | None = None
# --8<-- [start:chat-embedding-extra-params]
add_generation_prompt: bool = Field(
default=False,
description=(
"If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."
),
)
continue_final_message: bool = Field(
default=False,
description=(
"If this is set, the chat will be formatted so that the final "
"message in the chat is open-ended, without any EOS tokens. The "
"model will continue this message rather than starting a new one. "
'This allows you to "prefill" part of the model\'s response for it. '
"Cannot be used at the same time as `add_generation_prompt`."
),
)
add_special_tokens: bool = Field(
default=False,
description=(
"If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. "
"For most models, the chat template takes care of adding the "
"special tokens so this should be set to false (as is the "
"default)."
),
)
chat_template: str | None = Field(
default=None,
description=(
"A Jinja template to use for this conversion. "
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
),
)
chat_template_kwargs: dict[str, Any] | None = Field(
default=None,
description=(
"Additional keyword args to pass to the template renderer. "
"Will be accessible by the chat template."
),
)
mm_processor_kwargs: dict[str, Any] | None = Field(
default=None,
description=("Additional kwargs to pass to the HF processor."),
@@ -134,16 +87,6 @@ class EmbeddingChatRequest(PoolingBasicRequestMixin):
)
# --8<-- [end:chat-embedding-extra-params]
@model_validator(mode="before")
@classmethod
def check_generation_prompt(cls, data):
if data.get("continue_final_message") and data.get("add_generation_prompt"):
raise ValueError(
"Cannot set both `continue_final_message` and "
"`add_generation_prompt` to True."
)
return data
def to_pooling_params(self):
return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens,

View File

@@ -144,10 +144,8 @@ class OpenAIServingPooling(OpenAIServing):
request.messages,
chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
# In pooling requests, we are not generating tokens,
# so there is no need to append extra tokens to the input
add_generation_prompt=False,
continue_final_message=False,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
add_special_tokens=request.add_special_tokens,
)
elif isinstance(request, PoolingCompletionRequest):