[BugFix][V1] Fix parallel sampling finishing/aborts (#14512)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-03-12 13:29:48 -04:00
committed by GitHub
parent 916836bbfb
commit f5d3acd474
7 changed files with 137 additions and 113 deletions

View File

@@ -263,15 +263,16 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI,
prompt = "What is an LLM?"
n = 3
max_tokens = 5
max_tokens = 50 # we want some to finish earlier than others
# High temperature to maximize chance of unique completions.
completion = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=max_tokens,
n=n,
temperature=0.95,
temperature=1.0,
stream=False,
logprobs=0,
seed=42)
# Assert `n` completions
@@ -279,6 +280,7 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI,
assert num_completions == n, (
f"Num completions {num_completions} but expected {n}.")
completion_repeats: dict[str, int] = {}
output_token_lengths = set()
for idx, choice in enumerate(completion.choices):
# Assert correct completion index & some finish reason.
assert choice.index == idx, (
@@ -287,6 +289,9 @@ async def test_parallel_no_streaming(client: openai.AsyncOpenAI,
"None finish_reason is invalid.")
text = choice.text
completion_repeats[text] = completion_repeats.get(text, 0) + 1
output_token_lengths.add(len(choice.logprobs.tokens))
# Assert subrequests finished at different times
assert len(output_token_lengths) > 1
# Assert `n` unique completions
num_unique = len(completion_repeats)
if num_unique != n:
@@ -312,16 +317,16 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
prompt = "What is an LLM?"
n = 3
max_tokens = 5
max_tokens = 50 # we want some to finish earlier than others
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=max_tokens,
n=n,
temperature=0.95,
temperature=1.0,
stream=True,
seed=42)
chunks: list[list[str]] = [[] for i in range(n)]
chunks: list[list[str]] = [[] for _ in range(n)]
finish_reason_count = 0
async for chunk in stream:
index = chunk.choices[0].index
@@ -333,14 +338,18 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
assert finish_reason_count == n, (
f"Expected {n} completions with valid indices and finish_reason.")
completion_repeats: dict[str, int] = {}
chunk_lengths = set()
for chunk in chunks:
chunk_len = len(chunk)
# Assert correct number of completion tokens
assert chunk_len == max_tokens, (
chunk_lengths.add(chunk_len)
assert chunk_len <= max_tokens, (
f"max_tokens={max_tokens} but chunk len is {chunk_len}.")
text = "".join(chunk)
completion_repeats[text] = completion_repeats.get(text, 0) + 1
print(text)
# Assert subrequests finished at different times
assert len(chunk_lengths) > 1
# Assert `n` unique completions
num_unique = len(completion_repeats)
if num_unique != n: