[Speculative Decoding] Fixing hidden states handling in batch expansion (#7508)
This commit is contained in:
@@ -288,15 +288,17 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
ensure_all_accepted=ensure_all_accepted)
|
||||
|
||||
|
||||
def run_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len,
|
||||
force_output_len: bool,
|
||||
temperature: float,
|
||||
seeded: bool,
|
||||
print_tokens: bool = False,
|
||||
ensure_all_accepted: bool = False):
|
||||
def run_equality_correctness_test(
|
||||
baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len,
|
||||
force_output_len: bool,
|
||||
temperature: float,
|
||||
seeded: bool,
|
||||
print_tokens: bool = False,
|
||||
ensure_all_accepted: bool = False,
|
||||
expected_acceptance_rate: Optional[float] = None):
|
||||
"""Helper method that compares the outputs of both the baseline LLM and
|
||||
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
|
||||
the same when temperature is zero (or when temperature is > 0 and seeded).
|
||||
@@ -357,5 +359,10 @@ def run_equality_correctness_test(baseline_llm_generator,
|
||||
print(f'{i=} {spec_token_ids=}')
|
||||
assert baseline_token_ids == spec_token_ids
|
||||
|
||||
print(f'{acceptance_rate=}')
|
||||
|
||||
if ensure_all_accepted:
|
||||
assert acceptance_rate == 1.0
|
||||
|
||||
if expected_acceptance_rate is not None:
|
||||
assert acceptance_rate >= expected_acceptance_rate - 1e-2
|
||||
|
||||
Reference in New Issue
Block a user