[Frontend][2/n] Make pooling entrypoints request schema consensus | ChatRequest (#32574)
Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
@@ -214,64 +214,6 @@ async def test_completion_request_batched(
|
||||
run_embedding_correctness_test(hf_model, input_texts, vllm_outputs)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_conversation_embedding(
|
||||
server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str
|
||||
):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "The cat sat on the mat.",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "A feline was resting on a rug.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Stars twinkle brightly in the night sky.",
|
||||
},
|
||||
]
|
||||
|
||||
chat_response = requests.post(
|
||||
server.url_for("v1/embeddings"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"messages": messages,
|
||||
"encoding_format": "float",
|
||||
},
|
||||
)
|
||||
chat_response.raise_for_status()
|
||||
chat_embeddings = EmbeddingResponse.model_validate(chat_response.json())
|
||||
|
||||
tokenizer = get_tokenizer(tokenizer_name=model_name)
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
chat_template=DUMMY_CHAT_TEMPLATE,
|
||||
add_generation_prompt=True,
|
||||
continue_final_message=False,
|
||||
tokenize=False,
|
||||
)
|
||||
completion_response = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=prompt,
|
||||
encoding_format="float",
|
||||
# To be consistent with chat
|
||||
extra_body={"add_special_tokens": False},
|
||||
)
|
||||
completion_embeddings = EmbeddingResponse.model_validate(
|
||||
completion_response.model_dump(mode="json")
|
||||
)
|
||||
|
||||
assert chat_embeddings.id is not None
|
||||
assert completion_embeddings.id is not None
|
||||
assert chat_embeddings.created <= completion_embeddings.created
|
||||
assert chat_embeddings.model_dump(exclude={"id", "created"}) == (
|
||||
completion_embeddings.model_dump(exclude={"id", "created"})
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_truncate_prompt_tokens(client: openai.AsyncOpenAI, model_name: str):
|
||||
@@ -350,7 +292,129 @@ async def test_truncate_prompt_tokens(client: openai.AsyncOpenAI, model_name: st
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_chat_request(
|
||||
server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str
|
||||
):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "The cat sat on the mat.",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "A feline was resting on a rug.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Stars twinkle brightly in the night sky.",
|
||||
},
|
||||
]
|
||||
|
||||
# test chat request basic usage
|
||||
chat_response = requests.post(
|
||||
server.url_for("v1/embeddings"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"messages": messages,
|
||||
"encoding_format": "float",
|
||||
},
|
||||
)
|
||||
chat_response.raise_for_status()
|
||||
chat_embeddings = EmbeddingResponse.model_validate(chat_response.json())
|
||||
|
||||
tokenizer = get_tokenizer(tokenizer_name=model_name)
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
chat_template=DUMMY_CHAT_TEMPLATE,
|
||||
add_generation_prompt=True,
|
||||
continue_final_message=False,
|
||||
tokenize=False,
|
||||
)
|
||||
completion_response = await client.embeddings.create(
|
||||
model=model_name,
|
||||
input=prompt,
|
||||
encoding_format="float",
|
||||
# To be consistent with chat
|
||||
extra_body={"add_special_tokens": False},
|
||||
)
|
||||
completion_embeddings = EmbeddingResponse.model_validate(
|
||||
completion_response.model_dump(mode="json")
|
||||
)
|
||||
|
||||
assert chat_embeddings.id is not None
|
||||
assert completion_embeddings.id is not None
|
||||
assert chat_embeddings.created <= completion_embeddings.created
|
||||
assert chat_embeddings.model_dump(exclude={"id", "created"}) == (
|
||||
completion_embeddings.model_dump(exclude={"id", "created"})
|
||||
)
|
||||
|
||||
# test add_generation_prompt
|
||||
response = requests.post(
|
||||
server.url_for("v1/embeddings"),
|
||||
json={"model": model_name, "messages": messages, "add_generation_prompt": True},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
output = EmbeddingResponse.model_validate(response.json())
|
||||
|
||||
assert output.object == "list"
|
||||
assert len(output.data) == 1
|
||||
assert output.model == MODEL_NAME
|
||||
assert output.usage.prompt_tokens == 34
|
||||
|
||||
# test continue_final_message
|
||||
response = requests.post(
|
||||
server.url_for("v1/embeddings"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"messages": messages,
|
||||
"continue_final_message": True,
|
||||
},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
output = EmbeddingResponse.model_validate(response.json())
|
||||
|
||||
assert output.object == "list"
|
||||
assert len(output.data) == 1
|
||||
assert output.model == MODEL_NAME
|
||||
assert output.usage.prompt_tokens == 33
|
||||
|
||||
# test add_special_tokens
|
||||
response = requests.post(
|
||||
server.url_for("v1/embeddings"),
|
||||
json={"model": model_name, "messages": messages, "add_special_tokens": True},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
output = EmbeddingResponse.model_validate(response.json())
|
||||
|
||||
assert output.object == "list"
|
||||
assert len(output.data) == 1
|
||||
assert output.model == MODEL_NAME
|
||||
assert output.usage.prompt_tokens == 36
|
||||
|
||||
# test continue_final_message with add_generation_prompt
|
||||
response = requests.post(
|
||||
server.url_for("v1/embeddings"),
|
||||
json={
|
||||
"model": model_name,
|
||||
"messages": messages,
|
||||
"continue_final_message": True,
|
||||
"add_generation_prompt": True,
|
||||
},
|
||||
)
|
||||
assert (
|
||||
"Cannot set both `continue_final_message` and `add_generation_prompt` to True."
|
||||
in response.json()["error"]["message"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations_completion_request(
|
||||
server: RemoteOpenAIServer, client: openai.AsyncOpenAI
|
||||
):
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"input": input_text,
|
||||
@@ -381,7 +445,7 @@ async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenA
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations_conversation(server: RemoteOpenAIServer):
|
||||
async def test_invocations_chat_request(server: RemoteOpenAIServer):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
|
||||
Reference in New Issue
Block a user