[Test] Add tests for n parameter in chat completions API (#35283)
Signed-off-by: KrxGu <krishom70@gmail.com>
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
|
||||
# imports for structured outputs tests
|
||||
import json
|
||||
from collections import defaultdict
|
||||
|
||||
import jsonschema
|
||||
import openai # use the official client for correctness check
|
||||
@@ -13,6 +14,11 @@ import requests
|
||||
import torch
|
||||
from openai import BadRequestError
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
# any model with a chat template should work here
|
||||
@@ -815,3 +821,203 @@ async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenA
|
||||
|
||||
assert chat_output.keys() == invocation_output.keys()
|
||||
assert chat_output["choices"] == invocation_output["choices"]
|
||||
|
||||
|
||||
# Test n parameter for chat completions
|
||||
# Tests that the n parameter works correctly for regular sampling
|
||||
# (non-beam search) in chat completions, addressing issue #34305.
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_chat_completion_n_parameter_non_streaming(
|
||||
client: openai.AsyncOpenAI, model_name: str
|
||||
):
|
||||
"""Test that n parameter returns multiple choices for non-streaming requests."""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is the opposite of big?"},
|
||||
]
|
||||
|
||||
# Test with n=3
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=20,
|
||||
temperature=0.7,
|
||||
n=3,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert len(chat_completion.choices) == 3
|
||||
|
||||
# Verify each choice has content and correct index
|
||||
for i, choice in enumerate(chat_completion.choices):
|
||||
assert choice.index == i
|
||||
assert choice.message.content is not None
|
||||
assert len(choice.message.content) > 0
|
||||
|
||||
# Verify all responses are different (highly likely with temperature > 0)
|
||||
contents = [choice.message.content for choice in chat_completion.choices]
|
||||
assert len(set(contents)) > 1, "Expected different responses with n=3"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_chat_completion_n_parameter_streaming(
|
||||
client: openai.AsyncOpenAI, model_name: str
|
||||
):
|
||||
"""Test that n parameter returns multiple choices for streaming requests."""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is the capital of France?"},
|
||||
]
|
||||
|
||||
stream = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=15,
|
||||
temperature=0.7,
|
||||
n=2,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Collect all chunks using defaultdict for dynamic handling
|
||||
chunks_by_index = defaultdict(list)
|
||||
async for chunk in stream:
|
||||
for choice in chunk.choices:
|
||||
if choice.delta.content:
|
||||
chunks_by_index[choice.index].append(choice.delta.content)
|
||||
|
||||
# Verify both choices received content
|
||||
assert len(chunks_by_index[0]) > 0, "Choice 0 received no content chunks"
|
||||
assert len(chunks_by_index[1]) > 0, "Choice 1 received no content chunks"
|
||||
|
||||
# Reconstruct full responses
|
||||
response_0 = "".join(chunks_by_index[0])
|
||||
response_1 = "".join(chunks_by_index[1])
|
||||
|
||||
assert len(response_0) > 0, "Choice 0 has empty response"
|
||||
assert len(response_1) > 0, "Choice 1 has empty response"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_chat_completion_n_with_seed(client: openai.AsyncOpenAI, model_name: str):
|
||||
"""Test that n parameter works correctly with seed parameter."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Say hello."},
|
||||
]
|
||||
|
||||
# Test that seed parameter is accepted and works with n > 1
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=10,
|
||||
temperature=0.8,
|
||||
n=2,
|
||||
seed=42,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Verify we get n=2 choices
|
||||
assert len(chat_completion.choices) == 2
|
||||
|
||||
# Verify both choices have valid content
|
||||
for i, choice in enumerate(chat_completion.choices):
|
||||
assert choice.index == i
|
||||
assert choice.message.content is not None
|
||||
assert len(choice.message.content) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_chat_completion_n_equals_1(client: openai.AsyncOpenAI, model_name: str):
|
||||
"""Test that n=1 (default) still works correctly."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello!"},
|
||||
]
|
||||
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=10,
|
||||
temperature=0.7,
|
||||
n=1,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert len(chat_completion.choices) == 1
|
||||
assert chat_completion.choices[0].index == 0
|
||||
assert chat_completion.choices[0].message.content is not None
|
||||
|
||||
|
||||
# Unit tests for n parameter in ChatCompletionRequest.to_sampling_params()
|
||||
def test_chat_completion_request_n_parameter_to_sampling_params():
|
||||
"""Test that n parameter is correctly passed to SamplingParams."""
|
||||
# Test with n=3
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
n=3,
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
sampling_params = request.to_sampling_params(
|
||||
max_tokens=10,
|
||||
default_sampling_params={},
|
||||
)
|
||||
|
||||
assert isinstance(sampling_params, SamplingParams)
|
||||
assert sampling_params.n == 3, f"Expected n=3, got n={sampling_params.n}"
|
||||
|
||||
|
||||
def test_chat_completion_request_n_parameter_default():
|
||||
"""Test that n parameter defaults to 1."""
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
# n not specified, should default to 1
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
assert request.n == 1, "n should default to 1"
|
||||
sampling_params = request.to_sampling_params(
|
||||
max_tokens=10,
|
||||
default_sampling_params={},
|
||||
)
|
||||
|
||||
# SamplingParams.from_optional converts None to 1
|
||||
assert sampling_params.n == 1, f"Expected n=1 (default), got n={sampling_params.n}"
|
||||
|
||||
|
||||
def test_chat_completion_request_n_parameter_various_values():
|
||||
"""Test n parameter with various values."""
|
||||
for n_value in [1, 2, 5, 10]:
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "Test"}],
|
||||
n=n_value,
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
sampling_params = request.to_sampling_params(
|
||||
max_tokens=10,
|
||||
default_sampling_params={},
|
||||
)
|
||||
|
||||
assert sampling_params.n == n_value, (
|
||||
f"Expected n={n_value}, got n={sampling_params.n}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user