[Frontend] Online Pooling API (#11457)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-24 17:54:30 +08:00
committed by GitHub
parent 4f074fbf53
commit 9edca6bf8f
15 changed files with 808 additions and 156 deletions

View File

@@ -6,6 +6,7 @@ import pytest
import pytest_asyncio
import requests
from vllm.entrypoints.openai.protocol import EmbeddingResponse
from vllm.transformers_utils.tokenizer import get_tokenizer
from ...utils import RemoteOpenAIServer
@@ -17,6 +18,8 @@ DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' +
@pytest.fixture(scope="module")
def server():
args = [
"--task",
"embed",
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
@@ -45,11 +48,14 @@ async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str):
]
# test single embedding
embeddings = await client.embeddings.create(
embedding_response = await client.embeddings.create(
model=model_name,
input=input_texts,
encoding_format="float",
)
embeddings = EmbeddingResponse.model_validate(
embedding_response.model_dump(mode="json"))
assert embeddings.id is not None
assert len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 4096
@@ -59,11 +65,14 @@ async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs
input_tokens = [1, 1, 1, 1, 1]
embeddings = await client.embeddings.create(
embedding_response = await client.embeddings.create(
model=model_name,
input=input_tokens,
encoding_format="float",
)
embeddings = EmbeddingResponse.model_validate(
embedding_response.model_dump(mode="json"))
assert embeddings.id is not None
assert len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 4096
@@ -80,11 +89,14 @@ async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str):
"The cat sat on the mat.", "A feline was resting on a rug.",
"Stars twinkle brightly in the night sky."
]
embeddings = await client.embeddings.create(
embedding_response = await client.embeddings.create(
model=model_name,
input=input_texts,
encoding_format="float",
)
embeddings = EmbeddingResponse.model_validate(
embedding_response.model_dump(mode="json"))
assert embeddings.id is not None
assert len(embeddings.data) == 3
assert len(embeddings.data[0].embedding) == 4096
@@ -95,11 +107,14 @@ async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str):
# test List[List[int]]
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
[25, 32, 64, 77]]
embeddings = await client.embeddings.create(
embedding_response = await client.embeddings.create(
model=model_name,
input=input_tokens,
encoding_format="float",
)
embeddings = EmbeddingResponse.model_validate(
embedding_response.model_dump(mode="json"))
assert embeddings.id is not None
assert len(embeddings.data) == 4
assert len(embeddings.data[0].embedding) == 4096
@@ -124,14 +139,16 @@ async def test_conversation_embedding(server: RemoteOpenAIServer,
"content": "Stars twinkle brightly in the night sky.",
}]
chat_response = requests.post(server.url_for("v1/embeddings"),
json={
"model": model_name,
"messages": messages,
"encoding_format": "float",
})
chat_response = requests.post(
server.url_for("v1/embeddings"),
json={
"model": model_name,
"messages": messages,
"encoding_format": "float",
},
)
chat_response.raise_for_status()
chat_embeddings = chat_response.json()
chat_embeddings = EmbeddingResponse.model_validate(chat_response.json())
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
prompt = tokenizer.apply_chat_template(
@@ -148,13 +165,15 @@ async def test_conversation_embedding(server: RemoteOpenAIServer,
# To be consistent with chat
extra_body={"add_special_tokens": False},
)
completion_embeddings = completion_response.model_dump(mode="json")
completion_embeddings = EmbeddingResponse.model_validate(
completion_response.model_dump(mode="json"))
assert chat_embeddings.pop("id") is not None
assert completion_embeddings.pop("id") is not None
assert chat_embeddings.pop("created") <= completion_embeddings.pop(
"created")
assert chat_embeddings == completion_embeddings
assert chat_embeddings.id is not None
assert completion_embeddings.id is not None
assert chat_embeddings.created <= completion_embeddings.created
assert chat_embeddings.model_dump(
exclude={"id", "created"}) == (completion_embeddings.model_dump(
exclude={"id", "created"}))
@pytest.mark.asyncio
@@ -204,10 +223,13 @@ async def test_single_embedding_truncation(client: openai.AsyncOpenAI,
]
# test single embedding
embeddings = await client.embeddings.create(
embedding_response = await client.embeddings.create(
model=model_name,
input=input_texts,
extra_body={"truncate_prompt_tokens": 10})
embeddings = EmbeddingResponse.model_validate(
embedding_response.model_dump(mode="json"))
assert embeddings.id is not None
assert len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 4096
@@ -219,10 +241,12 @@ async def test_single_embedding_truncation(client: openai.AsyncOpenAI,
1, 24428, 289, 18341, 26165, 285, 19323, 283, 289, 26789, 3871, 28728,
9901, 340, 2229, 385, 340, 315, 28741, 28804, 2
]
embeddings = await client.embeddings.create(
embedding_response = await client.embeddings.create(
model=model_name,
input=input_tokens,
extra_body={"truncate_prompt_tokens": 10})
embeddings = EmbeddingResponse.model_validate(
embedding_response.model_dump(mode="json"))
assert embeddings.id is not None
assert len(embeddings.data) == 1
@@ -241,10 +265,10 @@ async def test_single_embedding_truncation_invalid(client: openai.AsyncOpenAI,
]
with pytest.raises(openai.BadRequestError):
embeddings = await client.embeddings.create(
response = await client.embeddings.create(
model=model_name,
input=input_texts,
extra_body={"truncate_prompt_tokens": 8193})
assert "error" in embeddings.object
assert "error" in response.object
assert "truncate_prompt_tokens value is greater than max_model_len. "\
"Please, select a smaller truncation size." in embeddings.message
"Please, select a smaller truncation size." in response.message