[Frontend] Chat-based Embeddings API (#9759)
This commit is contained in:
@@ -4,14 +4,18 @@ import numpy as np
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import requests
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
|
||||
MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
|
||||
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def embedding_server():
|
||||
def server():
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
@@ -19,31 +23,29 @@ def embedding_server():
|
||||
"--enforce-eager",
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--chat-template",
|
||||
DUMMY_CHAT_TEMPLATE,
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(EMBEDDING_MODEL_NAME, args) as remote_server:
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def embedding_client(embedding_server):
|
||||
async with embedding_server.get_async_client() as async_client:
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[EMBEDDING_MODEL_NAME],
|
||||
)
|
||||
async def test_single_embedding(embedding_client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str):
|
||||
input_texts = [
|
||||
"The chef prepared a delicious meal.",
|
||||
]
|
||||
|
||||
# test single embedding
|
||||
embeddings = await embedding_client.embeddings.create(
|
||||
embeddings = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=input_texts,
|
||||
encoding_format="float",
|
||||
@@ -57,7 +59,7 @@ async def test_single_embedding(embedding_client: openai.AsyncOpenAI,
|
||||
|
||||
# test using token IDs
|
||||
input_tokens = [1, 1, 1, 1, 1]
|
||||
embeddings = await embedding_client.embeddings.create(
|
||||
embeddings = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=input_tokens,
|
||||
encoding_format="float",
|
||||
@@ -71,18 +73,14 @@ async def test_single_embedding(embedding_client: openai.AsyncOpenAI,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[EMBEDDING_MODEL_NAME],
|
||||
)
|
||||
async def test_batch_embedding(embedding_client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str):
|
||||
# test List[str]
|
||||
input_texts = [
|
||||
"The cat sat on the mat.", "A feline was resting on a rug.",
|
||||
"Stars twinkle brightly in the night sky."
|
||||
]
|
||||
embeddings = await embedding_client.embeddings.create(
|
||||
embeddings = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=input_texts,
|
||||
encoding_format="float",
|
||||
@@ -90,11 +88,14 @@ async def test_batch_embedding(embedding_client: openai.AsyncOpenAI,
|
||||
assert embeddings.id is not None
|
||||
assert len(embeddings.data) == 3
|
||||
assert len(embeddings.data[0].embedding) == 4096
|
||||
assert embeddings.usage.completion_tokens == 0
|
||||
assert embeddings.usage.prompt_tokens == 32
|
||||
assert embeddings.usage.total_tokens == 32
|
||||
|
||||
# test List[List[int]]
|
||||
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
|
||||
[25, 32, 64, 77]]
|
||||
embeddings = await embedding_client.embeddings.create(
|
||||
embeddings = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=input_tokens,
|
||||
encoding_format="float",
|
||||
@@ -108,22 +109,70 @@ async def test_batch_embedding(embedding_client: openai.AsyncOpenAI,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[EMBEDDING_MODEL_NAME],
|
||||
)
|
||||
async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI,
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_conversation_embedding(server: RemoteOpenAIServer,
|
||||
client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": "The cat sat on the mat.",
|
||||
}, {
|
||||
"role": "assistant",
|
||||
"content": "A feline was resting on a rug.",
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Stars twinkle brightly in the night sky.",
|
||||
}]
|
||||
|
||||
chat_response = requests.post(server.url_for("v1/embeddings"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"messages": messages,
|
||||
"encoding_format": "float",
|
||||
})
|
||||
chat_response.raise_for_status()
|
||||
chat_embeddings = chat_response.json()
|
||||
|
||||
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
chat_template=DUMMY_CHAT_TEMPLATE,
|
||||
add_generation_prompt=True,
|
||||
continue_final_message=False,
|
||||
tokenize=False,
|
||||
)
|
||||
completion_response = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=prompt,
|
||||
encoding_format="float",
|
||||
# To be consistent with chat
|
||||
extra_body={"add_special_tokens": False},
|
||||
)
|
||||
completion_embeddings = completion_response.model_dump(mode="json")
|
||||
|
||||
assert chat_embeddings.pop("id") is not None
|
||||
assert completion_embeddings.pop("id") is not None
|
||||
assert chat_embeddings.pop("created") <= completion_embeddings.pop(
|
||||
"created")
|
||||
assert chat_embeddings == completion_embeddings
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_batch_base64_embedding(client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
input_texts = [
|
||||
"Hello my name is",
|
||||
"The best thing about vLLM is that it supports many different models"
|
||||
]
|
||||
|
||||
responses_float = await embedding_client.embeddings.create(
|
||||
input=input_texts, model=model_name, encoding_format="float")
|
||||
responses_float = await client.embeddings.create(input=input_texts,
|
||||
model=model_name,
|
||||
encoding_format="float")
|
||||
|
||||
responses_base64 = await embedding_client.embeddings.create(
|
||||
input=input_texts, model=model_name, encoding_format="base64")
|
||||
responses_base64 = await client.embeddings.create(input=input_texts,
|
||||
model=model_name,
|
||||
encoding_format="base64")
|
||||
|
||||
decoded_responses_base64_data = []
|
||||
for data in responses_base64.data:
|
||||
@@ -137,8 +186,8 @@ async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI,
|
||||
1]
|
||||
|
||||
# Default response is float32 decoded from base64 by OpenAI Client
|
||||
responses_default = await embedding_client.embeddings.create(
|
||||
input=input_texts, model=model_name)
|
||||
responses_default = await client.embeddings.create(input=input_texts,
|
||||
model=model_name)
|
||||
|
||||
assert responses_float.data[0].embedding == responses_default.data[
|
||||
0].embedding
|
||||
@@ -147,18 +196,15 @@ async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[EMBEDDING_MODEL_NAME],
|
||||
)
|
||||
async def test_single_embedding_truncation(
|
||||
embedding_client: openai.AsyncOpenAI, model_name: str):
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_single_embedding_truncation(client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
input_texts = [
|
||||
"Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
|
||||
]
|
||||
|
||||
# test single embedding
|
||||
embeddings = await embedding_client.embeddings.create(
|
||||
embeddings = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=input_texts,
|
||||
extra_body={"truncate_prompt_tokens": 10})
|
||||
@@ -173,7 +219,7 @@ async def test_single_embedding_truncation(
|
||||
1, 24428, 289, 18341, 26165, 285, 19323, 283, 289, 26789, 3871, 28728,
|
||||
9901, 340, 2229, 385, 340, 315, 28741, 28804, 2
|
||||
]
|
||||
embeddings = await embedding_client.embeddings.create(
|
||||
embeddings = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=input_tokens,
|
||||
extra_body={"truncate_prompt_tokens": 10})
|
||||
@@ -187,18 +233,15 @@ async def test_single_embedding_truncation(
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[EMBEDDING_MODEL_NAME],
|
||||
)
|
||||
async def test_single_embedding_truncation_invalid(
|
||||
embedding_client: openai.AsyncOpenAI, model_name: str):
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_single_embedding_truncation_invalid(client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
input_texts = [
|
||||
"Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
|
||||
]
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
embeddings = await embedding_client.embeddings.create(
|
||||
embeddings = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=input_texts,
|
||||
extra_body={"truncate_prompt_tokens": 8193})
|
||||
|
||||
Reference in New Issue
Block a user