[Frontend] Online Pooling API (#11457)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -6,6 +6,7 @@ import pytest
|
||||
import pytest_asyncio
|
||||
import requests
|
||||
|
||||
from vllm.entrypoints.openai.protocol import EmbeddingResponse
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
@@ -17,6 +18,8 @@ DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' +
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--task",
|
||||
"embed",
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
@@ -45,11 +48,14 @@ async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str):
|
||||
]
|
||||
|
||||
# test single embedding
|
||||
embeddings = await client.embeddings.create(
|
||||
embedding_response = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=input_texts,
|
||||
encoding_format="float",
|
||||
)
|
||||
embeddings = EmbeddingResponse.model_validate(
|
||||
embedding_response.model_dump(mode="json"))
|
||||
|
||||
assert embeddings.id is not None
|
||||
assert len(embeddings.data) == 1
|
||||
assert len(embeddings.data[0].embedding) == 4096
|
||||
@@ -59,11 +65,14 @@ async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str):
|
||||
|
||||
# test using token IDs
|
||||
input_tokens = [1, 1, 1, 1, 1]
|
||||
embeddings = await client.embeddings.create(
|
||||
embedding_response = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=input_tokens,
|
||||
encoding_format="float",
|
||||
)
|
||||
embeddings = EmbeddingResponse.model_validate(
|
||||
embedding_response.model_dump(mode="json"))
|
||||
|
||||
assert embeddings.id is not None
|
||||
assert len(embeddings.data) == 1
|
||||
assert len(embeddings.data[0].embedding) == 4096
|
||||
@@ -80,11 +89,14 @@ async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str):
|
||||
"The cat sat on the mat.", "A feline was resting on a rug.",
|
||||
"Stars twinkle brightly in the night sky."
|
||||
]
|
||||
embeddings = await client.embeddings.create(
|
||||
embedding_response = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=input_texts,
|
||||
encoding_format="float",
|
||||
)
|
||||
embeddings = EmbeddingResponse.model_validate(
|
||||
embedding_response.model_dump(mode="json"))
|
||||
|
||||
assert embeddings.id is not None
|
||||
assert len(embeddings.data) == 3
|
||||
assert len(embeddings.data[0].embedding) == 4096
|
||||
@@ -95,11 +107,14 @@ async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str):
|
||||
# test List[List[int]]
|
||||
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
|
||||
[25, 32, 64, 77]]
|
||||
embeddings = await client.embeddings.create(
|
||||
embedding_response = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=input_tokens,
|
||||
encoding_format="float",
|
||||
)
|
||||
embeddings = EmbeddingResponse.model_validate(
|
||||
embedding_response.model_dump(mode="json"))
|
||||
|
||||
assert embeddings.id is not None
|
||||
assert len(embeddings.data) == 4
|
||||
assert len(embeddings.data[0].embedding) == 4096
|
||||
@@ -124,14 +139,16 @@ async def test_conversation_embedding(server: RemoteOpenAIServer,
|
||||
"content": "Stars twinkle brightly in the night sky.",
|
||||
}]
|
||||
|
||||
chat_response = requests.post(server.url_for("v1/embeddings"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"messages": messages,
|
||||
"encoding_format": "float",
|
||||
})
|
||||
chat_response = requests.post(
|
||||
server.url_for("v1/embeddings"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"messages": messages,
|
||||
"encoding_format": "float",
|
||||
},
|
||||
)
|
||||
chat_response.raise_for_status()
|
||||
chat_embeddings = chat_response.json()
|
||||
chat_embeddings = EmbeddingResponse.model_validate(chat_response.json())
|
||||
|
||||
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
@@ -148,13 +165,15 @@ async def test_conversation_embedding(server: RemoteOpenAIServer,
|
||||
# To be consistent with chat
|
||||
extra_body={"add_special_tokens": False},
|
||||
)
|
||||
completion_embeddings = completion_response.model_dump(mode="json")
|
||||
completion_embeddings = EmbeddingResponse.model_validate(
|
||||
completion_response.model_dump(mode="json"))
|
||||
|
||||
assert chat_embeddings.pop("id") is not None
|
||||
assert completion_embeddings.pop("id") is not None
|
||||
assert chat_embeddings.pop("created") <= completion_embeddings.pop(
|
||||
"created")
|
||||
assert chat_embeddings == completion_embeddings
|
||||
assert chat_embeddings.id is not None
|
||||
assert completion_embeddings.id is not None
|
||||
assert chat_embeddings.created <= completion_embeddings.created
|
||||
assert chat_embeddings.model_dump(
|
||||
exclude={"id", "created"}) == (completion_embeddings.model_dump(
|
||||
exclude={"id", "created"}))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -204,10 +223,13 @@ async def test_single_embedding_truncation(client: openai.AsyncOpenAI,
|
||||
]
|
||||
|
||||
# test single embedding
|
||||
embeddings = await client.embeddings.create(
|
||||
embedding_response = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=input_texts,
|
||||
extra_body={"truncate_prompt_tokens": 10})
|
||||
embeddings = EmbeddingResponse.model_validate(
|
||||
embedding_response.model_dump(mode="json"))
|
||||
|
||||
assert embeddings.id is not None
|
||||
assert len(embeddings.data) == 1
|
||||
assert len(embeddings.data[0].embedding) == 4096
|
||||
@@ -219,10 +241,12 @@ async def test_single_embedding_truncation(client: openai.AsyncOpenAI,
|
||||
1, 24428, 289, 18341, 26165, 285, 19323, 283, 289, 26789, 3871, 28728,
|
||||
9901, 340, 2229, 385, 340, 315, 28741, 28804, 2
|
||||
]
|
||||
embeddings = await client.embeddings.create(
|
||||
embedding_response = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=input_tokens,
|
||||
extra_body={"truncate_prompt_tokens": 10})
|
||||
embeddings = EmbeddingResponse.model_validate(
|
||||
embedding_response.model_dump(mode="json"))
|
||||
|
||||
assert embeddings.id is not None
|
||||
assert len(embeddings.data) == 1
|
||||
@@ -241,10 +265,10 @@ async def test_single_embedding_truncation_invalid(client: openai.AsyncOpenAI,
|
||||
]
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
embeddings = await client.embeddings.create(
|
||||
response = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=input_texts,
|
||||
extra_body={"truncate_prompt_tokens": 8193})
|
||||
assert "error" in embeddings.object
|
||||
assert "error" in response.object
|
||||
assert "truncate_prompt_tokens value is greater than max_model_len. "\
|
||||
"Please, select a smaller truncation size." in embeddings.message
|
||||
"Please, select a smaller truncation size." in response.message
|
||||
|
||||
Reference in New Issue
Block a user