[Test] Add tests for n parameter in chat completions API (#35283)

Signed-off-by: KrxGu <krishom70@gmail.com>
This commit is contained in:
Krish Gupta
2026-02-26 14:44:07 +05:30
committed by GitHub
parent ade81f17fe
commit 3827c8c55a

View File

@@ -3,6 +3,7 @@
# imports for structured outputs tests
import json
from collections import defaultdict
import jsonschema
import openai # use the official client for correctness check
@@ -13,6 +14,11 @@ import requests
import torch
from openai import BadRequestError
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.sampling_params import SamplingParams
from ...utils import RemoteOpenAIServer
# any model with a chat template should work here
@@ -815,3 +821,203 @@ async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenA
assert chat_output.keys() == invocation_output.keys()
assert chat_output["choices"] == invocation_output["choices"]
# Test n parameter for chat completions
# Tests that the n parameter works correctly for regular sampling
# (non-beam search) in chat completions, addressing issue #34305.
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_chat_completion_n_parameter_non_streaming(
client: openai.AsyncOpenAI, model_name: str
):
"""Test that n parameter returns multiple choices for non-streaming requests."""
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the opposite of big?"},
]
# Test with n=3
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=20,
temperature=0.7,
n=3,
stream=False,
)
assert len(chat_completion.choices) == 3
# Verify each choice has content and correct index
for i, choice in enumerate(chat_completion.choices):
assert choice.index == i
assert choice.message.content is not None
assert len(choice.message.content) > 0
# Verify all responses are different (highly likely with temperature > 0)
contents = [choice.message.content for choice in chat_completion.choices]
assert len(set(contents)) > 1, "Expected different responses with n=3"
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_chat_completion_n_parameter_streaming(
client: openai.AsyncOpenAI, model_name: str
):
"""Test that n parameter returns multiple choices for streaming requests."""
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the capital of France?"},
]
stream = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=15,
temperature=0.7,
n=2,
stream=True,
)
# Collect all chunks using defaultdict for dynamic handling
chunks_by_index = defaultdict(list)
async for chunk in stream:
for choice in chunk.choices:
if choice.delta.content:
chunks_by_index[choice.index].append(choice.delta.content)
# Verify both choices received content
assert len(chunks_by_index[0]) > 0, "Choice 0 received no content chunks"
assert len(chunks_by_index[1]) > 0, "Choice 1 received no content chunks"
# Reconstruct full responses
response_0 = "".join(chunks_by_index[0])
response_1 = "".join(chunks_by_index[1])
assert len(response_0) > 0, "Choice 0 has empty response"
assert len(response_1) > 0, "Choice 1 has empty response"
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_chat_completion_n_with_seed(client: openai.AsyncOpenAI, model_name: str):
"""Test that n parameter works correctly with seed parameter."""
messages = [
{"role": "user", "content": "Say hello."},
]
# Test that seed parameter is accepted and works with n > 1
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
temperature=0.8,
n=2,
seed=42,
stream=False,
)
# Verify we get n=2 choices
assert len(chat_completion.choices) == 2
# Verify both choices have valid content
for i, choice in enumerate(chat_completion.choices):
assert choice.index == i
assert choice.message.content is not None
assert len(choice.message.content) > 0
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_chat_completion_n_equals_1(client: openai.AsyncOpenAI, model_name: str):
"""Test that n=1 (default) still works correctly."""
messages = [
{"role": "user", "content": "Hello!"},
]
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
temperature=0.7,
n=1,
stream=False,
)
assert len(chat_completion.choices) == 1
assert chat_completion.choices[0].index == 0
assert chat_completion.choices[0].message.content is not None
# Unit tests for n parameter in ChatCompletionRequest.to_sampling_params()
def test_chat_completion_request_n_parameter_to_sampling_params():
"""Test that n parameter is correctly passed to SamplingParams."""
# Test with n=3
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Hello"}],
n=3,
max_tokens=10,
)
sampling_params = request.to_sampling_params(
max_tokens=10,
default_sampling_params={},
)
assert isinstance(sampling_params, SamplingParams)
assert sampling_params.n == 3, f"Expected n=3, got n={sampling_params.n}"
def test_chat_completion_request_n_parameter_default():
"""Test that n parameter defaults to 1."""
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Hello"}],
# n not specified, should default to 1
max_tokens=10,
)
assert request.n == 1, "n should default to 1"
sampling_params = request.to_sampling_params(
max_tokens=10,
default_sampling_params={},
)
# SamplingParams.from_optional converts None to 1
assert sampling_params.n == 1, f"Expected n=1 (default), got n={sampling_params.n}"
def test_chat_completion_request_n_parameter_various_values():
"""Test n parameter with various values."""
for n_value in [1, 2, 5, 10]:
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Test"}],
n=n_value,
max_tokens=10,
)
sampling_params = request.to_sampling_params(
max_tokens=10,
default_sampling_params={},
)
assert sampling_params.n == n_value, (
f"Expected n={n_value}, got n={sampling_params.n}"
)