[V1] V1 engine implements parallel sampling (AsyncLLM and LLMEngine) (#10980)
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -250,6 +250,108 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
|
||||
assert "".join(chunks) == single_output
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_parallel_no_streaming(client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
"""Parallel sampling without streaming.
|
||||
A single request output contains a list of completions.
|
||||
"""
|
||||
|
||||
prompt = "What is an LLM?"
|
||||
n = 3
|
||||
max_tokens = 5
|
||||
|
||||
# High temperature to maximize chance of unique completions.
|
||||
completion = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
temperature=0.95,
|
||||
stream=False,
|
||||
seed=42)
|
||||
|
||||
# Assert `n` completions
|
||||
num_completions = len(completion.choices)
|
||||
assert num_completions == n, (
|
||||
f"Num completions {num_completions} but expected {n}.")
|
||||
completion_repeats: Dict[str, int] = {}
|
||||
for idx, choice in enumerate(completion.choices):
|
||||
# Assert correct completion index & some finish reason.
|
||||
assert choice.index == idx, (
|
||||
f"Index {choice.index} but expected {idx}.")
|
||||
assert choice.finish_reason is not None, (
|
||||
"None finish_reason is invalid.")
|
||||
text = choice.text
|
||||
completion_repeats[text] = completion_repeats.get(text, 0) + 1
|
||||
# Assert `n` unique completions
|
||||
num_unique = len(completion_repeats)
|
||||
if num_unique != n:
|
||||
repeats = {
|
||||
txt: num
|
||||
for (txt, num) in completion_repeats.items() if num > 1
|
||||
}
|
||||
raise AssertionError(
|
||||
f"Expected {n} unique completions, got {num_unique};"
|
||||
f" repeats: {repeats}.")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
|
||||
"""Streaming for parallel sampling.
|
||||
The tokens from multiple samples, are flattened into a single stream,
|
||||
with an index to indicate which sample the token belongs to.
|
||||
"""
|
||||
|
||||
prompt = "What is an LLM?"
|
||||
n = 3
|
||||
max_tokens = 5
|
||||
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
temperature=0.95,
|
||||
stream=True,
|
||||
seed=42)
|
||||
chunks: List[List[str]] = [[] for i in range(n)]
|
||||
finish_reason_count = 0
|
||||
async for chunk in stream:
|
||||
index = chunk.choices[0].index
|
||||
text = chunk.choices[0].text
|
||||
chunks[index].append(text)
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason_count += 1
|
||||
# Assert `n` completions with correct finish reasons
|
||||
assert finish_reason_count == n, (
|
||||
f"Expected {n} completions with valid indices and finish_reason.")
|
||||
completion_repeats: Dict[str, int] = {}
|
||||
for chunk in chunks:
|
||||
chunk_len = len(chunk)
|
||||
# Assert correct number of completion tokens
|
||||
assert chunk_len == max_tokens, (
|
||||
f"max_tokens={max_tokens} but chunk len is {chunk_len}.")
|
||||
text = "".join(chunk)
|
||||
completion_repeats[text] = completion_repeats.get(text, 0) + 1
|
||||
print(text)
|
||||
# Assert `n` unique completions
|
||||
num_unique = len(completion_repeats)
|
||||
if num_unique != n:
|
||||
repeats = {
|
||||
txt: num
|
||||
for (txt, num) in completion_repeats.items() if num > 1
|
||||
}
|
||||
raise AssertionError(f"{num_unique} unique completions, expected {n};"
|
||||
f" repeats: {repeats}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
|
||||
Reference in New Issue
Block a user