[Frontend][1/n] Make pooling entrypoints request schema consensus | CompletionRequest (#32395)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
@@ -559,7 +559,7 @@ Our Classification API directly supports Hugging Face sequence-classification mo
|
||||
|
||||
We automatically wrap any other transformer via `as_seq_cls_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities.
|
||||
|
||||
Code example: [examples/pooling/classify/openai_classification_client.py](../../examples/pooling/classify/openai_classification_client.py)
|
||||
Code example: [examples/pooling/classify/classification_online.py](../../examples/pooling/classify/classification_online.py)
|
||||
|
||||
#### Example Requests
|
||||
|
||||
|
||||
67
examples/pooling/classify/classification_online.py
Normal file
67
examples/pooling/classify/classification_online.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Example Python client for classification API using vLLM API server
|
||||
NOTE:
|
||||
start a supported classification model server with `vllm serve`, e.g.
|
||||
vllm serve jason9693/Qwen2.5-1.5B-apeach
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pprint
|
||||
|
||||
import requests
|
||||
|
||||
headers = {"accept": "application/json", "Content-Type": "application/json"}
|
||||
|
||||
|
||||
def parse_args():
|
||||
parse = argparse.ArgumentParser()
|
||||
parse.add_argument("--host", type=str, default="localhost")
|
||||
parse.add_argument("--port", type=int, default=8000)
|
||||
return parse.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
base_url = f"http://{args.host}:{args.port}"
|
||||
models_url = base_url + "/v1/models"
|
||||
classify_url = base_url + "/classify"
|
||||
tokenize_url = base_url + "/tokenize"
|
||||
|
||||
response = requests.get(models_url, headers=headers)
|
||||
model = response.json()["data"][0]["id"]
|
||||
|
||||
# /classify can accept str as input
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"input": prompts,
|
||||
}
|
||||
response = requests.post(classify_url, headers=headers, json=payload)
|
||||
pprint.pprint(response.json())
|
||||
|
||||
# /classify can accept token ids as input
|
||||
token_ids = []
|
||||
for prompt in prompts:
|
||||
response = requests.post(
|
||||
tokenize_url,
|
||||
json={"model": model, "prompt": prompt},
|
||||
)
|
||||
token_ids.append(response.json()["tokens"])
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"input": token_ids,
|
||||
}
|
||||
response = requests.post(classify_url, headers=headers, json=payload)
|
||||
pprint.pprint(response.json())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
@@ -1,53 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Example Python client for classification API using vLLM API server
|
||||
NOTE:
|
||||
start a supported classification model server with `vllm serve`, e.g.
|
||||
vllm serve jason9693/Qwen2.5-1.5B-apeach
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pprint
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def post_http_request(payload: dict, api_url: str) -> requests.Response:
|
||||
headers = {"User-Agent": "Test Client"}
|
||||
response = requests.post(api_url, headers=headers, json=payload)
|
||||
return response
|
||||
|
||||
|
||||
def parse_args():
|
||||
parse = argparse.ArgumentParser()
|
||||
parse.add_argument("--host", type=str, default="localhost")
|
||||
parse.add_argument("--port", type=int, default=8000)
|
||||
parse.add_argument("--model", type=str, default="jason9693/Qwen2.5-1.5B-apeach")
|
||||
return parse.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
host = args.host
|
||||
port = args.port
|
||||
model_name = args.model
|
||||
|
||||
api_url = f"http://{host}:{port}/classify"
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
payload = {
|
||||
"model": model_name,
|
||||
"input": prompts,
|
||||
}
|
||||
|
||||
classify_response = post_http_request(payload=payload, api_url=api_url)
|
||||
pprint.pprint(classify_response.json())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
@@ -16,7 +16,7 @@ from typing import Literal, NamedTuple, TypeAlias, TypedDict, get_args
|
||||
from PIL.Image import Image
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.entrypoints.score_utils import ScoreMultiModalParam
|
||||
from vllm.entrypoints.pooling.score.utils import ScoreMultiModalParam
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from pathlib import Path
|
||||
from typing import NamedTuple
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.entrypoints.score_utils import ScoreMultiModalParam
|
||||
from vllm.entrypoints.pooling.score.utils import ScoreMultiModalParam
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
TEMPLATE_HOME = Path(__file__).parent / "template"
|
||||
|
||||
@@ -12,6 +12,8 @@ from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse
|
||||
|
||||
MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"
|
||||
DTYPE = "float32" # Use float32 to avoid NaN issue
|
||||
input_text = "This product was excellent and exceeded my expectations"
|
||||
input_tokens = [1986, 1985, 572, 9073, 323, 33808, 847, 16665]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -29,9 +31,23 @@ def server():
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_single_input_classification(server: RemoteOpenAIServer, model_name: str):
|
||||
input_text = "This product was excellent and exceeded my expectations"
|
||||
def test_basic(server: RemoteOpenAIServer, model_name: str):
|
||||
# test /v1/models
|
||||
response = requests.get(server.url_for("/v1/models"))
|
||||
served_model = response.json()["data"][0]["id"]
|
||||
assert served_model == MODEL_NAME
|
||||
|
||||
# test /tokenize
|
||||
response = requests.post(
|
||||
server.url_for("/tokenize"),
|
||||
json={"model": model_name, "prompt": input_text},
|
||||
)
|
||||
assert response.json()["tokens"] == input_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_completion_request(server: RemoteOpenAIServer, model_name: str):
|
||||
# test input: str
|
||||
classification_response = requests.post(
|
||||
server.url_for("classify"),
|
||||
json={"model": model_name, "input": input_text},
|
||||
@@ -46,35 +62,34 @@ def test_single_input_classification(server: RemoteOpenAIServer, model_name: str
|
||||
assert hasattr(output.data[0], "label")
|
||||
assert hasattr(output.data[0], "probs")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_add_special_tokens_false(server: RemoteOpenAIServer, model_name: str):
|
||||
response = requests.post(
|
||||
server.url_for("classify"),
|
||||
json={"model": model_name, "input": "hello", "add_special_tokens": False},
|
||||
)
|
||||
response.raise_for_status()
|
||||
ClassificationResponse.model_validate(response.json())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_multiple_inputs_classification(server: RemoteOpenAIServer, model_name: str):
|
||||
input_texts = [
|
||||
"The product arrived on time and works perfectly",
|
||||
"I'm very satisfied with my purchase, would buy again",
|
||||
"The customer service was helpful and resolved my issue quickly",
|
||||
"This product broke after one week, terrible quality",
|
||||
"I'm very disappointed with this purchase, complete waste of money",
|
||||
"The customer service was rude and unhelpful",
|
||||
]
|
||||
|
||||
# test input: list[int]
|
||||
classification_response = requests.post(
|
||||
server.url_for("classify"),
|
||||
json={"model": model_name, "input": input_texts},
|
||||
json={"model": model_name, "input": input_tokens},
|
||||
)
|
||||
|
||||
classification_response.raise_for_status()
|
||||
output = ClassificationResponse.model_validate(classification_response.json())
|
||||
|
||||
assert output.object == "list"
|
||||
assert output.model == MODEL_NAME
|
||||
assert len(output.data) == 1
|
||||
assert hasattr(output.data[0], "label")
|
||||
assert hasattr(output.data[0], "probs")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_completion_request_batched(server: RemoteOpenAIServer, model_name: str):
|
||||
N = 10
|
||||
|
||||
# test input: list[str]
|
||||
classification_response = requests.post(
|
||||
server.url_for("classify"),
|
||||
json={"model": model_name, "input": [input_text] * N},
|
||||
)
|
||||
output = ClassificationResponse.model_validate(classification_response.json())
|
||||
|
||||
assert len(output.data) == len(input_texts)
|
||||
assert len(output.data) == N
|
||||
for i, item in enumerate(output.data):
|
||||
assert item.index == i
|
||||
assert hasattr(item, "label")
|
||||
@@ -82,6 +97,44 @@ def test_multiple_inputs_classification(server: RemoteOpenAIServer, model_name:
|
||||
assert len(item.probs) == item.num_classes
|
||||
assert item.label in ["Default", "Spoiled"]
|
||||
|
||||
# test input: list[list[int]]
|
||||
classification_response = requests.post(
|
||||
server.url_for("classify"),
|
||||
json={"model": model_name, "input": [input_tokens] * N},
|
||||
)
|
||||
output = ClassificationResponse.model_validate(classification_response.json())
|
||||
|
||||
assert len(output.data) == N
|
||||
for i, item in enumerate(output.data):
|
||||
assert item.index == i
|
||||
assert hasattr(item, "label")
|
||||
assert hasattr(item, "probs")
|
||||
assert len(item.probs) == item.num_classes
|
||||
assert item.label in ["Default", "Spoiled"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_empty_input_error(server: RemoteOpenAIServer, model_name: str):
|
||||
classification_response = requests.post(
|
||||
server.url_for("classify"),
|
||||
json={"model": model_name, "input": ""},
|
||||
)
|
||||
|
||||
error = classification_response.json()
|
||||
assert classification_response.status_code == 400
|
||||
assert "error" in error
|
||||
|
||||
classification_response = requests.post(
|
||||
server.url_for("classify"),
|
||||
json={"model": model_name, "input": []},
|
||||
)
|
||||
classification_response.raise_for_status()
|
||||
output = ClassificationResponse.model_validate(classification_response.json())
|
||||
|
||||
assert output.object == "list"
|
||||
assert isinstance(output.data, list)
|
||||
assert len(output.data) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_truncate_prompt_tokens(server: RemoteOpenAIServer, model_name: str):
|
||||
@@ -101,11 +154,7 @@ def test_truncate_prompt_tokens(server: RemoteOpenAIServer, model_name: str):
|
||||
assert output.usage.prompt_tokens == 5
|
||||
assert output.usage.total_tokens == 5
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_invalid_truncate_prompt_tokens_error(
|
||||
server: RemoteOpenAIServer, model_name: str
|
||||
):
|
||||
# invalid_truncate_prompt_tokens
|
||||
classification_response = requests.post(
|
||||
server.url_for("classify"),
|
||||
json={"model": model_name, "input": "test", "truncate_prompt_tokens": 513},
|
||||
@@ -117,36 +166,28 @@ def test_invalid_truncate_prompt_tokens_error(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_empty_input_error(server: RemoteOpenAIServer, model_name: str):
|
||||
classification_response = requests.post(
|
||||
def test_add_special_tokens(server: RemoteOpenAIServer, model_name: str):
|
||||
# FIXME: The add_special_tokens parameter doesn't seem to be working.
|
||||
response = requests.post(
|
||||
server.url_for("classify"),
|
||||
json={"model": model_name, "input": ""},
|
||||
json={"model": model_name, "input": input_text, "add_special_tokens": False},
|
||||
)
|
||||
response.raise_for_status()
|
||||
ClassificationResponse.model_validate(response.json())
|
||||
|
||||
error = classification_response.json()
|
||||
assert classification_response.status_code == 400
|
||||
assert "error" in error
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_batch_classification_empty_list(server: RemoteOpenAIServer, model_name: str):
|
||||
classification_response = requests.post(
|
||||
response = requests.post(
|
||||
server.url_for("classify"),
|
||||
json={"model": model_name, "input": []},
|
||||
json={"model": model_name, "input": input_text, "add_special_tokens": True},
|
||||
)
|
||||
classification_response.raise_for_status()
|
||||
output = ClassificationResponse.model_validate(classification_response.json())
|
||||
|
||||
assert output.object == "list"
|
||||
assert isinstance(output.data, list)
|
||||
assert len(output.data) == 0
|
||||
response.raise_for_status()
|
||||
ClassificationResponse.model_validate(response.json())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations(server: RemoteOpenAIServer):
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"input": "This product was excellent and exceeded my expectations",
|
||||
"input": input_text,
|
||||
}
|
||||
|
||||
classification_response = requests.post(
|
||||
@@ -175,8 +216,6 @@ async def test_invocations(server: RemoteOpenAIServer):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_use_activation(server: RemoteOpenAIServer, model_name: str):
|
||||
input_text = ["This product was excellent and exceeded my expectations"]
|
||||
|
||||
async def get_outputs(use_activation):
|
||||
response = requests.post(
|
||||
server.url_for("classify"),
|
||||
@@ -237,7 +276,6 @@ async def test_rerank(server: RemoteOpenAIServer, model_name: str):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str):
|
||||
input_text = "This product was excellent and exceeded my expectations"
|
||||
response = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json={
|
||||
@@ -256,7 +294,6 @@ async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str):
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str):
|
||||
task = "token_classify"
|
||||
input_text = ["This product was excellent and exceeded my expectations"]
|
||||
response = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json={
|
||||
@@ -282,7 +319,7 @@ async def test_pooling_not_supported(
|
||||
server.url_for("pooling"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": "test",
|
||||
"input": input_text,
|
||||
"encoding_format": "float",
|
||||
"task": task,
|
||||
},
|
||||
|
||||
@@ -31,7 +31,26 @@ from vllm.utils.serial_utils import (
|
||||
MODEL_NAME = "intfloat/multilingual-e5-small"
|
||||
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
|
||||
DTYPE = "bfloat16"
|
||||
|
||||
input_text = "The best thing about vLLM is that it supports many different models"
|
||||
input_tokens = [
|
||||
0,
|
||||
581,
|
||||
2965,
|
||||
13580,
|
||||
1672,
|
||||
81,
|
||||
23708,
|
||||
594,
|
||||
83,
|
||||
450,
|
||||
442,
|
||||
8060,
|
||||
7,
|
||||
5941,
|
||||
12921,
|
||||
115774,
|
||||
2,
|
||||
]
|
||||
|
||||
if current_platform.is_rocm():
|
||||
# Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers
|
||||
@@ -79,15 +98,36 @@ def hf_model(hf_runner):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_single_embedding(hf_model, client: openai.AsyncOpenAI, model_name: str):
|
||||
input_texts = [
|
||||
"The chef prepared a delicious meal.",
|
||||
]
|
||||
async def test_basic(
|
||||
server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str
|
||||
):
|
||||
# test /v1/models
|
||||
response = requests.get(server.url_for("/v1/models"))
|
||||
model = response.json()["data"][0]["id"]
|
||||
assert model == MODEL_NAME
|
||||
|
||||
# test single embedding
|
||||
models = await client.models.list()
|
||||
models = models.data
|
||||
served_model = models[0]
|
||||
assert served_model.id == MODEL_NAME
|
||||
|
||||
# test /tokenize
|
||||
response = requests.post(
|
||||
server.url_for("/tokenize"),
|
||||
json={"model": model_name, "prompt": input_text},
|
||||
)
|
||||
assert response.json()["tokens"] == input_tokens
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_completion_request(
|
||||
client: openai.AsyncOpenAI, model_name: str, hf_model
|
||||
):
|
||||
# test input: str
|
||||
embedding_response = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=input_texts,
|
||||
input=input_text,
|
||||
encoding_format="float",
|
||||
)
|
||||
embeddings = EmbeddingResponse.model_validate(
|
||||
@@ -98,14 +138,13 @@ async def test_single_embedding(hf_model, client: openai.AsyncOpenAI, model_name
|
||||
assert len(embeddings.data) == 1
|
||||
assert len(embeddings.data[0].embedding) == 384
|
||||
assert embeddings.usage.completion_tokens == 0
|
||||
assert embeddings.usage.prompt_tokens == 11
|
||||
assert embeddings.usage.total_tokens == 11
|
||||
assert embeddings.usage.prompt_tokens == len(input_tokens)
|
||||
assert embeddings.usage.total_tokens == len(input_tokens)
|
||||
|
||||
vllm_outputs = [d.embedding for d in embeddings.data]
|
||||
run_embedding_correctness_test(hf_model, input_texts, vllm_outputs)
|
||||
run_embedding_correctness_test(hf_model, [input_text], vllm_outputs)
|
||||
|
||||
# test using token IDs
|
||||
input_tokens = [1, 1, 1, 1, 1]
|
||||
# test input: list[int]
|
||||
embedding_response = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=input_tokens,
|
||||
@@ -119,19 +158,22 @@ async def test_single_embedding(hf_model, client: openai.AsyncOpenAI, model_name
|
||||
assert len(embeddings.data) == 1
|
||||
assert len(embeddings.data[0].embedding) == 384
|
||||
assert embeddings.usage.completion_tokens == 0
|
||||
assert embeddings.usage.prompt_tokens == 5
|
||||
assert embeddings.usage.total_tokens == 5
|
||||
assert embeddings.usage.prompt_tokens == len(input_tokens)
|
||||
assert embeddings.usage.total_tokens == len(input_tokens)
|
||||
|
||||
vllm_outputs = [d.embedding for d in embeddings.data]
|
||||
run_embedding_correctness_test(hf_model, [input_text], vllm_outputs)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI, model_name: str):
|
||||
# test list[str]
|
||||
input_texts = [
|
||||
"The cat sat on the mat.",
|
||||
"A feline was resting on a rug.",
|
||||
"Stars twinkle brightly in the night sky.",
|
||||
]
|
||||
async def test_completion_request_batched(
|
||||
client: openai.AsyncOpenAI, model_name: str, hf_model
|
||||
):
|
||||
N = 10
|
||||
input_texts = [input_text] * N
|
||||
|
||||
# test input: list[str]
|
||||
embedding_response = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=input_texts,
|
||||
@@ -142,25 +184,19 @@ async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI, model_name:
|
||||
)
|
||||
|
||||
assert embeddings.id is not None
|
||||
assert len(embeddings.data) == 3
|
||||
assert len(embeddings.data) == N
|
||||
assert len(embeddings.data[0].embedding) == 384
|
||||
assert embeddings.usage.completion_tokens == 0
|
||||
assert embeddings.usage.prompt_tokens == 33
|
||||
assert embeddings.usage.total_tokens == 33
|
||||
assert embeddings.usage.prompt_tokens == len(input_tokens) * N
|
||||
assert embeddings.usage.total_tokens == len(input_tokens) * N
|
||||
|
||||
vllm_outputs = [d.embedding for d in embeddings.data]
|
||||
run_embedding_correctness_test(hf_model, input_texts, vllm_outputs)
|
||||
|
||||
# test list[list[int]]
|
||||
input_tokens = [
|
||||
[4, 5, 7, 9, 20],
|
||||
[15, 29, 499],
|
||||
[24, 24, 24, 24, 24],
|
||||
[25, 32, 64, 77],
|
||||
]
|
||||
embedding_response = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=input_tokens,
|
||||
input=[input_tokens] * N,
|
||||
encoding_format="float",
|
||||
)
|
||||
embeddings = EmbeddingResponse.model_validate(
|
||||
@@ -168,11 +204,14 @@ async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI, model_name:
|
||||
)
|
||||
|
||||
assert embeddings.id is not None
|
||||
assert len(embeddings.data) == 4
|
||||
assert len(embeddings.data) == N
|
||||
assert len(embeddings.data[0].embedding) == 384
|
||||
assert embeddings.usage.completion_tokens == 0
|
||||
assert embeddings.usage.prompt_tokens == 17
|
||||
assert embeddings.usage.total_tokens == 17
|
||||
assert embeddings.usage.prompt_tokens == len(input_tokens) * N
|
||||
assert embeddings.usage.total_tokens == len(input_tokens) * N
|
||||
|
||||
vllm_outputs = [d.embedding for d in embeddings.data]
|
||||
run_embedding_correctness_test(hf_model, input_texts, vllm_outputs)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -235,9 +274,162 @@ async def test_conversation_embedding(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_batch_base64_embedding(
|
||||
hf_model, client: openai.AsyncOpenAI, model_name: str
|
||||
):
|
||||
async def test_truncate_prompt_tokens(client: openai.AsyncOpenAI, model_name: str):
|
||||
input_texts = [
|
||||
"Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
|
||||
]
|
||||
|
||||
# test single embedding
|
||||
embedding_response = await client.embeddings.create(
|
||||
model=model_name, input=input_texts, extra_body={"truncate_prompt_tokens": 10}
|
||||
)
|
||||
embeddings = EmbeddingResponse.model_validate(
|
||||
embedding_response.model_dump(mode="json")
|
||||
)
|
||||
|
||||
assert embeddings.id is not None
|
||||
assert len(embeddings.data) == 1
|
||||
assert len(embeddings.data[0].embedding) == 384
|
||||
assert embeddings.usage.completion_tokens == 0
|
||||
assert embeddings.usage.prompt_tokens == 10
|
||||
assert embeddings.usage.total_tokens == 10
|
||||
|
||||
input_tokens = [
|
||||
1,
|
||||
24428,
|
||||
289,
|
||||
18341,
|
||||
26165,
|
||||
285,
|
||||
19323,
|
||||
283,
|
||||
289,
|
||||
26789,
|
||||
3871,
|
||||
28728,
|
||||
9901,
|
||||
340,
|
||||
2229,
|
||||
385,
|
||||
340,
|
||||
315,
|
||||
28741,
|
||||
28804,
|
||||
2,
|
||||
]
|
||||
embedding_response = await client.embeddings.create(
|
||||
model=model_name, input=input_tokens, extra_body={"truncate_prompt_tokens": 10}
|
||||
)
|
||||
embeddings = EmbeddingResponse.model_validate(
|
||||
embedding_response.model_dump(mode="json")
|
||||
)
|
||||
|
||||
assert embeddings.id is not None
|
||||
assert len(embeddings.data) == 1
|
||||
assert len(embeddings.data[0].embedding) == 384
|
||||
assert embeddings.usage.completion_tokens == 0
|
||||
assert embeddings.usage.prompt_tokens == 10
|
||||
assert embeddings.usage.total_tokens == 10
|
||||
|
||||
# invalid_truncate_prompt_tokens
|
||||
input_texts = [
|
||||
"Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
|
||||
]
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
response = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=input_texts,
|
||||
extra_body={"truncate_prompt_tokens": 8193},
|
||||
)
|
||||
assert "error" in response.object
|
||||
assert (
|
||||
"truncate_prompt_tokens value is greater than max_model_len. "
|
||||
"Please, select a smaller truncation size." in response.message
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenAI):
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"input": input_text,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
|
||||
completion_response = await client.embeddings.create(**request_args)
|
||||
|
||||
invocation_response = requests.post(
|
||||
server.url_for("invocations"), json=request_args
|
||||
)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
completion_output = completion_response.model_dump()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert completion_output.keys() == invocation_output.keys()
|
||||
for completion_data, invocation_data in zip(
|
||||
completion_output["data"], invocation_output["data"]
|
||||
):
|
||||
assert completion_data.keys() == invocation_data.keys()
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=[completion_data["embedding"]],
|
||||
embeddings_1_lst=[invocation_data["embedding"]],
|
||||
name_0="completion",
|
||||
name_1="invocation",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations_conversation(server: RemoteOpenAIServer):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "The cat sat on the mat.",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "A feline was resting on a rug.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Stars twinkle brightly in the night sky.",
|
||||
},
|
||||
]
|
||||
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"messages": messages,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
|
||||
chat_response = requests.post(server.url_for("v1/embeddings"), json=request_args)
|
||||
chat_response.raise_for_status()
|
||||
|
||||
invocation_response = requests.post(
|
||||
server.url_for("invocations"), json=request_args
|
||||
)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
chat_output = chat_response.json()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert chat_output.keys() == invocation_output.keys()
|
||||
for chat_data, invocation_data in zip(
|
||||
chat_output["data"], invocation_output["data"]
|
||||
):
|
||||
assert chat_data.keys() == invocation_data.keys()
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=[chat_data["embedding"]],
|
||||
embeddings_1_lst=[invocation_data["embedding"]],
|
||||
name_0="chat",
|
||||
name_1="invocation",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_base64_embedding(hf_model, client: openai.AsyncOpenAI, model_name: str):
|
||||
input_texts = [
|
||||
"Hello my name is",
|
||||
"The best thing about vLLM is that it supports many different models",
|
||||
@@ -273,10 +465,7 @@ async def test_batch_base64_embedding(
|
||||
async def test_base64_embed_dtype_and_endianness(
|
||||
server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str
|
||||
):
|
||||
input_texts = [
|
||||
"The best thing about vLLM is that it supports many different models",
|
||||
]
|
||||
|
||||
input_texts = [input_text] * 3
|
||||
responses_float = await client.embeddings.create(
|
||||
input=input_texts, model=model_name, encoding_format="float"
|
||||
)
|
||||
@@ -315,10 +504,7 @@ async def test_base64_embed_dtype_and_endianness(
|
||||
async def test_bytes_embed_dtype_and_endianness(
|
||||
server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str
|
||||
):
|
||||
input_texts = [
|
||||
"The best thing about vLLM is that it supports many different models",
|
||||
]
|
||||
|
||||
input_texts = [input_text] * 3
|
||||
responses_float = await client.embeddings.create(
|
||||
input=input_texts, model=model_name, encoding_format="float"
|
||||
)
|
||||
@@ -408,15 +594,11 @@ async def test_bytes_only_embed_dtype_and_endianness(
|
||||
async def test_params_not_supported(
|
||||
server: RemoteOpenAIServer, model_name: str, param_name: str
|
||||
):
|
||||
input_texts = [
|
||||
"The best thing about vLLM is that it supports many different models",
|
||||
]
|
||||
|
||||
responses_base64 = requests.post(
|
||||
server.url_for("/v1/embeddings"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": input_texts,
|
||||
"input": input_text,
|
||||
"encoding_format": "base64",
|
||||
param_name: f"bad_{param_name}",
|
||||
},
|
||||
@@ -427,175 +609,9 @@ async def test_params_not_supported(
|
||||
assert f"bad_{param_name}" in responses_base64.json()["error"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_single_embedding_truncation(client: openai.AsyncOpenAI, model_name: str):
|
||||
input_texts = [
|
||||
"Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
|
||||
]
|
||||
|
||||
# test single embedding
|
||||
embedding_response = await client.embeddings.create(
|
||||
model=model_name, input=input_texts, extra_body={"truncate_prompt_tokens": 10}
|
||||
)
|
||||
embeddings = EmbeddingResponse.model_validate(
|
||||
embedding_response.model_dump(mode="json")
|
||||
)
|
||||
|
||||
assert embeddings.id is not None
|
||||
assert len(embeddings.data) == 1
|
||||
assert len(embeddings.data[0].embedding) == 384
|
||||
assert embeddings.usage.completion_tokens == 0
|
||||
assert embeddings.usage.prompt_tokens == 10
|
||||
assert embeddings.usage.total_tokens == 10
|
||||
|
||||
input_tokens = [
|
||||
1,
|
||||
24428,
|
||||
289,
|
||||
18341,
|
||||
26165,
|
||||
285,
|
||||
19323,
|
||||
283,
|
||||
289,
|
||||
26789,
|
||||
3871,
|
||||
28728,
|
||||
9901,
|
||||
340,
|
||||
2229,
|
||||
385,
|
||||
340,
|
||||
315,
|
||||
28741,
|
||||
28804,
|
||||
2,
|
||||
]
|
||||
embedding_response = await client.embeddings.create(
|
||||
model=model_name, input=input_tokens, extra_body={"truncate_prompt_tokens": 10}
|
||||
)
|
||||
embeddings = EmbeddingResponse.model_validate(
|
||||
embedding_response.model_dump(mode="json")
|
||||
)
|
||||
|
||||
assert embeddings.id is not None
|
||||
assert len(embeddings.data) == 1
|
||||
assert len(embeddings.data[0].embedding) == 384
|
||||
assert embeddings.usage.completion_tokens == 0
|
||||
assert embeddings.usage.prompt_tokens == 10
|
||||
assert embeddings.usage.total_tokens == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_single_embedding_truncation_invalid(
|
||||
client: openai.AsyncOpenAI, model_name: str
|
||||
):
|
||||
input_texts = [
|
||||
"Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
|
||||
]
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
response = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=input_texts,
|
||||
extra_body={"truncate_prompt_tokens": 8193},
|
||||
)
|
||||
assert "error" in response.object
|
||||
assert (
|
||||
"truncate_prompt_tokens value is greater than max_model_len. "
|
||||
"Please, select a smaller truncation size." in response.message
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenAI):
|
||||
input_texts = [
|
||||
"The chef prepared a delicious meal.",
|
||||
]
|
||||
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"input": input_texts,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
|
||||
completion_response = await client.embeddings.create(**request_args)
|
||||
|
||||
invocation_response = requests.post(
|
||||
server.url_for("invocations"), json=request_args
|
||||
)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
completion_output = completion_response.model_dump()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert completion_output.keys() == invocation_output.keys()
|
||||
for completion_data, invocation_data in zip(
|
||||
completion_output["data"], invocation_output["data"]
|
||||
):
|
||||
assert completion_data.keys() == invocation_data.keys()
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=[completion_data["embedding"]],
|
||||
embeddings_1_lst=[invocation_data["embedding"]],
|
||||
name_0="completion",
|
||||
name_1="invocation",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations_conversation(server: RemoteOpenAIServer):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "The cat sat on the mat.",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "A feline was resting on a rug.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Stars twinkle brightly in the night sky.",
|
||||
},
|
||||
]
|
||||
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"messages": messages,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
|
||||
chat_response = requests.post(server.url_for("v1/embeddings"), json=request_args)
|
||||
chat_response.raise_for_status()
|
||||
|
||||
invocation_response = requests.post(
|
||||
server.url_for("invocations"), json=request_args
|
||||
)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
chat_output = chat_response.json()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert chat_output.keys() == invocation_output.keys()
|
||||
for chat_data, invocation_data in zip(
|
||||
chat_output["data"], invocation_output["data"]
|
||||
):
|
||||
assert chat_data.keys() == invocation_data.keys()
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=[chat_data["embedding"]],
|
||||
embeddings_1_lst=[invocation_data["embedding"]],
|
||||
name_0="chat",
|
||||
name_1="invocation",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_normalize(server: RemoteOpenAIServer, model_name: str):
|
||||
input_text = ["The chef prepared a delicious meal."]
|
||||
|
||||
async def get_outputs(normalize):
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
@@ -626,8 +642,6 @@ async def test_normalize(server: RemoteOpenAIServer, model_name: str):
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_pooling_embed(server: RemoteOpenAIServer, model_name: str):
|
||||
task = "embed"
|
||||
input_text = ["The chef prepared a delicious meal."]
|
||||
|
||||
response = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json={
|
||||
@@ -648,8 +662,6 @@ async def test_pooling_embed(server: RemoteOpenAIServer, model_name: str):
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_pooling_token_embed(server: RemoteOpenAIServer, model_name: str):
|
||||
task = "token_embed"
|
||||
input_text = ["The chef prepared a delicious meal."]
|
||||
|
||||
response = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json={
|
||||
@@ -663,7 +675,7 @@ async def test_pooling_token_embed(server: RemoteOpenAIServer, model_name: str):
|
||||
poolings = PoolingResponse.model_validate(response.json())
|
||||
|
||||
assert len(poolings.data) == 1
|
||||
assert len(poolings.data[0].data) == 11
|
||||
assert len(poolings.data[0].data) == len(input_tokens)
|
||||
assert len(poolings.data[0].data[0]) == 384
|
||||
|
||||
|
||||
|
||||
@@ -24,6 +24,8 @@ from vllm.utils.serial_utils import (
|
||||
|
||||
MODEL_NAME = "internlm/internlm2-1_8b-reward"
|
||||
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
|
||||
input_text = "The chef prepared a delicious meal."
|
||||
input_tokens = [1, 918, 29981, 10166, 395, 18067, 15265, 281]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -46,30 +48,40 @@ def server():
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_basic(server: RemoteOpenAIServer, model_name: str):
|
||||
# test /v1/models
|
||||
response = requests.get(server.url_for("/v1/models"))
|
||||
served_model = response.json()["data"][0]["id"]
|
||||
assert served_model == MODEL_NAME
|
||||
|
||||
# test /tokenize
|
||||
response = requests.post(
|
||||
server.url_for("/tokenize"),
|
||||
json={"model": model_name, "prompt": input_text},
|
||||
)
|
||||
assert response.json()["tokens"] == input_tokens
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_single_pooling(server: RemoteOpenAIServer, model_name: str):
|
||||
input_texts = [
|
||||
"The chef prepared a delicious meal.",
|
||||
]
|
||||
|
||||
# test single pooling
|
||||
def test_completion_request(server: RemoteOpenAIServer, model_name: str):
|
||||
# test input: str
|
||||
response = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json={"model": model_name, "input": input_texts, "encoding_format": "float"},
|
||||
json={"model": model_name, "input": input_text, "encoding_format": "float"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
poolings = PoolingResponse.model_validate(response.json())
|
||||
|
||||
assert poolings.id is not None
|
||||
assert len(poolings.data) == 1
|
||||
assert len(poolings.data[0].data) == 8
|
||||
assert len(poolings.data[0].data) == len(input_tokens)
|
||||
assert poolings.usage.completion_tokens == 0
|
||||
assert poolings.usage.prompt_tokens == 8
|
||||
assert poolings.usage.total_tokens == 8
|
||||
assert poolings.usage.prompt_tokens == len(input_tokens)
|
||||
assert poolings.usage.total_tokens == len(input_tokens)
|
||||
|
||||
# test using token IDs
|
||||
input_tokens = [1, 1, 1, 1, 1]
|
||||
# test input: list[int]
|
||||
response = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json={"model": model_name, "input": input_tokens, "encoding_format": "float"},
|
||||
@@ -79,21 +91,17 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str):
|
||||
|
||||
assert poolings.id is not None
|
||||
assert len(poolings.data) == 1
|
||||
assert len(poolings.data[0].data) == 5
|
||||
assert len(poolings.data[0].data) == len(input_tokens)
|
||||
assert poolings.usage.completion_tokens == 0
|
||||
assert poolings.usage.prompt_tokens == 5
|
||||
assert poolings.usage.total_tokens == 5
|
||||
assert poolings.usage.prompt_tokens == len(input_tokens)
|
||||
assert poolings.usage.total_tokens == len(input_tokens)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str):
|
||||
# test list[str]
|
||||
input_texts = [
|
||||
"The cat sat on the mat.",
|
||||
"A feline was resting on a rug.",
|
||||
"Stars twinkle brightly in the night sky.",
|
||||
]
|
||||
def test_completion_request_batched(server: RemoteOpenAIServer, model_name: str):
|
||||
N = 10
|
||||
input_texts = [input_text] * N
|
||||
|
||||
response = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json={"model": model_name, "input": input_texts, "encoding_format": "float"},
|
||||
@@ -102,32 +110,30 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str):
|
||||
poolings = PoolingResponse.model_validate(response.json())
|
||||
|
||||
assert poolings.id is not None
|
||||
assert len(poolings.data) == 3
|
||||
assert len(poolings.data[0].data) == 8
|
||||
assert len(poolings.data) == N
|
||||
assert len(poolings.data[0].data) == len(input_tokens)
|
||||
assert poolings.usage.completion_tokens == 0
|
||||
assert poolings.usage.prompt_tokens == 29
|
||||
assert poolings.usage.total_tokens == 29
|
||||
assert poolings.usage.prompt_tokens == len(input_tokens) * N
|
||||
assert poolings.usage.total_tokens == len(input_tokens) * N
|
||||
|
||||
# test list[list[int]]
|
||||
input_tokens = [
|
||||
[4, 5, 7, 9, 20],
|
||||
[15, 29, 499],
|
||||
[24, 24, 24, 24, 24],
|
||||
[25, 32, 64, 77],
|
||||
]
|
||||
response = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json={"model": model_name, "input": input_tokens, "encoding_format": "float"},
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": [input_tokens] * N,
|
||||
"encoding_format": "float",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
poolings = PoolingResponse.model_validate(response.json())
|
||||
|
||||
assert poolings.id is not None
|
||||
assert len(poolings.data) == 4
|
||||
assert len(poolings.data[0].data) == 5
|
||||
assert len(poolings.data) == N
|
||||
assert len(poolings.data[0].data) == len(input_tokens)
|
||||
assert poolings.usage.completion_tokens == 0
|
||||
assert poolings.usage.prompt_tokens == 17
|
||||
assert poolings.usage.total_tokens == 17
|
||||
assert poolings.usage.prompt_tokens == len(input_tokens) * N
|
||||
assert poolings.usage.total_tokens == len(input_tokens) * N
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -259,9 +265,7 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer, model_name: str)
|
||||
async def test_base64_embed_dtype_and_endianness(
|
||||
server: RemoteOpenAIServer, model_name: str
|
||||
):
|
||||
input_texts = [
|
||||
"The best thing about vLLM is that it supports many different models",
|
||||
]
|
||||
input_texts = [input_text] * 3
|
||||
|
||||
url = server.url_for("pooling")
|
||||
float_response = requests.post(
|
||||
@@ -308,9 +312,7 @@ async def test_base64_embed_dtype_and_endianness(
|
||||
async def test_bytes_embed_dtype_and_endianness(
|
||||
server: RemoteOpenAIServer, model_name: str
|
||||
):
|
||||
input_texts = [
|
||||
"The best thing about vLLM is that it supports many different models",
|
||||
]
|
||||
input_texts = [input_text] * 3
|
||||
|
||||
url = server.url_for("pooling")
|
||||
float_response = requests.post(
|
||||
@@ -358,9 +360,7 @@ async def test_bytes_embed_dtype_and_endianness(
|
||||
async def test_bytes_only_embed_dtype_and_endianness(
|
||||
server: RemoteOpenAIServer, model_name: str
|
||||
):
|
||||
input_texts = [
|
||||
"The best thing about vLLM is that it supports many different models",
|
||||
] * 2
|
||||
input_texts = [input_text] * 3
|
||||
|
||||
url = server.url_for("pooling")
|
||||
float_response = requests.post(
|
||||
@@ -414,15 +414,11 @@ async def test_bytes_only_embed_dtype_and_endianness(
|
||||
async def test_params_not_supported(
|
||||
server: RemoteOpenAIServer, model_name: str, param_name: str
|
||||
):
|
||||
input_texts = [
|
||||
"The best thing about vLLM is that it supports many different models",
|
||||
]
|
||||
|
||||
responses_base64 = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": input_texts,
|
||||
"input": input_text,
|
||||
"encoding_format": "base64",
|
||||
param_name: f"bad_{param_name}",
|
||||
},
|
||||
@@ -435,13 +431,9 @@ async def test_params_not_supported(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations(server: RemoteOpenAIServer):
|
||||
input_texts = [
|
||||
"The chef prepared a delicious meal.",
|
||||
]
|
||||
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"input": input_texts,
|
||||
"input": input_text,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
|
||||
|
||||
@@ -13,6 +13,8 @@ from vllm.platforms import current_platform
|
||||
|
||||
MODEL_NAME = "BAAI/bge-reranker-base"
|
||||
DTYPE = "bfloat16"
|
||||
input_text = "This product was excellent and exceeded my expectations"
|
||||
input_tokens = [0, 3293, 12996, 509, 40881, 136, 204839, 297, 759, 202702, 2]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -27,6 +29,21 @@ def server():
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_basic(server: RemoteOpenAIServer, model_name: str):
|
||||
# test /v1/models
|
||||
response = requests.get(server.url_for("/v1/models"))
|
||||
served_model = response.json()["data"][0]["id"]
|
||||
assert served_model == MODEL_NAME
|
||||
|
||||
# test /tokenize
|
||||
response = requests.post(
|
||||
server.url_for("/tokenize"),
|
||||
json={"model": model_name, "prompt": input_text},
|
||||
)
|
||||
assert response.json()["tokens"] == input_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_rerank_texts(server: RemoteOpenAIServer, model_name: str):
|
||||
query = "What is the capital of France?"
|
||||
@@ -170,7 +187,6 @@ async def test_use_activation(server: RemoteOpenAIServer, model_name: str):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str):
|
||||
input_text = "This product was excellent and exceeded my expectations"
|
||||
response = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json={
|
||||
@@ -188,8 +204,6 @@ async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str):
|
||||
input_text = ["The chef prepared a delicious meal."]
|
||||
|
||||
response = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json={"model": model_name, "input": input_text, "encoding_format": "float"},
|
||||
@@ -198,7 +212,7 @@ async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: st
|
||||
poolings = PoolingResponse.model_validate(response.json())
|
||||
|
||||
assert len(poolings.data) == 1
|
||||
assert len(poolings.data[0].data) == 11
|
||||
assert len(poolings.data[0].data) == len(input_tokens)
|
||||
assert len(poolings.data[0].data[0]) == 1
|
||||
|
||||
|
||||
@@ -212,7 +226,7 @@ async def test_pooling_not_supported(
|
||||
server.url_for("pooling"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": "test",
|
||||
"input": input_text,
|
||||
"encoding_format": "float",
|
||||
"task": task,
|
||||
},
|
||||
|
||||
@@ -7,7 +7,7 @@ import pytest
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateResolutionError
|
||||
from vllm.entrypoints.score_utils import get_score_prompt
|
||||
from vllm.entrypoints.pooling.score.utils import get_score_prompt
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
|
||||
@@ -212,7 +212,7 @@ class TestGetScorePrompt:
|
||||
return_value=mock_model_no_score_template,
|
||||
),
|
||||
patch(
|
||||
"vllm.entrypoints.score_utils.apply_hf_chat_template",
|
||||
"vllm.entrypoints.pooling.score.utils.apply_hf_chat_template",
|
||||
return_value="test querytest doc",
|
||||
),
|
||||
):
|
||||
@@ -245,7 +245,7 @@ class TestGetScorePrompt:
|
||||
return_value=mock_model_no_score_template,
|
||||
),
|
||||
patch(
|
||||
"vllm.entrypoints.score_utils.apply_hf_chat_template",
|
||||
"vllm.entrypoints.pooling.score.utils.apply_hf_chat_template",
|
||||
side_effect=ChatTemplateResolutionError("No template"),
|
||||
),
|
||||
):
|
||||
@@ -296,7 +296,7 @@ class TestGetScorePrompt:
|
||||
return_value=mock_model_no_score_template,
|
||||
),
|
||||
patch(
|
||||
"vllm.entrypoints.score_utils.apply_hf_chat_template",
|
||||
"vllm.entrypoints.pooling.score.utils.apply_hf_chat_template",
|
||||
side_effect=ChatTemplateResolutionError("No template"),
|
||||
),
|
||||
):
|
||||
@@ -331,7 +331,7 @@ class TestGetScorePrompt:
|
||||
return_value=mock_model_with_score_template,
|
||||
),
|
||||
patch(
|
||||
"vllm.entrypoints.score_utils.apply_hf_chat_template",
|
||||
"vllm.entrypoints.pooling.score.utils.apply_hf_chat_template",
|
||||
side_effect=ChatTemplateResolutionError("No template"),
|
||||
),
|
||||
):
|
||||
|
||||
@@ -10,7 +10,7 @@ from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartTextParam,
|
||||
)
|
||||
from vllm.entrypoints.score_utils import ScoreMultiModalParam
|
||||
from vllm.entrypoints.pooling.score.utils import ScoreMultiModalParam
|
||||
|
||||
from ....conftest import HfRunner, VllmRunner
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ from vllm.entrypoints.chat_utils import (
|
||||
parse_chat_messages,
|
||||
resolve_chat_template_content_format,
|
||||
)
|
||||
from vllm.entrypoints.score_utils import (
|
||||
from vllm.entrypoints.pooling.score.utils import (
|
||||
ScoreContentPartParam,
|
||||
ScoreMultiModalParam,
|
||||
_cosine_similarity,
|
||||
|
||||
@@ -54,10 +54,6 @@ from vllm.entrypoints.openai.translations.serving import (
|
||||
OpenAIServingTranscription,
|
||||
OpenAIServingTranslation,
|
||||
)
|
||||
from vllm.entrypoints.pooling.classify.serving import ServingClassification
|
||||
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
|
||||
from vllm.entrypoints.pooling.score.serving import ServingScores
|
||||
from vllm.entrypoints.serve.disagg.serving import ServingTokens
|
||||
from vllm.entrypoints.serve.elastic_ep.middleware import (
|
||||
ScalingMiddleware,
|
||||
@@ -73,7 +69,6 @@ from vllm.entrypoints.utils import (
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.tasks import POOLING_TASKS
|
||||
from vllm.tool_parsers import ToolParserManager
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
@@ -761,59 +756,6 @@ async def init_app_state(
|
||||
if "generate" in supported_tasks
|
||||
else None
|
||||
)
|
||||
state.openai_serving_pooling = (
|
||||
(
|
||||
OpenAIServingPooling(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
supported_tasks=supported_tasks,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
trust_request_chat_template=args.trust_request_chat_template,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
)
|
||||
if any(task in POOLING_TASKS for task in supported_tasks)
|
||||
else None
|
||||
)
|
||||
state.openai_serving_embedding = (
|
||||
OpenAIServingEmbedding(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
trust_request_chat_template=args.trust_request_chat_template,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
if "embed" in supported_tasks
|
||||
else None
|
||||
)
|
||||
state.openai_serving_classification = (
|
||||
ServingClassification(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
trust_request_chat_template=args.trust_request_chat_template,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
if "classify" in supported_tasks
|
||||
else None
|
||||
)
|
||||
state.openai_serving_scores = (
|
||||
ServingScores(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
score_template=resolved_chat_template,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
if ("embed" in supported_tasks or "score" in supported_tasks)
|
||||
else None
|
||||
)
|
||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
@@ -878,6 +820,10 @@ async def init_app_state(
|
||||
else None
|
||||
)
|
||||
|
||||
from vllm.entrypoints.pooling import init_pooling_state
|
||||
|
||||
await init_pooling_state(engine_client, state, args)
|
||||
|
||||
state.enable_server_load_tracking = args.enable_server_load_tracking
|
||||
state.server_load_metrics = 0
|
||||
|
||||
|
||||
@@ -1,8 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
|
||||
from starlette.datastructures import State
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
|
||||
|
||||
def register_pooling_api_routers(app: FastAPI):
|
||||
from vllm.entrypoints.pooling.classify.api_router import router as classify_router
|
||||
@@ -14,3 +23,82 @@ def register_pooling_api_routers(app: FastAPI):
|
||||
app.include_router(embed_router)
|
||||
app.include_router(score_router)
|
||||
app.include_router(pooling_router)
|
||||
|
||||
|
||||
async def init_pooling_state(
|
||||
engine_client: "EngineClient", state: "State", args: "Namespace"
|
||||
):
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.pooling.classify.serving import ServingClassification
|
||||
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
|
||||
from vllm.entrypoints.pooling.score.serving import ServingScores
|
||||
from vllm.entrypoints.utils import process_chat_template
|
||||
from vllm.tasks import POOLING_TASKS
|
||||
|
||||
supported_tasks = await engine_client.get_supported_tasks()
|
||||
|
||||
vllm_config = engine_client.vllm_config
|
||||
|
||||
resolved_chat_template = await process_chat_template(
|
||||
args.chat_template, engine_client, vllm_config.model_config
|
||||
)
|
||||
|
||||
if args.enable_log_requests:
|
||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
else:
|
||||
request_logger = None
|
||||
|
||||
state.openai_serving_pooling = (
|
||||
(
|
||||
OpenAIServingPooling(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
supported_tasks=supported_tasks,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
trust_request_chat_template=args.trust_request_chat_template,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
)
|
||||
if any(task in POOLING_TASKS for task in supported_tasks)
|
||||
else None
|
||||
)
|
||||
state.openai_serving_embedding = (
|
||||
OpenAIServingEmbedding(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
trust_request_chat_template=args.trust_request_chat_template,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
if "embed" in supported_tasks
|
||||
else None
|
||||
)
|
||||
state.openai_serving_classification = (
|
||||
ServingClassification(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
trust_request_chat_template=args.trust_request_chat_template,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
if "classify" in supported_tasks
|
||||
else None
|
||||
)
|
||||
state.openai_serving_scores = (
|
||||
ServingScores(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
score_template=resolved_chat_template,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
if ("embed" in supported_tasks or "score" in supported_tasks)
|
||||
else None
|
||||
)
|
||||
|
||||
0
vllm/entrypoints/pooling/base/__init__.py
Normal file
0
vllm/entrypoints/pooling/base/__init__.py
Normal file
46
vllm/entrypoints/pooling/base/protocol.py
Normal file
46
vllm/entrypoints/pooling/base/protocol.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
|
||||
class PoolingBasicRequestMixin(OpenAIBaseModel):
|
||||
model: str | None = None
|
||||
user: str | None = None
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||
|
||||
request_id: str = Field(
|
||||
default_factory=random_uuid,
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CompletionRequestMixin(OpenAIBaseModel):
|
||||
input: list[int] | list[list[int]] | str | list[str]
|
||||
|
||||
add_special_tokens: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"If true (the default), special tokens (e.g. BOS) will be added to "
|
||||
"the prompt."
|
||||
),
|
||||
)
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
from typing import Annotated, Any, TypeAlias
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from pydantic import (
|
||||
Field,
|
||||
@@ -12,39 +12,15 @@ from vllm import PoolingParams
|
||||
from vllm.config.pooler import get_use_activation
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
|
||||
from vllm.entrypoints.pooling.base.protocol import (
|
||||
CompletionRequestMixin,
|
||||
PoolingBasicRequestMixin,
|
||||
)
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
|
||||
class ClassificationCompletionRequest(OpenAIBaseModel):
|
||||
model: str | None = None
|
||||
input: list[str] | str
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||
user: str | None = None
|
||||
|
||||
class ClassificationCompletionRequest(PoolingBasicRequestMixin, CompletionRequestMixin):
|
||||
# --8<-- [start:classification-extra-params]
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
add_special_tokens: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"If true (the default), special tokens (e.g. BOS) will be added to "
|
||||
"the prompt."
|
||||
),
|
||||
)
|
||||
request_id: str = Field(
|
||||
default_factory=random_uuid,
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
softmax: bool | None = Field(
|
||||
default=None,
|
||||
description="softmax will be deprecated, please use use_activation instead.",
|
||||
@@ -69,11 +45,8 @@ class ClassificationCompletionRequest(OpenAIBaseModel):
|
||||
)
|
||||
|
||||
|
||||
class ClassificationChatRequest(OpenAIBaseModel):
|
||||
model: str | None = None
|
||||
class ClassificationChatRequest(PoolingBasicRequestMixin):
|
||||
messages: list[ChatCompletionMessageParam]
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||
user: str | None = None
|
||||
|
||||
# --8<-- [start:chat-classification-extra-params]
|
||||
add_generation_prompt: bool = Field(
|
||||
@@ -119,23 +92,6 @@ class ClassificationChatRequest(OpenAIBaseModel):
|
||||
description=("Additional kwargs to pass to the HF processor."),
|
||||
)
|
||||
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
|
||||
request_id: str = Field(
|
||||
default_factory=random_uuid,
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
softmax: bool | None = Field(
|
||||
default=None,
|
||||
description="softmax will be deprecated, please use use_activation instead.",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
from typing import Annotated, Any, TypeAlias
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from pydantic import (
|
||||
Field,
|
||||
@@ -11,44 +11,22 @@ from pydantic import (
|
||||
from vllm import PoolingParams
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
|
||||
from vllm.entrypoints.pooling.base.protocol import (
|
||||
CompletionRequestMixin,
|
||||
PoolingBasicRequestMixin,
|
||||
)
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
|
||||
|
||||
|
||||
class EmbeddingCompletionRequest(OpenAIBaseModel):
|
||||
class EmbeddingCompletionRequest(PoolingBasicRequestMixin, CompletionRequestMixin):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/embeddings
|
||||
model: str | None = None
|
||||
input: list[int] | list[list[int]] | str | list[str]
|
||||
|
||||
encoding_format: EncodingFormat = "float"
|
||||
dimensions: int | None = None
|
||||
user: str | None = None
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||
|
||||
# --8<-- [start:embedding-extra-params]
|
||||
add_special_tokens: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"If true (the default), special tokens (e.g. BOS) will be added to "
|
||||
"the prompt."
|
||||
),
|
||||
)
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
request_id: str = Field(
|
||||
default_factory=random_uuid,
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
normalize: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to normalize the embeddings outputs. Default is True.",
|
||||
@@ -73,20 +51,17 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
dimensions=self.dimensions,
|
||||
use_activation=self.normalize,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingChatRequest(OpenAIBaseModel):
|
||||
model: str | None = None
|
||||
class EmbeddingChatRequest(PoolingBasicRequestMixin):
|
||||
messages: list[ChatCompletionMessageParam]
|
||||
|
||||
encoding_format: EncodingFormat = "float"
|
||||
dimensions: int | None = None
|
||||
user: str | None = None
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||
|
||||
# --8<-- [start:chat-embedding-extra-params]
|
||||
add_generation_prompt: bool = Field(
|
||||
@@ -137,22 +112,6 @@ class EmbeddingChatRequest(OpenAIBaseModel):
|
||||
default=None,
|
||||
description=("Additional kwargs to pass to the HF processor."),
|
||||
)
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
request_id: str = Field(
|
||||
default_factory=random_uuid,
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
normalize: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to normalize the embeddings outputs. Default is True.",
|
||||
|
||||
@@ -10,6 +10,7 @@ from pydantic import (
|
||||
from vllm import PoolingParams
|
||||
from vllm.config.pooler import get_use_activation
|
||||
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
|
||||
from vllm.entrypoints.pooling.base.protocol import PoolingBasicRequestMixin
|
||||
from vllm.entrypoints.pooling.embed.protocol import (
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
@@ -72,17 +73,8 @@ class PoolingChatRequest(EmbeddingChatRequest):
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
|
||||
model: str | None = None
|
||||
|
||||
priority: int = Field(default=0)
|
||||
"""
|
||||
The priority of the request (lower means earlier handling;
|
||||
default: 0). Any priority other than 0 will raise an error
|
||||
if the served model does not use priority scheduling.
|
||||
"""
|
||||
class IOProcessorRequest(PoolingBasicRequestMixin, Generic[T]):
|
||||
data: T
|
||||
|
||||
task: PoolingTask = "plugin"
|
||||
encoding_format: EncodingFormat = "float"
|
||||
embed_dtype: EmbedDType = Field(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
from typing import Annotated, Any
|
||||
from typing import Any
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
@@ -11,32 +11,24 @@ from pydantic import (
|
||||
from vllm import PoolingParams
|
||||
from vllm.config.pooler import get_use_activation
|
||||
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel, UsageInfo
|
||||
from vllm.entrypoints.score_utils import ScoreContentPartParam, ScoreMultiModalParam
|
||||
from vllm.entrypoints.pooling.base.protocol import PoolingBasicRequestMixin
|
||||
from vllm.entrypoints.pooling.score.utils import (
|
||||
ScoreContentPartParam,
|
||||
ScoreMultiModalParam,
|
||||
)
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
|
||||
class ScoreRequest(OpenAIBaseModel):
|
||||
model: str | None = None
|
||||
class ScoreRequest(PoolingBasicRequestMixin):
|
||||
text_1: list[str] | str | ScoreMultiModalParam
|
||||
text_2: list[str] | str | ScoreMultiModalParam
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||
|
||||
# --8<-- [start:score-extra-params]
|
||||
|
||||
mm_processor_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=("Additional kwargs to pass to the HF processor."),
|
||||
)
|
||||
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
|
||||
softmax: bool | None = Field(
|
||||
default=None,
|
||||
description="softmax will be deprecated, please use use_activation instead.",
|
||||
@@ -61,29 +53,16 @@ class ScoreRequest(OpenAIBaseModel):
|
||||
)
|
||||
|
||||
|
||||
class RerankRequest(OpenAIBaseModel):
|
||||
model: str | None = None
|
||||
class RerankRequest(PoolingBasicRequestMixin):
|
||||
query: str | ScoreMultiModalParam
|
||||
documents: list[str] | ScoreMultiModalParam
|
||||
top_n: int = Field(default_factory=lambda: 0)
|
||||
truncate_prompt_tokens: Annotated[int, Field(ge=-1)] | None = None
|
||||
|
||||
# --8<-- [start:rerank-extra-params]
|
||||
|
||||
mm_processor_kwargs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description=("Additional kwargs to pass to the HF processor."),
|
||||
)
|
||||
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
|
||||
softmax: bool | None = Field(
|
||||
default=None,
|
||||
description="softmax will be deprecated, please use use_activation instead.",
|
||||
|
||||
@@ -25,7 +25,7 @@ from vllm.entrypoints.pooling.score.protocol import (
|
||||
ScoreResponse,
|
||||
ScoreResponseData,
|
||||
)
|
||||
from vllm.entrypoints.score_utils import (
|
||||
from vllm.entrypoints.pooling.score.utils import (
|
||||
ScoreContentPartParam,
|
||||
ScoreMultiModalParam,
|
||||
_cosine_similarity,
|
||||
|
||||
Reference in New Issue
Block a user