[Frontend] Chat-based Embeddings API (#9759)

This commit is contained in:
Cyrus Leung
2024-11-01 16:13:35 +08:00
committed by GitHub
parent d3aa2a8b2f
commit 06386a64dd
21 changed files with 846 additions and 408 deletions

View File

@@ -4,14 +4,18 @@ import numpy as np
import openai
import pytest
import pytest_asyncio
import requests
from vllm.transformers_utils.tokenizer import get_tokenizer
from ...utils import RemoteOpenAIServer
EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
@pytest.fixture(scope="module")
def embedding_server():
def server():
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
@@ -19,31 +23,29 @@ def embedding_server():
"--enforce-eager",
"--max-model-len",
"8192",
"--chat-template",
DUMMY_CHAT_TEMPLATE,
]
with RemoteOpenAIServer(EMBEDDING_MODEL_NAME, args) as remote_server:
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def embedding_client(embedding_server):
async with embedding_server.get_async_client() as async_client:
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[EMBEDDING_MODEL_NAME],
)
async def test_single_embedding(embedding_client: openai.AsyncOpenAI,
model_name: str):
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str):
input_texts = [
"The chef prepared a delicious meal.",
]
# test single embedding
embeddings = await embedding_client.embeddings.create(
embeddings = await client.embeddings.create(
model=model_name,
input=input_texts,
encoding_format="float",
@@ -57,7 +59,7 @@ async def test_single_embedding(embedding_client: openai.AsyncOpenAI,
# test using token IDs
input_tokens = [1, 1, 1, 1, 1]
embeddings = await embedding_client.embeddings.create(
embeddings = await client.embeddings.create(
model=model_name,
input=input_tokens,
encoding_format="float",
@@ -71,18 +73,14 @@ async def test_single_embedding(embedding_client: openai.AsyncOpenAI,
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[EMBEDDING_MODEL_NAME],
)
async def test_batch_embedding(embedding_client: openai.AsyncOpenAI,
model_name: str):
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str):
# test List[str]
input_texts = [
"The cat sat on the mat.", "A feline was resting on a rug.",
"Stars twinkle brightly in the night sky."
]
embeddings = await embedding_client.embeddings.create(
embeddings = await client.embeddings.create(
model=model_name,
input=input_texts,
encoding_format="float",
@@ -90,11 +88,14 @@ async def test_batch_embedding(embedding_client: openai.AsyncOpenAI,
assert embeddings.id is not None
assert len(embeddings.data) == 3
assert len(embeddings.data[0].embedding) == 4096
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 32
assert embeddings.usage.total_tokens == 32
# test List[List[int]]
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
[25, 32, 64, 77]]
embeddings = await embedding_client.embeddings.create(
embeddings = await client.embeddings.create(
model=model_name,
input=input_tokens,
encoding_format="float",
@@ -108,22 +109,70 @@ async def test_batch_embedding(embedding_client: openai.AsyncOpenAI,
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[EMBEDDING_MODEL_NAME],
)
async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI,
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_conversation_embedding(server: RemoteOpenAIServer,
client: openai.AsyncOpenAI,
model_name: str):
messages = [{
"role": "user",
"content": "The cat sat on the mat.",
}, {
"role": "assistant",
"content": "A feline was resting on a rug.",
}, {
"role": "user",
"content": "Stars twinkle brightly in the night sky.",
}]
chat_response = requests.post(server.url_for("v1/embeddings"),
json={
"model": model_name,
"messages": messages,
"encoding_format": "float",
})
chat_response.raise_for_status()
chat_embeddings = chat_response.json()
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
prompt = tokenizer.apply_chat_template(
messages,
chat_template=DUMMY_CHAT_TEMPLATE,
add_generation_prompt=True,
continue_final_message=False,
tokenize=False,
)
completion_response = await client.embeddings.create(
model=model_name,
input=prompt,
encoding_format="float",
# To be consistent with chat
extra_body={"add_special_tokens": False},
)
completion_embeddings = completion_response.model_dump(mode="json")
assert chat_embeddings.pop("id") is not None
assert completion_embeddings.pop("id") is not None
assert chat_embeddings.pop("created") <= completion_embeddings.pop(
"created")
assert chat_embeddings == completion_embeddings
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_batch_base64_embedding(client: openai.AsyncOpenAI,
model_name: str):
input_texts = [
"Hello my name is",
"The best thing about vLLM is that it supports many different models"
]
responses_float = await embedding_client.embeddings.create(
input=input_texts, model=model_name, encoding_format="float")
responses_float = await client.embeddings.create(input=input_texts,
model=model_name,
encoding_format="float")
responses_base64 = await embedding_client.embeddings.create(
input=input_texts, model=model_name, encoding_format="base64")
responses_base64 = await client.embeddings.create(input=input_texts,
model=model_name,
encoding_format="base64")
decoded_responses_base64_data = []
for data in responses_base64.data:
@@ -137,8 +186,8 @@ async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI,
1]
# Default response is float32 decoded from base64 by OpenAI Client
responses_default = await embedding_client.embeddings.create(
input=input_texts, model=model_name)
responses_default = await client.embeddings.create(input=input_texts,
model=model_name)
assert responses_float.data[0].embedding == responses_default.data[
0].embedding
@@ -147,18 +196,15 @@ async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI,
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[EMBEDDING_MODEL_NAME],
)
async def test_single_embedding_truncation(
embedding_client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_single_embedding_truncation(client: openai.AsyncOpenAI,
model_name: str):
input_texts = [
"Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
]
# test single embedding
embeddings = await embedding_client.embeddings.create(
embeddings = await client.embeddings.create(
model=model_name,
input=input_texts,
extra_body={"truncate_prompt_tokens": 10})
@@ -173,7 +219,7 @@ async def test_single_embedding_truncation(
1, 24428, 289, 18341, 26165, 285, 19323, 283, 289, 26789, 3871, 28728,
9901, 340, 2229, 385, 340, 315, 28741, 28804, 2
]
embeddings = await embedding_client.embeddings.create(
embeddings = await client.embeddings.create(
model=model_name,
input=input_tokens,
extra_body={"truncate_prompt_tokens": 10})
@@ -187,18 +233,15 @@ async def test_single_embedding_truncation(
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[EMBEDDING_MODEL_NAME],
)
async def test_single_embedding_truncation_invalid(
embedding_client: openai.AsyncOpenAI, model_name: str):
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_single_embedding_truncation_invalid(client: openai.AsyncOpenAI,
model_name: str):
input_texts = [
"Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
]
with pytest.raises(openai.BadRequestError):
embeddings = await embedding_client.embeddings.create(
embeddings = await client.embeddings.create(
model=model_name,
input=input_texts,
extra_body={"truncate_prompt_tokens": 8193})