[Speculative Decoding] Fixing hidden states handling in batch expansion (#7508)

This commit is contained in:
Abhinav Goyal
2024-08-20 06:28:14 +05:30
committed by GitHub
parent e54ebc2f8f
commit 312f761232
6 changed files with 139 additions and 41 deletions

View File

@@ -288,15 +288,17 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
ensure_all_accepted=ensure_all_accepted)
def run_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
temperature: float,
seeded: bool,
print_tokens: bool = False,
ensure_all_accepted: bool = False):
def run_equality_correctness_test(
baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
temperature: float,
seeded: bool,
print_tokens: bool = False,
ensure_all_accepted: bool = False,
expected_acceptance_rate: Optional[float] = None):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero (or when temperature is > 0 and seeded).
@@ -357,5 +359,10 @@ def run_equality_correctness_test(baseline_llm_generator,
print(f'{i=} {spec_token_ids=}')
assert baseline_token_ids == spec_token_ids
print(f'{acceptance_rate=}')
if ensure_all_accepted:
assert acceptance_rate == 1.0
if expected_acceptance_rate is not None:
assert acceptance_rate >= expected_acceptance_rate - 1e-2