[Bugfix] Make spec. decode respect per-request seed. (#6034)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
Thomas Parnell
2024-07-19 04:22:08 +02:00
committed by GitHub
parent b5672a112c
commit d4201e06d5
8 changed files with 293 additions and 46 deletions

View File

@@ -1,6 +1,6 @@
import asyncio
from itertools import cycle
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Sequence, Tuple, Union
import pytest
import ray
@@ -128,7 +128,9 @@ class AsyncLLM:
try:
for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
res = asyncio.run(get_output(prompt, sampling_params))
params = sampling_params[i] if isinstance(
sampling_params, Sequence) else sampling_params
res = asyncio.run(get_output(prompt, params))
outputs.append(res)
finally:
ray.shutdown()
@@ -267,7 +269,31 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
"""
temperature = 0.0
run_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len,
temperature=0.0,
seeded=False,
print_tokens=print_tokens,
ensure_all_accepted=ensure_all_accepted)
def run_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
temperature: float,
seeded: bool,
print_tokens: bool = False,
ensure_all_accepted: bool = False):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero (or when temperature is > 0 and seeded).
"""
prompts = [
"Hello, my name is",
@@ -286,11 +312,21 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
# sampling params to ignore eos token.
ignore_eos = force_output_len
sampling_params = SamplingParams(
max_tokens=max_output_len,
ignore_eos=ignore_eos,
temperature=temperature,
)
if seeded:
sampling_params = [
SamplingParams(
max_tokens=max_output_len,
ignore_eos=ignore_eos,
temperature=temperature,
seed=i,
) for i in range(len(prompts))
]
else:
sampling_params = SamplingParams(
max_tokens=max_output_len,
ignore_eos=ignore_eos,
temperature=temperature,
)
(spec_batch_tokens, spec_batch_token_ids,
acceptance_rate) = get_output_from_llm_generator(test_llm_generator,